AttributeError: 'numpy.ndarray' object has no attribute 'eval'

У меня есть файл, содержащий модель обученной нейронной сети и пример кода, который должен эту модель сохранять/загружать.

PATH ="/content/drive/MyDrive/nn_x1x2.pytorch_model_file"

if False:
    print( f'Сохранение в файл "{PATH}" ')
    torch.save(nn_x1x2, PATH)
else:
    print( f'Загрузка из файла "{PATH}" ')
    nn_x1x2 = torch.load(PATH)
    nn_x1x2.eval()

Но при попытке запустить его, я получаю ошибку:

AttributeError: 'numpy.ndarray' object has no attribute 'eval'

Я пока только начал изучать анализ данных на python. Подскажите пожалуйста, что здесь не так? И как это можно исправить?


Ответы (1 шт):

Автор решения: Alexey

torch.load() скорее всего возвращает np.array.

Чтобы подгрузить модель в pytorch используйте это руководство

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
→ Ссылка