Как подобрать слои для языковой модели на PyTorch?
Всех приветствую, впервые использую StackOverflow, так как столкнулся с одной проблемой. Я создаю нейросеть, которая основываясь на написанном в сообщении будет угадывать слова, которыми можно ответить. Для обучения я взял переписку из Telegram в формате Json. Всю её перевёл в лист из строк и сделал кастомный датасет, в котором каждый ввод представляет собой тензор размерностью с словарь из всех слов в переписке, чтобы каждый индекс отвечал за частоту употребления отдельно взятого слова. Я хочу, чтобы я давал на входе такой Bag of words в формате int и получал точно такой же на выходе, чтобы перевести обратно в слова. Однако я вроде разобрался, модель даже запускается, всё вроде работает, но что после обучения, что до она подбирает веса и выдаёт один и тот же вывод, вне зависимости от ввода. Ранее пытался использовать эмбеддинговые слои, но нейронка вообще превращала тензор в кашу. Можете мне подсказать какие лучше использовать слои и как организовать цикл обучения ?
class Jester(nn.Module):
def __init__(self, input_tensor_size, vocab_size, embedding_dim,hidden_dim ):
super(Jester, self).__init__()
self.embedding = nn.Embedding(vocab_size,embedding_dim)
self.lstm = nn.LSTM(input_size=embedding_dim,hidden_size=hidden_dim,num_layers=3,)
self.out_linear = nn.Linear(hidden_dim,input_tensor_size)
#self.output_layer = nn.Softmax(dim=1)
def forward(self, x):
print(x)
out_embedded = self.embedding(x)
print (out_embedded.shape)
out,hidden = self.lstm(out_embedded)
print(out.shape)
out = self.out_linear(out)
out = torch.round(out)
print(out.shape)
return out , hidden
def train(dataloader, model, optimizer,criterion):
print("[System] Start training loops")
model.train()
running_loss = 0.0
for epoch in range(1):
print(f"[Training] Epoch - {epoch}")
for x,y in tqdm(dataloader):
optimizer.zero_grad()
prediction = model(x)
loss = criterion(prediction,y.float())
loss.backward()
optimizer.step()
running_loss += loss.item()
#print(prediction)
print(f"[Training] Epoch - {epoch} || Loss is {running_loss/len(dataloader)}")