AttributeError: 'numpy.ndarray' object has no attribute 'eval'
У меня есть файл, содержащий модель обученной нейронной сети и пример кода, который должен эту модель сохранять/загружать.
PATH ="/content/drive/MyDrive/nn_x1x2.pytorch_model_file"
if False:
print( f'Сохранение в файл "{PATH}" ')
torch.save(nn_x1x2, PATH)
else:
print( f'Загрузка из файла "{PATH}" ')
nn_x1x2 = torch.load(PATH)
nn_x1x2.eval()
Но при попытке запустить его, я получаю ошибку:
AttributeError: 'numpy.ndarray' object has no attribute 'eval'
Я пока только начал изучать анализ данных на python. Подскажите пожалуйста, что здесь не так? И как это можно исправить?
Ответы (1 шт):
Автор решения: Alexey
→ Ссылка
torch.load() скорее всего возвращает np.array.
Чтобы подгрузить модель в pytorch используйте это руководство
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()