Проблема с дообучением seq2seq модели(Tensorflow)
Пытаюсь сделать чат-бота на основе seq2seq модели. Для создания бота решил воспользоваться туториалом от tensorflow(https://www.tensorflow.org/text/tutorials/nmt_with_attention). По-сути просто пока что скопировал код, подставив свой датасет. Суть бота в том, что он одолжен обучаться в реальном времени, то есть каждый желающий может вносить правки в бота(создается новый датасет и модель дообучается).
train_model.fit(new_dataset, epochs=3)
На данный момент модель вовсе не хочет дообучаться и вносить новые изменения, хотя вроде в веса она вносит изменения, но словарный запас у нее почему-то не пополняется..
Я предполагаю, что возможно дело в векторизации текста, а именно где определяется словарный запас:
input_text_processor = tf.keras.layers.TextVectorization(
standardize=tf_lower_and_split_punct,
max_tokens=max_tokens)
input_text_processor.adapt(inputs)
Возможно ли как-то добавить новый словарный запас tf.keras.layers.TextVectorization?
embedding_dim = 1024
units = 1024
dataset = build_dataset(data := (['hi',
'hello',
'howdy'
],
['greetings',
'hello',
'hi'
]))
dataset2 = build_dataset(data2 := (['bye',
'bye-bye',
'bb'
],
['bye',
'bye',
'bye'
]))
input_text_processor = text_processor()
input_text_processor.adapt(data[0])
output_text_processor = text_processor()
output_text_processor.adapt(data[1])
train_model = TrainModel(embedding_dim, units,
input_text_processor=input_text_processor,
output_text_processor=output_text_processor)
train_model.compile(
optimizer='adam',
loss=MaskedLoss()
)
batch_loss = BatchLogs('batch_loss')
train_model.fit(dataset, epochs=3,
callbacks=[batch_loss],
)
# дообучаем на новом датасете
# train_model.input_text_processor.adapt(data2[0]) # это перезапишет старый словарный запас..
# train_model.output_text_processor.adapt(data2[1])
train_model.fit(dataset2, epochs=3,
callbacks=[batch_loss],
)
# тестируем
translator = Translator(
encoder=train_model.encoder,
decoder=train_model.decoder,
input_text_processor=train_model.input_text_processor,
output_text_processor=train_model.output_text_processor,
)
input_text = tf.constant([
'bye',
'hello',
])
result = translator.translate(
input_text=input_text)
# везде выводится фразы из первого датасета(dataset).
print(result['text'][0].numpy().decode()) # hello
print(result['text'][1].numpy().decode()) # hi
Возможно кто-то может посоветовать иные туториалы, примеры c исходниками чат-ботов, основанных на нейронных сетях?