425 lines
14 KiB
Python
425 lines
14 KiB
Python
"""
|
||
Calibration Training Script (REWRITTEN)
|
||
=======================================
|
||
Trains Isotonic Regression calibration models for football markets
|
||
using REAL model predictions + actual match outcomes.
|
||
|
||
Data sources (combined):
|
||
- `predictions` table: Full bet_summary (many markets per match), joined to `matches` for actual results
|
||
- `prediction_runs` table: main_pick + value_pick predictions with resolved outcomes
|
||
|
||
Per market, fits IsotonicRegression(raw_model_prob → actual_hit) so that
|
||
calibrated_prob mirrors empirical hit rate.
|
||
|
||
Usage:
|
||
python ai-engine/scripts/train_calibration.py
|
||
python ai-engine/scripts/train_calibration.py --min-samples 30
|
||
python ai-engine/scripts/train_calibration.py --markets ms_home ou25 btts
|
||
|
||
Notes:
|
||
* Multi-source data extraction tolerates schema drift in payload JSON.
|
||
* If a market has fewer than --min-samples points, it is skipped
|
||
(orchestrator will fall back to the multiplier from market_thresholds.json).
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import pandas as pd
|
||
import psycopg2
|
||
from dotenv import load_dotenv
|
||
|
||
AI_ENGINE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.insert(0, AI_ENGINE_DIR)
|
||
|
||
from models.calibration import get_calibrator # noqa: E402
|
||
|
||
load_dotenv()
|
||
|
||
|
||
# =============================================================================
|
||
# DB
|
||
# =============================================================================
|
||
def get_conn():
|
||
db_url = os.getenv("DATABASE_URL")
|
||
if not db_url:
|
||
raise ValueError("DATABASE_URL not set")
|
||
if "?schema=" in db_url:
|
||
db_url = db_url.split("?schema=")[0]
|
||
return psycopg2.connect(db_url)
|
||
|
||
|
||
# =============================================================================
|
||
# OUTCOME RESOLUTION
|
||
# =============================================================================
|
||
def _normalize_pick(pick: Any) -> str:
|
||
return str(pick or "").strip().casefold()
|
||
|
||
|
||
def _is_over(pick: str) -> bool:
|
||
norm = _normalize_pick(pick)
|
||
return "over" in norm or "üst" in norm or "ust" in norm
|
||
|
||
|
||
def _is_under(pick: str) -> bool:
|
||
norm = _normalize_pick(pick)
|
||
return "under" in norm or "alt" in norm
|
||
|
||
|
||
def _is_yes(pick: str) -> bool:
|
||
norm = _normalize_pick(pick)
|
||
return "yes" in norm or "var" in norm
|
||
|
||
|
||
def resolve_actual(
|
||
market: str,
|
||
pick: str,
|
||
score_home: Optional[int],
|
||
score_away: Optional[int],
|
||
ht_home: Optional[int],
|
||
ht_away: Optional[int],
|
||
) -> Optional[int]:
|
||
"""Return 1 if the (market, pick) hit, 0 if it missed, None if undetermined."""
|
||
if score_home is None or score_away is None:
|
||
return None
|
||
market = (market or "").upper()
|
||
p = _normalize_pick(pick)
|
||
total = score_home + score_away
|
||
ht_total = (ht_home or 0) + (ht_away or 0) if ht_home is not None else None
|
||
|
||
if market == "MS":
|
||
if p == "1":
|
||
return int(score_home > score_away)
|
||
if p in {"x", "0", "x/0"}:
|
||
return int(score_home == score_away)
|
||
if p == "2":
|
||
return int(score_away > score_home)
|
||
return None
|
||
|
||
if market == "DC":
|
||
norm = p.replace("-", "").upper()
|
||
if norm == "1X":
|
||
return int(score_home >= score_away)
|
||
if norm == "X2":
|
||
return int(score_away >= score_home)
|
||
if norm == "12":
|
||
return int(score_home != score_away)
|
||
return None
|
||
|
||
if market in {"OU15", "OU25", "OU35"}:
|
||
line = {"OU15": 1.5, "OU25": 2.5, "OU35": 3.5}[market]
|
||
if _is_over(p):
|
||
return int(total > line)
|
||
if _is_under(p):
|
||
return int(total < line)
|
||
return None
|
||
|
||
if market == "BTTS":
|
||
both_scored = score_home > 0 and score_away > 0
|
||
if _is_yes(p):
|
||
return int(both_scored)
|
||
if "no" in p or "yok" in p:
|
||
return int(not both_scored)
|
||
return None
|
||
|
||
if market == "HT":
|
||
if ht_home is None or ht_away is None:
|
||
return None
|
||
if p == "1":
|
||
return int(ht_home > ht_away)
|
||
if p in {"x", "0"}:
|
||
return int(ht_home == ht_away)
|
||
if p == "2":
|
||
return int(ht_away > ht_home)
|
||
return None
|
||
|
||
if market in {"HT_OU05", "HT_OU15"}:
|
||
if ht_total is None:
|
||
return None
|
||
line = 0.5 if market == "HT_OU05" else 1.5
|
||
if _is_over(p):
|
||
return int(ht_total > line)
|
||
if _is_under(p):
|
||
return int(ht_total < line)
|
||
return None
|
||
|
||
if market == "OE":
|
||
if "odd" in p or "tek" in p:
|
||
return int(total % 2 == 1)
|
||
if "even" in p or "çift" in p or "cift" in p:
|
||
return int(total % 2 == 0)
|
||
return None
|
||
|
||
if market == "HTFT":
|
||
if ht_home is None or ht_away is None or "/" not in p:
|
||
return None
|
||
ht_p, ft_p = p.split("/")
|
||
ht_actual = "1" if ht_home > ht_away else "2" if ht_away > ht_home else "x"
|
||
ft_actual = "1" if score_home > score_away else "2" if score_away > score_home else "x"
|
||
return int(ht_p.strip() == ht_actual and ft_p.strip() == ft_actual)
|
||
|
||
return None
|
||
|
||
|
||
# =============================================================================
|
||
# CALIBRATOR KEY (must mirror orchestrator._calibrator_key)
|
||
# =============================================================================
|
||
def calibrator_key(market: str, pick: str) -> Optional[str]:
|
||
m = (market or "").upper()
|
||
p = _normalize_pick(pick)
|
||
if m == "MS":
|
||
if p == "1":
|
||
return "ms_home"
|
||
if p in {"x", "0"}:
|
||
return "ms_draw"
|
||
if p == "2":
|
||
return "ms_away"
|
||
return None
|
||
if m == "DC":
|
||
return "dc"
|
||
if m == "OU15" and _is_over(p):
|
||
return "ou15"
|
||
if m == "OU25" and _is_over(p):
|
||
return "ou25"
|
||
if m == "OU35" and _is_over(p):
|
||
return "ou35"
|
||
if m == "BTTS" and _is_yes(p):
|
||
return "btts"
|
||
if m == "HT":
|
||
if p == "1":
|
||
return "ht_home"
|
||
if p in {"x", "0"}:
|
||
return "ht_draw"
|
||
if p == "2":
|
||
return "ht_away"
|
||
return None
|
||
if m == "HTFT":
|
||
return "ht_ft"
|
||
return None
|
||
|
||
|
||
# =============================================================================
|
||
# DATA EXTRACTION
|
||
# =============================================================================
|
||
def fetch_predictions_with_outcomes(cur) -> List[Dict[str, Any]]:
|
||
"""
|
||
Source 1: `predictions` table joined with `matches` (FT only).
|
||
Each row of bet_summary becomes a training sample.
|
||
"""
|
||
cur.execute("""
|
||
SELECT
|
||
p.match_id,
|
||
p.prediction_json,
|
||
m.score_home,
|
||
m.score_away,
|
||
m.ht_score_home,
|
||
m.ht_score_away
|
||
FROM predictions p
|
||
JOIN matches m ON m.id = p.match_id
|
||
WHERE m.sport = 'football'
|
||
AND m.status = 'FT'
|
||
AND m.score_home IS NOT NULL
|
||
AND m.score_away IS NOT NULL
|
||
""")
|
||
rows = cur.fetchall()
|
||
samples: List[Dict[str, Any]] = []
|
||
for match_id, payload, sh, sa, ht_h, ht_a in rows:
|
||
if not isinstance(payload, dict):
|
||
continue
|
||
bet_summary = payload.get("bet_summary")
|
||
if not isinstance(bet_summary, list):
|
||
continue
|
||
for item in bet_summary:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
market = str(item.get("market") or "")
|
||
pick = str(item.get("pick") or "")
|
||
raw_conf = item.get("raw_confidence")
|
||
if raw_conf is None:
|
||
continue
|
||
actual = resolve_actual(market, pick, sh, sa, ht_h, ht_a)
|
||
if actual is None:
|
||
continue
|
||
key = calibrator_key(market, pick)
|
||
if not key:
|
||
continue
|
||
samples.append({
|
||
"source": "predictions",
|
||
"match_id": match_id,
|
||
"market": market,
|
||
"pick": pick,
|
||
"key": key,
|
||
"raw_prob": float(raw_conf) / 100.0,
|
||
"actual": int(actual),
|
||
})
|
||
return samples
|
||
|
||
|
||
def fetch_prediction_runs_with_outcomes(cur) -> List[Dict[str, Any]]:
|
||
"""
|
||
Source 2: `prediction_runs` table with resolved settlement.
|
||
Each main_pick / value_pick becomes a training sample.
|
||
"""
|
||
cur.execute("""
|
||
SELECT
|
||
pr.match_id,
|
||
pr.payload_summary,
|
||
m.score_home,
|
||
m.score_away,
|
||
m.ht_score_home,
|
||
m.ht_score_away
|
||
FROM prediction_runs pr
|
||
JOIN matches m ON m.id = pr.match_id
|
||
WHERE pr.eventual_outcome IS NOT NULL
|
||
AND m.score_home IS NOT NULL
|
||
AND m.score_away IS NOT NULL
|
||
""")
|
||
rows = cur.fetchall()
|
||
samples: List[Dict[str, Any]] = []
|
||
for match_id, payload, sh, sa, ht_h, ht_a in rows:
|
||
if not isinstance(payload, dict):
|
||
continue
|
||
for source_key in ("main_pick", "value_pick"):
|
||
item = payload.get(source_key)
|
||
if not isinstance(item, dict):
|
||
continue
|
||
market = str(item.get("market") or "")
|
||
pick = str(item.get("pick") or "")
|
||
# Prefer raw_confidence, fall back to calibrated_probability×100 if raw missing
|
||
raw_conf = item.get("raw_confidence")
|
||
if raw_conf is None:
|
||
cal_prob = item.get("calibrated_probability") or item.get("probability")
|
||
if cal_prob is None:
|
||
continue
|
||
raw_conf = float(cal_prob) * 100.0
|
||
actual = resolve_actual(market, pick, sh, sa, ht_h, ht_a)
|
||
if actual is None:
|
||
continue
|
||
key = calibrator_key(market, pick)
|
||
if not key:
|
||
continue
|
||
samples.append({
|
||
"source": f"runs.{source_key}",
|
||
"match_id": match_id,
|
||
"market": market,
|
||
"pick": pick,
|
||
"key": key,
|
||
"raw_prob": float(raw_conf) / 100.0,
|
||
"actual": int(actual),
|
||
})
|
||
return samples
|
||
|
||
|
||
# =============================================================================
|
||
# TRAINING
|
||
# =============================================================================
|
||
def train_per_key(
|
||
df: pd.DataFrame,
|
||
min_samples: int,
|
||
markets_filter: Optional[List[str]] = None,
|
||
) -> Dict[str, Any]:
|
||
calibrator = get_calibrator()
|
||
results: Dict[str, Any] = {}
|
||
keys = sorted(df["key"].unique())
|
||
|
||
for key in keys:
|
||
if markets_filter and key not in markets_filter:
|
||
continue
|
||
sub = df[df["key"] == key]
|
||
# Drop duplicates by (match_id, key) to avoid double-counting across sources
|
||
sub = sub.drop_duplicates(subset=["match_id", "key"], keep="first")
|
||
sub = sub.dropna(subset=["raw_prob", "actual"])
|
||
# Clamp probabilities to (0, 1) for isotonic stability
|
||
sub = sub[(sub["raw_prob"] > 0.0) & (sub["raw_prob"] < 1.0)]
|
||
|
||
n = len(sub)
|
||
if n < min_samples:
|
||
results[key] = {
|
||
"status": "skipped",
|
||
"samples": n,
|
||
"reason": f"need ≥{min_samples}, have {n}",
|
||
}
|
||
continue
|
||
|
||
metrics = calibrator.train_calibration(
|
||
df=sub,
|
||
market=key,
|
||
prob_col="raw_prob",
|
||
actual_col="actual",
|
||
min_samples=min_samples,
|
||
save=True,
|
||
)
|
||
results[key] = {
|
||
"status": "trained",
|
||
"samples": metrics.sample_count,
|
||
"brier": round(metrics.brier_score, 4),
|
||
"ece": round(metrics.calibration_error, 4),
|
||
"mean_predicted": round(metrics.mean_predicted, 4),
|
||
"mean_actual": round(metrics.mean_actual, 4),
|
||
}
|
||
return results
|
||
|
||
|
||
def print_report(results: Dict[str, Any], total_samples: int) -> None:
|
||
print("\n" + "=" * 78)
|
||
print("CALIBRATION TRAINING REPORT")
|
||
print("=" * 78)
|
||
print(f"Total samples across all markets: {total_samples}")
|
||
print(f"\n{'market':<14} {'status':<10} {'n':<6} {'brier':<9} {'ece':<8} {'pred_avg':<9} {'actual_avg':<10}")
|
||
print("-" * 78)
|
||
for key, info in sorted(results.items()):
|
||
if info["status"] == "trained":
|
||
print(
|
||
f"{key:<14} {'✓ ok':<10} {info['samples']:<6} "
|
||
f"{info['brier']:<9.4f} {info['ece']:<8.4f} "
|
||
f"{info['mean_predicted']:<8.3f} {info['mean_actual']:<8.3f}"
|
||
)
|
||
else:
|
||
print(f"{key:<14} {'⊘ skip':<10} {info['samples']:<6} -- {info.get('reason', '')}")
|
||
print("=" * 78)
|
||
print("Trained models saved to: ai-engine/models/calibration/")
|
||
print("Skipped markets fall back to the multiplier in market_thresholds.json.")
|
||
print("=" * 78)
|
||
|
||
|
||
# =============================================================================
|
||
# CLI
|
||
# =============================================================================
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Train isotonic calibration on real data")
|
||
parser.add_argument("--min-samples", type=int, default=30,
|
||
help="Minimum samples required per market (default: 30)")
|
||
parser.add_argument("--markets", nargs="+", default=None,
|
||
help="Limit to specific calibrator keys (e.g., ms_home ou25)")
|
||
args = parser.parse_args()
|
||
|
||
conn = get_conn()
|
||
cur = conn.cursor()
|
||
try:
|
||
s1 = fetch_predictions_with_outcomes(cur)
|
||
s2 = fetch_prediction_runs_with_outcomes(cur)
|
||
print(f"[Data] predictions table: {len(s1)} samples")
|
||
print(f"[Data] prediction_runs: {len(s2)} samples")
|
||
all_samples = s1 + s2
|
||
if not all_samples:
|
||
print("[Error] No training samples available")
|
||
return
|
||
df = pd.DataFrame(all_samples)
|
||
print(f"[Data] Combined: {len(df)} samples")
|
||
print(f"[Data] Unique matches: {df['match_id'].nunique()}")
|
||
print(f"[Data] Per-key counts:")
|
||
for key, count in df["key"].value_counts().items():
|
||
print(f" {key:<14} {count}")
|
||
|
||
results = train_per_key(df, args.min_samples, args.markets)
|
||
print_report(results, total_samples=len(df))
|
||
finally:
|
||
cur.close()
|
||
conn.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|