This commit is contained in:
@@ -1,63 +1,48 @@
|
||||
"""
|
||||
Calibration Training Script
|
||||
===========================
|
||||
Trains Isotonic Regression calibration models for all betting markets.
|
||||
Calibration Training Script (REWRITTEN)
|
||||
=======================================
|
||||
Trains Isotonic Regression calibration models for football markets
|
||||
using REAL model predictions + actual match outcomes.
|
||||
|
||||
This script:
|
||||
1. Fetches historical match data with predictions and actual results
|
||||
2. Trains Isotonic Regression models for each market
|
||||
3. Calculates calibration metrics (Brier Score, ECE)
|
||||
4. Saves models to ai-engine/models/calibration/
|
||||
Data sources (combined):
|
||||
- `predictions` table: Full bet_summary (many markets per match), joined to `matches` for actual results
|
||||
- `prediction_runs` table: main_pick + value_pick predictions with resolved outcomes
|
||||
|
||||
Per market, fits IsotonicRegression(raw_model_prob → actual_hit) so that
|
||||
calibrated_prob mirrors empirical hit rate.
|
||||
|
||||
Usage:
|
||||
# Train on last 90 days of data
|
||||
python3 ai-engine/scripts/train_calibration.py
|
||||
|
||||
# Train on specific date range
|
||||
python3 ai-engine/scripts/train_calibration.py --start 2026-01-01 --end 2026-02-15
|
||||
|
||||
# Train only specific markets
|
||||
python3 ai-engine/scripts/train_calibration.py --markets ou25 btts ms_home
|
||||
python ai-engine/scripts/train_calibration.py
|
||||
python ai-engine/scripts/train_calibration.py --min-samples 30
|
||||
python ai-engine/scripts/train_calibration.py --markets ms_home ou25 btts
|
||||
|
||||
Notes:
|
||||
* Multi-source data extraction tolerates schema drift in payload JSON.
|
||||
* If a market has fewer than --min-samples points, it is skipped
|
||||
(orchestrator will fall back to the multiplier from market_thresholds.json).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import psycopg2
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
from typing import Dict, List, Tuple, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
import psycopg2
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Setup path for ai-engine imports
|
||||
AI_ENGINE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, AI_ENGINE_DIR)
|
||||
|
||||
from models.calibration import get_calibrator, SUPPORTED_MARKETS
|
||||
from models.calibration import get_calibrator # noqa: E402
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CONFIG
|
||||
# =============================================================================
|
||||
TOP_LEAGUES_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(AI_ENGINE_DIR)),
|
||||
"top_leagues.json"
|
||||
)
|
||||
|
||||
# Default: last 90 days
|
||||
DEFAULT_START_DATE = (datetime.utcnow() - timedelta(days=90)).strftime("%Y-%m-%d")
|
||||
DEFAULT_END_DATE = (datetime.utcnow() - timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DB CONNECTION
|
||||
# DB
|
||||
# =============================================================================
|
||||
def get_conn():
|
||||
"""Get PostgreSQL connection."""
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
if not db_url:
|
||||
raise ValueError("DATABASE_URL not set")
|
||||
@@ -66,354 +51,370 @@ def get_conn():
|
||||
return psycopg2.connect(db_url)
|
||||
|
||||
|
||||
def load_top_league_ids() -> List[str]:
|
||||
"""Load top league IDs from JSON file."""
|
||||
if not os.path.exists(TOP_LEAGUES_PATH):
|
||||
print(f"[Warning] top_leagues.json not found at {TOP_LEAGUES_PATH}")
|
||||
return []
|
||||
|
||||
with open(TOP_LEAGUES_PATH, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Handle both list and dict formats
|
||||
if isinstance(data, dict):
|
||||
return data.get("football", [])
|
||||
return data
|
||||
# =============================================================================
|
||||
# OUTCOME RESOLUTION
|
||||
# =============================================================================
|
||||
def _normalize_pick(pick: Any) -> str:
|
||||
return str(pick or "").strip().casefold()
|
||||
|
||||
|
||||
def _is_over(pick: str) -> bool:
|
||||
norm = _normalize_pick(pick)
|
||||
return "over" in norm or "üst" in norm or "ust" in norm
|
||||
|
||||
|
||||
def _is_under(pick: str) -> bool:
|
||||
norm = _normalize_pick(pick)
|
||||
return "under" in norm or "alt" in norm
|
||||
|
||||
|
||||
def _is_yes(pick: str) -> bool:
|
||||
norm = _normalize_pick(pick)
|
||||
return "yes" in norm or "var" in norm
|
||||
|
||||
|
||||
def resolve_actual(
|
||||
market: str,
|
||||
pick: str,
|
||||
score_home: Optional[int],
|
||||
score_away: Optional[int],
|
||||
ht_home: Optional[int],
|
||||
ht_away: Optional[int],
|
||||
) -> Optional[int]:
|
||||
"""Return 1 if the (market, pick) hit, 0 if it missed, None if undetermined."""
|
||||
if score_home is None or score_away is None:
|
||||
return None
|
||||
market = (market or "").upper()
|
||||
p = _normalize_pick(pick)
|
||||
total = score_home + score_away
|
||||
ht_total = (ht_home or 0) + (ht_away or 0) if ht_home is not None else None
|
||||
|
||||
if market == "MS":
|
||||
if p == "1":
|
||||
return int(score_home > score_away)
|
||||
if p in {"x", "0", "x/0"}:
|
||||
return int(score_home == score_away)
|
||||
if p == "2":
|
||||
return int(score_away > score_home)
|
||||
return None
|
||||
|
||||
if market == "DC":
|
||||
norm = p.replace("-", "").upper()
|
||||
if norm == "1X":
|
||||
return int(score_home >= score_away)
|
||||
if norm == "X2":
|
||||
return int(score_away >= score_home)
|
||||
if norm == "12":
|
||||
return int(score_home != score_away)
|
||||
return None
|
||||
|
||||
if market in {"OU15", "OU25", "OU35"}:
|
||||
line = {"OU15": 1.5, "OU25": 2.5, "OU35": 3.5}[market]
|
||||
if _is_over(p):
|
||||
return int(total > line)
|
||||
if _is_under(p):
|
||||
return int(total < line)
|
||||
return None
|
||||
|
||||
if market == "BTTS":
|
||||
both_scored = score_home > 0 and score_away > 0
|
||||
if _is_yes(p):
|
||||
return int(both_scored)
|
||||
if "no" in p or "yok" in p:
|
||||
return int(not both_scored)
|
||||
return None
|
||||
|
||||
if market == "HT":
|
||||
if ht_home is None or ht_away is None:
|
||||
return None
|
||||
if p == "1":
|
||||
return int(ht_home > ht_away)
|
||||
if p in {"x", "0"}:
|
||||
return int(ht_home == ht_away)
|
||||
if p == "2":
|
||||
return int(ht_away > ht_home)
|
||||
return None
|
||||
|
||||
if market in {"HT_OU05", "HT_OU15"}:
|
||||
if ht_total is None:
|
||||
return None
|
||||
line = 0.5 if market == "HT_OU05" else 1.5
|
||||
if _is_over(p):
|
||||
return int(ht_total > line)
|
||||
if _is_under(p):
|
||||
return int(ht_total < line)
|
||||
return None
|
||||
|
||||
if market == "OE":
|
||||
if "odd" in p or "tek" in p:
|
||||
return int(total % 2 == 1)
|
||||
if "even" in p or "çift" in p or "cift" in p:
|
||||
return int(total % 2 == 0)
|
||||
return None
|
||||
|
||||
if market == "HTFT":
|
||||
if ht_home is None or ht_away is None or "/" not in p:
|
||||
return None
|
||||
ht_p, ft_p = p.split("/")
|
||||
ht_actual = "1" if ht_home > ht_away else "2" if ht_away > ht_home else "x"
|
||||
ft_actual = "1" if score_home > score_away else "2" if score_away > score_home else "x"
|
||||
return int(ht_p.strip() == ht_actual and ft_p.strip() == ft_actual)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CALIBRATOR KEY (must mirror orchestrator._calibrator_key)
|
||||
# =============================================================================
|
||||
def calibrator_key(market: str, pick: str) -> Optional[str]:
|
||||
m = (market or "").upper()
|
||||
p = _normalize_pick(pick)
|
||||
if m == "MS":
|
||||
if p == "1":
|
||||
return "ms_home"
|
||||
if p in {"x", "0"}:
|
||||
return "ms_draw"
|
||||
if p == "2":
|
||||
return "ms_away"
|
||||
return None
|
||||
if m == "DC":
|
||||
return "dc"
|
||||
if m == "OU15" and _is_over(p):
|
||||
return "ou15"
|
||||
if m == "OU25" and _is_over(p):
|
||||
return "ou25"
|
||||
if m == "OU35" and _is_over(p):
|
||||
return "ou35"
|
||||
if m == "BTTS" and _is_yes(p):
|
||||
return "btts"
|
||||
if m == "HT":
|
||||
if p == "1":
|
||||
return "ht_home"
|
||||
if p in {"x", "0"}:
|
||||
return "ht_draw"
|
||||
if p == "2":
|
||||
return "ht_away"
|
||||
return None
|
||||
if m == "HTFT":
|
||||
return "ht_ft"
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DATA EXTRACTION
|
||||
# =============================================================================
|
||||
def fetch_training_data(
|
||||
cur,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
league_ids: List[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
def fetch_predictions_with_outcomes(cur) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch match data with odds and results for calibration training.
|
||||
|
||||
Returns DataFrame with columns:
|
||||
- match_id
|
||||
- home_team, away_team
|
||||
- ms_h, ms_d, ms_a (odds)
|
||||
- score_home, score_away (actual result)
|
||||
- ht_score_home, ht_score_away
|
||||
- ou25_actual, btts_actual, etc.
|
||||
Source 1: `predictions` table joined with `matches` (FT only).
|
||||
Each row of bet_summary becomes a training sample.
|
||||
"""
|
||||
start_ms = int(datetime.strptime(start_date, "%Y-%m-%d").timestamp() * 1000)
|
||||
end_ms = int(datetime.strptime(end_date, "%Y-%m-%d").timestamp() * 1000) + 86400000 # +1 day
|
||||
|
||||
# Build league filter
|
||||
league_filter = ""
|
||||
params = [start_ms, end_ms]
|
||||
if league_ids:
|
||||
placeholders = ",".join(["%s"] * len(league_ids))
|
||||
league_filter = f"AND m.league_id IN ({placeholders})"
|
||||
params.extend(league_ids)
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
m.id as match_id,
|
||||
m.home_team_id,
|
||||
m.away_team_id,
|
||||
m.score_home,
|
||||
m.score_away,
|
||||
m.ht_score_home,
|
||||
m.ht_score_away,
|
||||
m.mst_utc,
|
||||
-- Odds from odd_categories/selections
|
||||
MAX(CASE WHEN oc.name = 'Maç Sonucu' AND os.name = '1' THEN os.odd_value END) as ms_h,
|
||||
MAX(CASE WHEN oc.name = 'Maç Sonucu' AND os.name = 'X' THEN os.odd_value END) as ms_d,
|
||||
MAX(CASE WHEN oc.name = 'Maç Sonucu' AND os.name = '2' THEN os.odd_value END) as ms_a,
|
||||
MAX(CASE WHEN oc.name = '2,5 Alt/Üst' AND os.name = 'Üst' THEN os.odd_value END) as ou25_over,
|
||||
MAX(CASE WHEN oc.name = '2,5 Alt/Üst' AND os.name = 'Alt' THEN os.odd_value END) as ou25_under,
|
||||
MAX(CASE WHEN oc.name = '1,5 Alt/Üst' AND os.name = 'Üst' THEN os.odd_value END) as ou15_over,
|
||||
MAX(CASE WHEN oc.name = '3,5 Alt/Üst' AND os.name = 'Üst' THEN os.odd_value END) as ou35_over,
|
||||
MAX(CASE WHEN oc.name = 'Karşılıklı Gol' AND os.name = 'Var' THEN os.odd_value END) as btts_yes,
|
||||
MAX(CASE WHEN oc.name = 'Karşılıklı Gol' AND os.name = 'Yok' THEN os.odd_value END) as btts_no
|
||||
FROM matches m
|
||||
LEFT JOIN odd_categories oc ON oc.match_id = m.id
|
||||
LEFT JOIN odd_selections os ON os.odd_category_db_id = oc.db_id
|
||||
WHERE m.mst_utc >= %s
|
||||
AND m.mst_utc < %s
|
||||
AND m.status = 'FT'
|
||||
AND m.score_home IS NOT NULL
|
||||
AND m.score_away IS NOT NULL
|
||||
{league_filter}
|
||||
GROUP BY m.id, m.home_team_id, m.away_team_id, m.score_home, m.score_away,
|
||||
m.ht_score_home, m.ht_score_away, m.mst_utc
|
||||
ORDER BY m.mst_utc DESC
|
||||
"""
|
||||
|
||||
cur.execute(query, params)
|
||||
cur.execute("""
|
||||
SELECT
|
||||
p.match_id,
|
||||
p.prediction_json,
|
||||
m.score_home,
|
||||
m.score_away,
|
||||
m.ht_score_home,
|
||||
m.ht_score_away
|
||||
FROM predictions p
|
||||
JOIN matches m ON m.id = p.match_id
|
||||
WHERE m.sport = 'football'
|
||||
AND m.status = 'FT'
|
||||
AND m.score_home IS NOT NULL
|
||||
AND m.score_away IS NOT NULL
|
||||
""")
|
||||
rows = cur.fetchall()
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
|
||||
df = pd.DataFrame(rows, columns=columns)
|
||||
print(f"[Data] Fetched {len(df)} matches from {start_date} to {end_date}")
|
||||
|
||||
return df
|
||||
samples: List[Dict[str, Any]] = []
|
||||
for match_id, payload, sh, sa, ht_h, ht_a in rows:
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
bet_summary = payload.get("bet_summary")
|
||||
if not isinstance(bet_summary, list):
|
||||
continue
|
||||
for item in bet_summary:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
market = str(item.get("market") or "")
|
||||
pick = str(item.get("pick") or "")
|
||||
raw_conf = item.get("raw_confidence")
|
||||
if raw_conf is None:
|
||||
continue
|
||||
actual = resolve_actual(market, pick, sh, sa, ht_h, ht_a)
|
||||
if actual is None:
|
||||
continue
|
||||
key = calibrator_key(market, pick)
|
||||
if not key:
|
||||
continue
|
||||
samples.append({
|
||||
"source": "predictions",
|
||||
"match_id": match_id,
|
||||
"market": market,
|
||||
"pick": pick,
|
||||
"key": key,
|
||||
"raw_prob": float(raw_conf) / 100.0,
|
||||
"actual": int(actual),
|
||||
})
|
||||
return samples
|
||||
|
||||
|
||||
def calculate_actual_outcomes(df: pd.DataFrame) -> pd.DataFrame:
|
||||
def fetch_prediction_runs_with_outcomes(cur) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Calculate actual binary outcomes for each market.
|
||||
|
||||
Adds columns:
|
||||
- ms_home_actual: 1 if home won, 0 otherwise
|
||||
- ms_draw_actual: 1 if draw, 0 otherwise
|
||||
- ms_away_actual: 1 if away won, 0 otherwise
|
||||
- ou25_over_actual: 1 if total goals > 2.5, 0 otherwise
|
||||
- ou15_over_actual: 1 if total goals > 1.5, 0 otherwise
|
||||
- ou35_over_actual: 1 if total goals > 3.5, 0 otherwise
|
||||
- btts_yes_actual: 1 if both teams scored, 0 otherwise
|
||||
Source 2: `prediction_runs` table with resolved settlement.
|
||||
Each main_pick / value_pick becomes a training sample.
|
||||
"""
|
||||
# Total goals
|
||||
df["total_goals"] = df["score_home"] + df["score_away"]
|
||||
df["ht_total_goals"] = df["ht_score_home"].fillna(0) + df["ht_score_away"].fillna(0)
|
||||
|
||||
# Match result outcomes
|
||||
df["ms_home_actual"] = (df["score_home"] > df["score_away"]).astype(int)
|
||||
df["ms_draw_actual"] = (df["score_home"] == df["score_away"]).astype(int)
|
||||
df["ms_away_actual"] = (df["score_home"] < df["score_away"]).astype(int)
|
||||
|
||||
# Over/Under outcomes
|
||||
df["ou25_over_actual"] = (df["total_goals"] > 2.5).astype(int)
|
||||
df["ou15_over_actual"] = (df["total_goals"] > 1.5).astype(int)
|
||||
df["ou35_over_actual"] = (df["total_goals"] > 3.5).astype(int)
|
||||
|
||||
# BTTS outcome
|
||||
df["btts_yes_actual"] = ((df["score_home"] > 0) & (df["score_away"] > 0)).astype(int)
|
||||
|
||||
# Half-Time result
|
||||
df["ht_home_actual"] = (df["ht_score_home"] > df["ht_score_away"]).astype(int)
|
||||
df["ht_draw_actual"] = (df["ht_score_home"] == df["ht_score_away"]).astype(int)
|
||||
df["ht_away_actual"] = (df["ht_score_home"] < df["ht_score_away"]).astype(int)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def calculate_implied_probabilities(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate implied probabilities from odds.
|
||||
|
||||
Adds columns:
|
||||
- ms_home_prob: implied probability from odds
|
||||
- ms_draw_prob
|
||||
- ms_away_prob
|
||||
- ou25_over_prob
|
||||
- etc.
|
||||
"""
|
||||
def safe_implied_prob(odd_str: str) -> float:
|
||||
"""Convert odds string to implied probability."""
|
||||
if pd.isna(odd_str) or odd_str is None:
|
||||
return np.nan
|
||||
try:
|
||||
odd = float(odd_str)
|
||||
if odd <= 1.0:
|
||||
return np.nan
|
||||
return 1.0 / odd
|
||||
except (ValueError, TypeError):
|
||||
return np.nan
|
||||
|
||||
# Match result implied probabilities
|
||||
df["ms_home_prob"] = df["ms_h"].apply(safe_implied_prob)
|
||||
df["ms_draw_prob"] = df["ms_d"].apply(safe_implied_prob)
|
||||
df["ms_away_prob"] = df["ms_a"].apply(safe_implied_prob)
|
||||
|
||||
# Over/Under implied probabilities
|
||||
df["ou25_over_prob"] = df["ou25_over"].apply(safe_implied_prob)
|
||||
df["ou15_over_prob"] = df["ou15_over"].apply(safe_implied_prob)
|
||||
df["ou35_over_prob"] = df["ou35_over"].apply(safe_implied_prob)
|
||||
|
||||
# BTTS implied probabilities
|
||||
df["btts_yes_prob"] = df["btts_yes"].apply(safe_implied_prob)
|
||||
|
||||
# -----------------------------------------------------
|
||||
# CONTEXT-AWARE BUCKETS
|
||||
# Create separate probability and actual columns for odds buckets
|
||||
# ms_home odds: ms_h (note ms_h is the bookmaker odds for home win)
|
||||
# -----------------------------------------------------
|
||||
# Helper to safe-cast to float
|
||||
df['ms_h_num'] = pd.to_numeric(df['ms_h'], errors='coerce')
|
||||
|
||||
# Bucket 1: Heavy Fav (odds <= 1.40)
|
||||
b1_mask = df['ms_h_num'] <= 1.40
|
||||
df.loc[b1_mask, 'ms_home_heavy_fav_prob'] = df.loc[b1_mask, 'ms_home_prob']
|
||||
df.loc[b1_mask, 'ms_home_heavy_fav_actual'] = df.loc[b1_mask, 'ms_home_actual']
|
||||
|
||||
# Bucket 2: Fav (1.40 < odds <= 1.80)
|
||||
b2_mask = (df['ms_h_num'] > 1.40) & (df['ms_h_num'] <= 1.80)
|
||||
df.loc[b2_mask, 'ms_home_fav_prob'] = df.loc[b2_mask, 'ms_home_prob']
|
||||
df.loc[b2_mask, 'ms_home_fav_actual'] = df.loc[b2_mask, 'ms_home_actual']
|
||||
|
||||
# Bucket 3: Balanced (1.80 < odds <= 2.50)
|
||||
b3_mask = (df['ms_h_num'] > 1.80) & (df['ms_h_num'] <= 2.50)
|
||||
df.loc[b3_mask, 'ms_home_balanced_prob'] = df.loc[b3_mask, 'ms_home_prob']
|
||||
df.loc[b3_mask, 'ms_home_balanced_actual'] = df.loc[b3_mask, 'ms_home_actual']
|
||||
|
||||
# Bucket 4: Underdog (odds > 2.50)
|
||||
b4_mask = df['ms_h_num'] > 2.50
|
||||
df.loc[b4_mask, 'ms_home_underdog_prob'] = df.loc[b4_mask, 'ms_home_prob']
|
||||
df.loc[b4_mask, 'ms_home_underdog_actual'] = df.loc[b4_mask, 'ms_home_actual']
|
||||
|
||||
return df
|
||||
cur.execute("""
|
||||
SELECT
|
||||
pr.match_id,
|
||||
pr.payload_summary,
|
||||
m.score_home,
|
||||
m.score_away,
|
||||
m.ht_score_home,
|
||||
m.ht_score_away
|
||||
FROM prediction_runs pr
|
||||
JOIN matches m ON m.id = pr.match_id
|
||||
WHERE pr.eventual_outcome IS NOT NULL
|
||||
AND m.score_home IS NOT NULL
|
||||
AND m.score_away IS NOT NULL
|
||||
""")
|
||||
rows = cur.fetchall()
|
||||
samples: List[Dict[str, Any]] = []
|
||||
for match_id, payload, sh, sa, ht_h, ht_a in rows:
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
for source_key in ("main_pick", "value_pick"):
|
||||
item = payload.get(source_key)
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
market = str(item.get("market") or "")
|
||||
pick = str(item.get("pick") or "")
|
||||
# Prefer raw_confidence, fall back to calibrated_probability×100 if raw missing
|
||||
raw_conf = item.get("raw_confidence")
|
||||
if raw_conf is None:
|
||||
cal_prob = item.get("calibrated_probability") or item.get("probability")
|
||||
if cal_prob is None:
|
||||
continue
|
||||
raw_conf = float(cal_prob) * 100.0
|
||||
actual = resolve_actual(market, pick, sh, sa, ht_h, ht_a)
|
||||
if actual is None:
|
||||
continue
|
||||
key = calibrator_key(market, pick)
|
||||
if not key:
|
||||
continue
|
||||
samples.append({
|
||||
"source": f"runs.{source_key}",
|
||||
"match_id": match_id,
|
||||
"market": market,
|
||||
"pick": pick,
|
||||
"key": key,
|
||||
"raw_prob": float(raw_conf) / 100.0,
|
||||
"actual": int(actual),
|
||||
})
|
||||
return samples
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MODEL PREDICTIONS (Optional - if you want to calibrate model outputs)
|
||||
# TRAINING
|
||||
# =============================================================================
|
||||
def get_model_predictions(
|
||||
def train_per_key(
|
||||
df: pd.DataFrame,
|
||||
cur,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Get model predictions for each match.
|
||||
|
||||
This is optional - if you want to calibrate model outputs rather than
|
||||
raw odds-implied probabilities.
|
||||
|
||||
TODO: Implement if needed. For now, we use odds-implied probabilities
|
||||
as a proxy for model predictions.
|
||||
"""
|
||||
# For now, return odds-implied probabilities as "model predictions"
|
||||
# In a full implementation, you would:
|
||||
# 1. Load the V20 predictor
|
||||
# 2. Run predictions for each match
|
||||
# 3. Store raw model probabilities
|
||||
|
||||
return df
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN TRAINING
|
||||
# =============================================================================
|
||||
def train_calibration_models(
|
||||
df: pd.DataFrame,
|
||||
markets: List[str] = None,
|
||||
min_samples: int = 100,
|
||||
min_samples: int,
|
||||
markets_filter: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train calibration models for specified markets.
|
||||
|
||||
Args:
|
||||
df: DataFrame with probabilities and actual outcomes
|
||||
markets: List of markets to train (default: all supported)
|
||||
min_samples: Minimum samples required per market
|
||||
|
||||
Returns:
|
||||
Dict with training results
|
||||
"""
|
||||
if markets is None:
|
||||
markets = SUPPORTED_MARKETS
|
||||
|
||||
calibrator = get_calibrator()
|
||||
|
||||
# Define market config: market -> (prob_col, actual_col)
|
||||
market_config = {
|
||||
"ms_home": ("ms_home_prob", "ms_home_actual"),
|
||||
"ms_home_heavy_fav": ("ms_home_heavy_fav_prob", "ms_home_heavy_fav_actual"),
|
||||
"ms_home_fav": ("ms_home_fav_prob", "ms_home_fav_actual"),
|
||||
"ms_home_balanced": ("ms_home_balanced_prob", "ms_home_balanced_actual"),
|
||||
"ms_home_underdog": ("ms_home_underdog_prob", "ms_home_underdog_actual"),
|
||||
"ms_draw": ("ms_draw_prob", "ms_draw_actual"),
|
||||
"ms_away": ("ms_away_prob", "ms_away_actual"),
|
||||
"ou15": ("ou15_over_prob", "ou15_over_actual"),
|
||||
"ou25": ("ou25_over_prob", "ou25_over_actual"),
|
||||
"ou35": ("ou35_over_prob", "ou35_over_actual"),
|
||||
"btts": ("btts_yes_prob", "btts_yes_actual"),
|
||||
"ht_home": ("ht_home_prob", "ht_home_actual"), # Note: need to add ht probs
|
||||
"ht_draw": ("ht_draw_prob", "ht_draw_actual"),
|
||||
"ht_away": ("ht_away_prob", "ht_away_actual"),
|
||||
}
|
||||
|
||||
# Filter to requested markets
|
||||
market_config = {k: v for k, v in market_config.items() if k in markets}
|
||||
|
||||
# Train all markets
|
||||
results = calibrator.train_all_markets(
|
||||
df=df,
|
||||
market_config=market_config,
|
||||
min_samples=min_samples,
|
||||
)
|
||||
|
||||
results: Dict[str, Any] = {}
|
||||
keys = sorted(df["key"].unique())
|
||||
|
||||
for key in keys:
|
||||
if markets_filter and key not in markets_filter:
|
||||
continue
|
||||
sub = df[df["key"] == key]
|
||||
# Drop duplicates by (match_id, key) to avoid double-counting across sources
|
||||
sub = sub.drop_duplicates(subset=["match_id", "key"], keep="first")
|
||||
sub = sub.dropna(subset=["raw_prob", "actual"])
|
||||
# Clamp probabilities to (0, 1) for isotonic stability
|
||||
sub = sub[(sub["raw_prob"] > 0.0) & (sub["raw_prob"] < 1.0)]
|
||||
|
||||
n = len(sub)
|
||||
if n < min_samples:
|
||||
results[key] = {
|
||||
"status": "skipped",
|
||||
"samples": n,
|
||||
"reason": f"need ≥{min_samples}, have {n}",
|
||||
}
|
||||
continue
|
||||
|
||||
metrics = calibrator.train_calibration(
|
||||
df=sub,
|
||||
market=key,
|
||||
prob_col="raw_prob",
|
||||
actual_col="actual",
|
||||
min_samples=min_samples,
|
||||
save=True,
|
||||
)
|
||||
results[key] = {
|
||||
"status": "trained",
|
||||
"samples": metrics.sample_count,
|
||||
"brier": round(metrics.brier_score, 4),
|
||||
"ece": round(metrics.calibration_error, 4),
|
||||
"mean_predicted": round(metrics.mean_predicted, 4),
|
||||
"mean_actual": round(metrics.mean_actual, 4),
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
def print_calibration_report(results: Dict[str, Any]):
|
||||
"""Print a formatted calibration report."""
|
||||
print("\n" + "=" * 70)
|
||||
def print_report(results: Dict[str, Any], total_samples: int) -> None:
|
||||
print("\n" + "=" * 78)
|
||||
print("CALIBRATION TRAINING REPORT")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n{'Market':<15} {'Brier':<10} {'ECE':<10} {'Samples':<10} {'Status'}")
|
||||
print("-" * 60)
|
||||
|
||||
for market, metrics in results.items():
|
||||
status = "✓ Trained" if metrics.sample_count >= 100 else "⚠ Insufficient"
|
||||
print(f"{market:<15} {metrics.brier_score:<10.4f} {metrics.calibration_error:<10.4f} "
|
||||
f"{metrics.sample_count:<10} {status}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Interpretation:")
|
||||
print(" - Brier Score: Lower is better (0 = perfect, 0.25 = random)")
|
||||
print(" - ECE (Expected Calibration Error): Lower is better (0 = perfect)")
|
||||
print(" - Models saved to: ai-engine/models/calibration/")
|
||||
print("=" * 70)
|
||||
print("=" * 78)
|
||||
print(f"Total samples across all markets: {total_samples}")
|
||||
print(f"\n{'market':<14} {'status':<10} {'n':<6} {'brier':<9} {'ece':<8} {'pred_avg':<9} {'actual_avg':<10}")
|
||||
print("-" * 78)
|
||||
for key, info in sorted(results.items()):
|
||||
if info["status"] == "trained":
|
||||
print(
|
||||
f"{key:<14} {'✓ ok':<10} {info['samples']:<6} "
|
||||
f"{info['brier']:<9.4f} {info['ece']:<8.4f} "
|
||||
f"{info['mean_predicted']:<8.3f} {info['mean_actual']:<8.3f}"
|
||||
)
|
||||
else:
|
||||
print(f"{key:<14} {'⊘ skip':<10} {info['samples']:<6} -- {info.get('reason', '')}")
|
||||
print("=" * 78)
|
||||
print("Trained models saved to: ai-engine/models/calibration/")
|
||||
print("Skipped markets fall back to the multiplier in market_thresholds.json.")
|
||||
print("=" * 78)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI
|
||||
# =============================================================================
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train calibration models")
|
||||
parser.add_argument("--start", type=str, default=DEFAULT_START_DATE,
|
||||
help="Start date (YYYY-MM-DD)")
|
||||
parser.add_argument("--end", type=str, default=DEFAULT_END_DATE,
|
||||
help="End date (YYYY-MM-DD)")
|
||||
parser = argparse.ArgumentParser(description="Train isotonic calibration on real data")
|
||||
parser.add_argument("--min-samples", type=int, default=30,
|
||||
help="Minimum samples required per market (default: 30)")
|
||||
parser.add_argument("--markets", nargs="+", default=None,
|
||||
help="Markets to train (default: all)")
|
||||
parser.add_argument("--min-samples", type=int, default=100,
|
||||
help="Minimum samples per market")
|
||||
parser.add_argument("--top-leagues-only", action="store_true",
|
||||
help="Only use top leagues data")
|
||||
|
||||
help="Limit to specific calibrator keys (e.g., ms_home ou25)")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"\n[Calibration Training] {args.start} to {args.end}")
|
||||
|
||||
# Load top leagues if requested
|
||||
league_ids = None
|
||||
if args.top_leagues_only:
|
||||
league_ids = load_top_league_ids()
|
||||
print(f"[Data] Filtering to {len(league_ids)} top leagues")
|
||||
|
||||
# Fetch data
|
||||
|
||||
conn = get_conn()
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
df = fetch_training_data(cur, args.start, args.end, league_ids)
|
||||
|
||||
if len(df) == 0:
|
||||
print("[Error] No data found for the specified date range")
|
||||
s1 = fetch_predictions_with_outcomes(cur)
|
||||
s2 = fetch_prediction_runs_with_outcomes(cur)
|
||||
print(f"[Data] predictions table: {len(s1)} samples")
|
||||
print(f"[Data] prediction_runs: {len(s2)} samples")
|
||||
all_samples = s1 + s2
|
||||
if not all_samples:
|
||||
print("[Error] No training samples available")
|
||||
return
|
||||
|
||||
# Calculate outcomes and probabilities
|
||||
df = calculate_actual_outcomes(df)
|
||||
df = calculate_implied_probabilities(df)
|
||||
|
||||
# Train models
|
||||
results = train_calibration_models(
|
||||
df=df,
|
||||
markets=args.markets,
|
||||
min_samples=args.min_samples,
|
||||
)
|
||||
|
||||
# Print report
|
||||
print_calibration_report(results)
|
||||
|
||||
df = pd.DataFrame(all_samples)
|
||||
print(f"[Data] Combined: {len(df)} samples")
|
||||
print(f"[Data] Unique matches: {df['match_id'].nunique()}")
|
||||
print(f"[Data] Per-key counts:")
|
||||
for key, count in df["key"].value_counts().items():
|
||||
print(f" {key:<14} {count}")
|
||||
|
||||
results = train_per_key(df, args.min_samples, args.markets)
|
||||
print_report(results, total_samples=len(df))
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
Reference in New Issue
Block a user