Ошибка во время деплоя модели градиентного спуска для задачи классификации текста
Для задачи классификации текста использовал модель SGD(ноутбук). Начал деплоить. Но возникла не понятная ошибка. Шаги:
- Сериализация модели
joblib.dump(SGDClassifier, 'spam_sgd_model.pkl')
joblib.dump(vectorizer, 'tfidf_vectorizer_model.pkl')
joblib.dump(sc, 'standard_scaler_model.pkl')
- Подключение FastAPI
from fastapi import FastAPI
app = FastAPI()
- Построение Deployment API
import joblib
from pydantic import BaseModel
model_load = joblib.load('spam_sgd_model.pkl')
vectorizer_load = joblib.load('tfidf_vectorizer_model.pkl')
scaler_load = joblib.load('standard_scaler_model.pkl')
class Email(BaseModel):
text: str
@app.post("/predict/")
async def predict(email: Email):
text_vectorized = vectorizer_load.transform([email.text])
text_scaled = scaler_load.transform(text_vectorized)
prediction = model_load.predict(text_scaled)
return {"prediction": int(prediction[0])}
- Хостинг
nest_asyncio.apply()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)
Далее решил потестить работу программы и запустил в отдельном поле код:
import requests
sample_text = "viiiiiiagraaaa only for the ones that want to make her scream. prodigy scrawny crow define upgrade. caant do..."
url = "http://127.0.0.1:8080/predict/"
try:
response = requests.post(url, json={"text": sample_text})
response.raise_for_status() # Raises an error for bad status codes
print(response.json())
except requests.exceptions.HTTPError as http_err:
print(f"HTTP error occurred: {http_err}")
except requests.exceptions.RequestException as req_err:
print(f"Request error occurred: {req_err}")
except ValueError as json_err:
print(f"JSON decode error: {json_err}")
print("Response content:", response.text)
В итоге возникла неожиданная ошибка(полный текст ошибки в ноутбуке по ссылке сверху):
File "C:\Users\Murad Mammadzade\AppData\Local\Temp\ipykernel_12564\220705840.py", line 8, in predict
prediction = model_load.predict(text_scaled)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: LinearClassifierMixin.predict() missing 1 required positional argument: 'X'
Попробовал уже разные методы + ChatGPT, но тщетно.
Ответы (1 шт):
Автор решения: Stan
→ Ссылка
Тут ошибка в сериализации модели 'spam_sgd_model.pkl'
. Вместо того чтобы вставить уже обученную модель под названием model
, вставлена необученная модель градиентного спуска SGDClassifier
, в этом и вся ошибка.