Как использовать модель без модуля GAT?

Есть вот такая модель (ResNet18 + GAT):


class AntispoofModel(nn.Module):
    def __init__(self, device="cpu", **kwargs):
        super().__init__()
        resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
        self.resnet = nn.Sequential(*[i for i in list(resnet.children())[:-2]]).to(device)
        for ch in self.resnet.children():
            for param in ch.parameters():
                param.requires_grad = False
        self.gat = GAT(**kwargs).to(device)
        self.device = device
        self.adj = torch.tensor(grid_to_graph(7, 7, return_as=np.ndarray)).to(device)
        
    def forward(self, x):
        x = self.resnet(x.to(self.device))
        x = x.view(-1, 49, 512)
        #adj = torch.stack([self.adj for i in range(x.shape[0])]).to(self.device)
        x = self.gat(x, self.adj)
        return torch.sigmoid(x)

Пытаюсь сделать то же самое, но без GAT, например:

class AntispoofModel(nn.Module):
    def __init__(self, device="cpu", **kwargs):
        super().__init__()
        resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
        self.resnet = nn.Sequential(*[i for i in list(resnet.children())[:-2]]).to(device)
        for ch in self.resnet.children():
            for param in ch.parameters():
                param.requires_grad = False
        self.device = device
        
    def forward(self, x):
        x = self.resnet(x.to(self.device))
        x = x.view(-1, 49, 512)
        return torch.sigmoid(x)

Но тогда получаю ошибку Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([64, 49, 512])) is deprecated. Please ensure they have the same size.

Как можно правильно убрать модуль GAT?


Ответы (0 шт):