This commit is contained in:
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
XGBoost Market Model Trainer (Basketball)
|
||||
=========================================
|
||||
Trains specialized XGBoost models for basketball betting markets.
|
||||
Models:
|
||||
1. ML (Match Result) - Binary (Home Win / Away Win)
|
||||
2. Totals (Over/Under) - Binary (Over / Under dynamic line)
|
||||
3. Spread (Handicap) - Binary (Home Cover / Away Cover)
|
||||
|
||||
Usage:
|
||||
python3 scripts/train_basketball_markets.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
|
||||
|
||||
# Config
|
||||
AI_ENGINE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
DATA_PATH = os.path.join(AI_ENGINE_DIR, "data", "basketball_training_data.csv")
|
||||
MODELS_DIR = os.path.join(AI_ENGINE_DIR, "models", "xgboost", "basketball")
|
||||
|
||||
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||
|
||||
# Feature Columns
|
||||
FEATURES = [
|
||||
# Form
|
||||
"home_points_avg", "home_conceded_avg",
|
||||
"away_points_avg", "away_conceded_avg",
|
||||
"home_winning_streak", "away_winning_streak",
|
||||
"home_win_rate", "away_win_rate",
|
||||
|
||||
# H2H
|
||||
"h2h_total_matches", "h2h_home_win_rate",
|
||||
"h2h_avg_points", "h2h_over140_rate",
|
||||
|
||||
# Odds
|
||||
"odds_ml_h", "odds_ml_a",
|
||||
"odds_tot_o", "odds_tot_u", "odds_tot_line",
|
||||
"odds_spread_h", "odds_spread_a", "odds_spread_line"
|
||||
]
|
||||
|
||||
def load_data():
|
||||
if not os.path.exists(DATA_PATH):
|
||||
print(f"❌ Data file not found: {DATA_PATH}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"📦 Loading data from {DATA_PATH}...")
|
||||
df = pd.read_csv(DATA_PATH)
|
||||
df.fillna(0, inplace=True)
|
||||
print(f" Shape: {df.shape}")
|
||||
return df
|
||||
|
||||
def train_binary_model(df, target_col, model_name):
|
||||
"""Generic trainer for Binary XGBoost models (ML, Totals, Spread)."""
|
||||
print(f"\n🚀 Training {model_name} (Target: {target_col})...")
|
||||
|
||||
valid_df = df[df[target_col].notna()].copy()
|
||||
if valid_df.empty:
|
||||
print(f" ⚠️ No valid data for {target_col}, skipping.")
|
||||
return
|
||||
|
||||
X = valid_df[FEATURES]
|
||||
y = valid_df[target_col].astype(int)
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
params = {
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss',
|
||||
'eta': 0.05,
|
||||
'max_depth': 6,
|
||||
'subsample': 0.8,
|
||||
'colsample_bytree': 0.8,
|
||||
'nthread': 4,
|
||||
'seed': 42
|
||||
}
|
||||
|
||||
model = xgb.XGBClassifier(**params, n_estimators=1000, early_stopping_rounds=50)
|
||||
|
||||
model.fit(
|
||||
X_train, y_train,
|
||||
eval_set=[(X_test, y_test)],
|
||||
verbose=False
|
||||
)
|
||||
|
||||
y_pred = model.predict(X_test)
|
||||
y_prob = model.predict_proba(X_test)[:, 1]
|
||||
|
||||
acc = accuracy_score(y_test, y_pred)
|
||||
try:
|
||||
auc = roc_auc_score(y_test, y_prob)
|
||||
except:
|
||||
auc = 0.0
|
||||
|
||||
print(f" ✅ Finished! Best Iteration: {model.best_iteration}")
|
||||
print(f" 📊 Accuracy: {acc:.4f} | ROC AUC: {auc:.4f}")
|
||||
print(classification_report(y_test, y_pred, zero_division=0))
|
||||
|
||||
# Save Model
|
||||
model_path = os.path.join(MODELS_DIR, f"{model_name}.pkl")
|
||||
with open(model_path, "wb") as f:
|
||||
pickle.dump(model, f)
|
||||
print(f" 💾 Saved to {model_path}")
|
||||
|
||||
# Save Top Features
|
||||
try:
|
||||
booster = model.get_booster()
|
||||
importance = booster.get_score(importance_type="gain")
|
||||
sorted_imp = sorted(importance.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
print(" 🔍 Top 5 Features (Gain):")
|
||||
for ft, score in sorted_imp:
|
||||
print(f" - {ft}: {score:.2f}")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Could not extract feature importance: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
df = load_data()
|
||||
|
||||
# 1. Moneyline (ML) Model -> Targets Home Win (0) vs Away Win (1)
|
||||
train_binary_model(df, "label_ml", "basketball_ml_v1")
|
||||
|
||||
# 2. Totals (Over/Under) Model -> Targets Under (0) vs Over (1) against 'odds_tot_line'
|
||||
train_binary_model(df, "label_tot", "basketball_tot_v1")
|
||||
|
||||
# 3. Spread (Handicap) Model -> Targets Away Cover (0) vs Home Cover (1) against 'odds_spread_line'
|
||||
train_binary_model(df, "label_spread", "basketball_spread_v1")
|
||||
|
||||
print("\n🎉 All Basketball Models Trained Successfully!")
|
||||
Reference in New Issue
Block a user