Files
iddaai-be/ai-engine/data/db.py
T
2026-05-24 02:43:10 +03:00

93 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Synchronous psycopg2 database helper for the AI Engine.
Uses a thread-safe connection pool for legacy V20+ endpoints.
"""
from __future__ import annotations
import os
from contextlib import contextmanager
from typing import Generator
import psycopg2
from psycopg2 import pool
from psycopg2.extensions import connection as PgConnection
from dotenv import load_dotenv
load_dotenv()
# Safe default with no credentials — will fail fast if not configured.
_DEFAULT_DSN = "postgresql://postgres:postgres@localhost:15432/boilerplate_db"
def get_clean_dsn() -> str:
"""
Return a psycopg2-compatible DSN from DATABASE_URL.
Handles DSN cleanup issues that break raw usage:
1. Prisma appends '?schema=public' which psycopg2 cannot parse.
"""
dsn: str = os.getenv("DATABASE_URL", _DEFAULT_DSN)
connect_timeout: str = os.getenv("PGCONNECT_TIMEOUT", "5").strip() or "5"
# Strip Prisma's ?schema= query parameter while preserving any other query args.
if "?" in dsn:
base, query = dsn.split("?", 1)
kept_parts: list[str] = [
part for part in query.split("&") if part and not part.startswith("schema=")
]
dsn = base if not kept_parts else f"{base}?{'&'.join(kept_parts)}"
# Force bounded DB connect attempts so API calls do not hang indefinitely.
if "connect_timeout=" not in dsn:
separator = "&" if "?" in dsn else "?"
dsn = f"{dsn}{separator}connect_timeout={connect_timeout}"
return dsn
class Database:
_pool: pool.ThreadedConnectionPool | None = None
@classmethod
def initialize(cls) -> None:
if cls._pool is None:
dsn: str = get_clean_dsn()
try:
cls._pool = pool.ThreadedConnectionPool(
minconn=1,
maxconn=10,
dsn=dsn,
)
print("✅ Database connection pool created")
except Exception as e:
print(f"❌ Failed to create DB pool: {e}")
raise
@classmethod
def get_conn(cls) -> PgConnection:
if cls._pool is None:
cls.initialize()
assert cls._pool is not None # guaranteed by initialize()
return cls._pool.getconn()
@classmethod
def return_conn(cls, conn: PgConnection) -> None:
if cls._pool:
cls._pool.putconn(conn)
@classmethod
@contextmanager
def connection(cls) -> Generator[PgConnection, None, None]:
"""Context manager for safe connection handling."""
conn: PgConnection = cls.get_conn()
try:
yield conn
finally:
cls.return_conn(conn)
@classmethod
def close_all(cls) -> None:
if cls._pool:
cls._pool.closeall()
print("️ Database connection pool closed")