Проблема с дообучением seq2seq модели(Tensorflow)

Пытаюсь сделать чат-бота на основе seq2seq модели. Для создания бота решил воспользоваться туториалом от tensorflow(https://www.tensorflow.org/text/tutorials/nmt_with_attention). По-сути просто пока что скопировал код, подставив свой датасет. Суть бота в том, что он одолжен обучаться в реальном времени, то есть каждый желающий может вносить правки в бота(создается новый датасет и модель дообучается).

train_model.fit(new_dataset, epochs=3)

На данный момент модель вовсе не хочет дообучаться и вносить новые изменения, хотя вроде в веса она вносит изменения, но словарный запас у нее почему-то не пополняется..

Я предполагаю, что возможно дело в векторизации текста, а именно где определяется словарный запас:

input_text_processor = tf.keras.layers.TextVectorization(
          standardize=tf_lower_and_split_punct,
          max_tokens=max_tokens)
input_text_processor.adapt(inputs)

Возможно ли как-то добавить новый словарный запас tf.keras.layers.TextVectorization?

фулкод(https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/nmt_with_attention.ipynb):

embedding_dim = 1024
units = 1024

dataset = build_dataset(data := (['hi',
                                  'hello',
                                  'howdy'
                                  ],
                                 ['greetings',
                                  'hello',
                                  'hi'
                                  ]))

dataset2 = build_dataset(data2 := (['bye',
                                    'bye-bye',
                                    'bb'
                                    ],
                                   ['bye',
                                    'bye',
                                    'bye'
                                    ]))

input_text_processor = text_processor()
input_text_processor.adapt(data[0])

output_text_processor = text_processor()
output_text_processor.adapt(data[1])

train_model = TrainModel(embedding_dim, units,
                         input_text_processor=input_text_processor,
                         output_text_processor=output_text_processor)

train_model.compile(
    optimizer='adam',
    loss=MaskedLoss()
)
batch_loss = BatchLogs('batch_loss')

train_model.fit(dataset, epochs=3,
                callbacks=[batch_loss],
                )

# дообучаем на новом датасете
# train_model.input_text_processor.adapt(data2[0]) # это перезапишет старый словарный запас..
# train_model.output_text_processor.adapt(data2[1])
train_model.fit(dataset2, epochs=3,
                callbacks=[batch_loss],
                )

# тестируем
translator = Translator(
    encoder=train_model.encoder,
    decoder=train_model.decoder,
    input_text_processor=train_model.input_text_processor,
    output_text_processor=train_model.output_text_processor,
)

input_text = tf.constant([
    'bye',
    'hello',
])

result = translator.translate(
    input_text=input_text)

# везде выводится фразы из первого датасета(dataset).
print(result['text'][0].numpy().decode())  # hello
print(result['text'][1].numpy().decode())  # hi

Возможно кто-то может посоветовать иные туториалы, примеры c исходниками чат-ботов, основанных на нейронных сетях?


Ответы (0 шт):