Pytest+FastAPI+SQLAlchemy+Postgres: ошибка InterfaceError
Разрабатывая апп на FastAPI, SQLAlchemy и PostgreSQL встретил ошибки при организации тестирования (однако с SQLite всё работает). Сделал репо с минимальной версией аппа и тестированием через Pytest на Docker Compose.
Основная ошибка – sqlalchemy.exc.InterfaceError('cannot perform operation: another operation is in progress'). Подозреваю, что оно связано с иницилизацией приложения/БД, хотя я проверил, что всё происходит последовательно. Также пробовал использовать единственный экземпляр TestClient для всех тест-кейсов, но это не помогло. Надеюсь найти решение и грамотный способ такого тестирования ?
Самые важные куски кода:
app.py:
app = FastAPI()
some_items = dict()
@app.on_event("startup")
async def startup():
await create_database()
# Extract some data from env, local files, or S3
some_items["pi"] = 3.1415926535
some_items["eu"] = 2.7182818284
@app.post("/{name}")
async def create_obj(name: str, request: Request):
data = await request.json()
if data.get("code") in some_items:
data["value"] = some_items[data["code"]]
async with async_session() as session:
async with session.begin():
await create_object(session, name, data)
return JSONResponse(status_code=200, content=data)
else:
return JSONResponse(status_code=404, content={})
@app.get("/{name}")
async def get_connected_register(name: str):
async with async_session() as session:
async with session.begin():
objects = await get_objects(session, name)
result = []
for obj in objects:
result.append({
"id": obj.id, "name": obj.name, **obj.data,
})
return result
tests.py:
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="module")
@pytest.mark.asyncio
async def get_db():
await delete_database()
await create_database()
@pytest.mark.parametrize("test_case", test_cases_post)
def test_post(get_db, test_case):
with TestClient(app)() as client:
response = client.post(f"/{test_case['name']}", json=test_case["data"])
assert response.status_code == test_case["res"]
@pytest.mark.parametrize("test_case", test_cases_get)
def test_get(get_db, test_case):
with TestClient(app)() as client:
response = client.get(f"/{test_case['name']}")
assert len(response.json()) == test_case["count"]
db.py:
DATABASE_URL = environ.get("DATABASE_URL", "sqlite+aiosqlite:///./test.db")
engine = create_async_engine(DATABASE_URL, future=True, echo=True)
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
Base = declarative_base()
async def delete_database():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
async def create_database():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
class Model(Base):
__tablename__ = "smth"
id = Column(Integer, primary_key=True)
name = Column(String, nullable=False)
data = Column(JSON, nullable=False)
idx_main = Index("name", "id")
async def create_object(db: Session, name: str, data: dict):
connection = Model(name=name, data=data)
db.add(connection)
await db.flush()
async def get_objects(db: Session, name: str):
raw_q = select(Model) \
.where(Model.name == name) \
.order_by(Model.id)
q = await db.execute(raw_q)
return q.scalars().all()