Files
iddaai-be/ai-engine/data/database.py
T
2026-05-24 02:43:10 +03:00

98 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Async Database Module — V2 Betting Engine
==========================================
Provides async SQLAlchemy sessions via asyncpg for the V2 router.
Usage:
async with get_session() as session:
result = await session.execute(text("SELECT ..."))
"""
from __future__ import annotations
import os
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from dotenv import load_dotenv
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
load_dotenv()
_engine: AsyncEngine | None = None
_session_maker: async_sessionmaker[AsyncSession] | None = None
def _get_async_dsn() -> str:
"""
Convert DATABASE_URL to asyncpg-compatible format.
Handles:
1. Prisma's ``?schema=public`` suffix → stripped
2. ``postgresql://`` driver prefix → ``postgresql+asyncpg://``
"""
dsn = os.getenv(
"DATABASE_URL",
"postgresql://suggestbet:SuGGesT2026SecuRe@localhost:15432/boilerplate_db",
)
# Strip Prisma's ?schema= parameter
if "?" in dsn:
base, query = dsn.split("?", 1)
kept_parts = [
part for part in query.split("&") if part and not part.startswith("schema=")
]
dsn = base if not kept_parts else f"{base}?{'&'.join(kept_parts)}"
# Convert driver prefix for asyncpg
if dsn.startswith("postgresql://"):
dsn = dsn.replace("postgresql://", "postgresql+asyncpg://", 1)
elif dsn.startswith("postgres://"):
dsn = dsn.replace("postgres://", "postgresql+asyncpg://", 1)
return dsn
def _ensure_engine() -> AsyncEngine:
global _engine, _session_maker
if _engine is None:
_engine = create_async_engine(
_get_async_dsn(),
pool_size=5,
max_overflow=5,
pool_timeout=10,
pool_pre_ping=True,
echo=False,
)
_session_maker = async_sessionmaker(
bind=_engine,
class_=AsyncSession,
expire_on_commit=False,
)
print("✅ Async database engine created (asyncpg)")
return _engine
@asynccontextmanager
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""Provide an async session context manager."""
_ensure_engine()
assert _session_maker is not None
async with _session_maker() as session:
yield session
async def dispose_engine() -> None:
"""Shut down the async engine cleanly."""
global _engine, _session_maker
if _engine is not None:
await _engine.dispose()
_engine = None
_session_maker = None
print("️ Async database engine disposed")