Как использовать модель без модуля GAT?
Есть вот такая модель (ResNet18 + GAT):
class AntispoofModel(nn.Module):
def __init__(self, device="cpu", **kwargs):
super().__init__()
resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
self.resnet = nn.Sequential(*[i for i in list(resnet.children())[:-2]]).to(device)
for ch in self.resnet.children():
for param in ch.parameters():
param.requires_grad = False
self.gat = GAT(**kwargs).to(device)
self.device = device
self.adj = torch.tensor(grid_to_graph(7, 7, return_as=np.ndarray)).to(device)
def forward(self, x):
x = self.resnet(x.to(self.device))
x = x.view(-1, 49, 512)
#adj = torch.stack([self.adj for i in range(x.shape[0])]).to(self.device)
x = self.gat(x, self.adj)
return torch.sigmoid(x)
Пытаюсь сделать то же самое, но без GAT, например:
class AntispoofModel(nn.Module):
def __init__(self, device="cpu", **kwargs):
super().__init__()
resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
self.resnet = nn.Sequential(*[i for i in list(resnet.children())[:-2]]).to(device)
for ch in self.resnet.children():
for param in ch.parameters():
param.requires_grad = False
self.device = device
def forward(self, x):
x = self.resnet(x.to(self.device))
x = x.view(-1, 49, 512)
return torch.sigmoid(x)
Но тогда получаю ошибку Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([64, 49, 512])) is deprecated. Please ensure they have the same size.
Как можно правильно убрать модуль GAT?