93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
"""
|
||
Synchronous psycopg2 database helper for the AI Engine.
|
||
Uses a thread-safe connection pool for legacy V20+ endpoints.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
from contextlib import contextmanager
|
||
from typing import Generator
|
||
|
||
import psycopg2
|
||
from psycopg2 import pool
|
||
from psycopg2.extensions import connection as PgConnection
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
# Safe default with no credentials — will fail fast if not configured.
|
||
_DEFAULT_DSN = "postgresql://postgres:postgres@localhost:15432/boilerplate_db"
|
||
|
||
|
||
def get_clean_dsn() -> str:
|
||
"""
|
||
Return a psycopg2-compatible DSN from DATABASE_URL.
|
||
|
||
Handles DSN cleanup issues that break raw usage:
|
||
1. Prisma appends '?schema=public' which psycopg2 cannot parse.
|
||
"""
|
||
dsn: str = os.getenv("DATABASE_URL", _DEFAULT_DSN)
|
||
connect_timeout: str = os.getenv("PGCONNECT_TIMEOUT", "5").strip() or "5"
|
||
|
||
# Strip Prisma's ?schema= query parameter while preserving any other query args.
|
||
if "?" in dsn:
|
||
base, query = dsn.split("?", 1)
|
||
kept_parts: list[str] = [
|
||
part for part in query.split("&") if part and not part.startswith("schema=")
|
||
]
|
||
dsn = base if not kept_parts else f"{base}?{'&'.join(kept_parts)}"
|
||
|
||
# Force bounded DB connect attempts so API calls do not hang indefinitely.
|
||
if "connect_timeout=" not in dsn:
|
||
separator = "&" if "?" in dsn else "?"
|
||
dsn = f"{dsn}{separator}connect_timeout={connect_timeout}"
|
||
return dsn
|
||
|
||
|
||
class Database:
|
||
_pool: pool.ThreadedConnectionPool | None = None
|
||
|
||
@classmethod
|
||
def initialize(cls) -> None:
|
||
if cls._pool is None:
|
||
dsn: str = get_clean_dsn()
|
||
try:
|
||
cls._pool = pool.ThreadedConnectionPool(
|
||
minconn=1,
|
||
maxconn=10,
|
||
dsn=dsn,
|
||
)
|
||
print("✅ Database connection pool created")
|
||
except Exception as e:
|
||
print(f"❌ Failed to create DB pool: {e}")
|
||
raise
|
||
|
||
@classmethod
|
||
def get_conn(cls) -> PgConnection:
|
||
if cls._pool is None:
|
||
cls.initialize()
|
||
assert cls._pool is not None # guaranteed by initialize()
|
||
return cls._pool.getconn()
|
||
|
||
@classmethod
|
||
def return_conn(cls, conn: PgConnection) -> None:
|
||
if cls._pool:
|
||
cls._pool.putconn(conn)
|
||
|
||
@classmethod
|
||
@contextmanager
|
||
def connection(cls) -> Generator[PgConnection, None, None]:
|
||
"""Context manager for safe connection handling."""
|
||
conn: PgConnection = cls.get_conn()
|
||
try:
|
||
yield conn
|
||
finally:
|
||
cls.return_conn(conn)
|
||
|
||
@classmethod
|
||
def close_all(cls) -> None:
|
||
if cls._pool:
|
||
cls._pool.closeall()
|
||
print("ℹ️ Database connection pool closed")
|