При загрузке модели keras ошибка AttributeError: 'NoneType' object has no attribute 'get'
Язык: python Фреймворк: tensorflow/keras
При попытке загрузить ранее обученную и сохраненную модель, возникает вот такая ошибка:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-6-aa7422fa3072> in <cell line: 1>()
----> 1 keras.models.load_model('/content/drive/MyDrive/Colab Notebooks/test111_stand')
2 frames
/usr/local/lib/python3.10/dist-packages/keras/src/saving/legacy/serialization.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
533 obj = object_registration._GLOBAL_CUSTOM_OBJECTS[object_name]
534 else:
--> 535 obj = module_objects.get(object_name)
536 if obj is None:
537 raise ValueError(
AttributeError: 'NoneType' object has no attribute 'get'
Архитектура модели:
model = tf.keras.Sequential([keras.layers.Input(shape=(1), dtype=tf.string),
vectorize_layer,
keras.layers.Embedding(input_dim=177475, output_dim=18, input_length=max_word),
keras.layers.Dropout(0.2),
keras.layers.GlobalAveragePooling1D(),
keras.layers.Dropout(0.2),
keras.layers.Dense(3, activation='sigmoid')])
Определение vectorize_layer:
vectorize_layer = keras.layers.TextVectorization(
standardize=standartize_str,
max_tokens=177474,
split="character",
output_mode='int',
output_sequence_length=max_word)
vectorize_layer.adapt(df['data'])
Определение пользовательской функции для стандартизации:
def standartize_str(s):
s = tf.strings.regex_replace(s, " ","")
s = tf.strings.lower(s)
return s
Насколько я понял, проблема в использовании пользовательской функции стандартизации standartize_str. Судя по всему при сохранении модели она не серилизуется. Если обойтись без этой функции и воспользоваться стандартной функцией стандартизации, которую предоставляет слой TextVectorization, то ранее сохраненная модель загружается корректно.
Я пробовал добавлять декоратор, к функции стандартизации, но это не помогло:
@keras.saving.register_keras_serializable(package='custom_package', name='standartize_str')
def standartize_str(s):
s = tf.strings.regex_replace(s, " ","")
s = tf.strings.lower(s)
return s
Подскажите, пожалуйста, как корректно применять пользовательскую функцию стандартизации в слое векторизации TextVectorization, чтобы после сохранения модели, в дальнейшем модель загружалась без ошибок?