This commit is contained in:
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
XGBoost Model Training (Advanced Basketball V21)
|
||||
================================================
|
||||
Trains XGBoost models for Match Winner (ML), Totals (O/U), and Spread.
|
||||
Builds upon 60+ deep tactical features (Rebounds, FG%, Q1/Q2 pacing, advanced odds).
|
||||
|
||||
Usage:
|
||||
python3 scripts/train_advanced_basketball.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||
from datetime import datetime
|
||||
|
||||
# Configuration
|
||||
AI_ENGINE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, AI_ENGINE_DIR)
|
||||
|
||||
DATA_PATH = os.path.join(AI_ENGINE_DIR, "data", "advanced_basketball_training_data.csv")
|
||||
MODEL_DIR = os.path.join(AI_ENGINE_DIR, "models", "bin")
|
||||
|
||||
os.makedirs(MODEL_DIR, exist_ok=True)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Deep Statistical Feature Matrix (54 Features)
|
||||
# -----------------------------------------------------------------------------
|
||||
FEATURES = [
|
||||
# Form
|
||||
"home_winning_streak", "away_winning_streak",
|
||||
"home_win_rate", "away_win_rate",
|
||||
|
||||
# Home Team Offense
|
||||
"home_pts_avg", "home_reb_avg", "home_ast_avg", "home_stl_avg", "home_blk_avg", "home_tov_avg",
|
||||
"home_fg_pct", "home_3pt_pct", "home_ft_pct",
|
||||
"home_q1_avg", "home_q2_avg", "home_q3_avg", "home_q4_avg",
|
||||
|
||||
# Home Team Defense
|
||||
"home_conc_pts", "home_conc_reb", "home_conc_ast", "home_conc_tov",
|
||||
"home_conc_fg_pct", "home_conc_3pt_pct",
|
||||
|
||||
# Away Team Offense
|
||||
"away_pts_avg", "away_reb_avg", "away_ast_avg", "away_stl_avg", "away_blk_avg", "away_tov_avg",
|
||||
"away_fg_pct", "away_3pt_pct", "away_ft_pct",
|
||||
"away_q1_avg", "away_q2_avg", "away_q3_avg", "away_q4_avg",
|
||||
|
||||
# Away Team Defense
|
||||
"away_conc_pts", "away_conc_reb", "away_conc_ast", "away_conc_tov",
|
||||
"away_conc_fg_pct", "away_conc_3pt_pct",
|
||||
|
||||
# H2H Features
|
||||
"h2h_total_matches", "h2h_home_win_rate",
|
||||
"h2h_avg_points", "h2h_over140_rate",
|
||||
|
||||
# Odds Features
|
||||
"odds_ml_h", "odds_ml_a",
|
||||
"odds_tot_o", "odds_tot_u", "odds_tot_line",
|
||||
"odds_spread_h", "odds_spread_a", "odds_spread_line",
|
||||
]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Core Training Function
|
||||
# -----------------------------------------------------------------------------
|
||||
def train_model(df, target_col, model_name, params=None):
|
||||
print(f"\n--- Training {model_name} ---")
|
||||
|
||||
# For Totals and Spread we need to drop purely empty lines if odds aren't matched
|
||||
if target_col in ["label_tot", "label_spread"]:
|
||||
# If line implies 0 and wasn't populated heavily, we may want to skip
|
||||
if target_col == "label_tot":
|
||||
df_filtered = df[(df["odds_tot_line"] > 50) & (df["odds_tot_line"] < 300)].copy()
|
||||
elif target_col == "label_spread":
|
||||
df_filtered = df[(abs(df["odds_spread_line"]) > 0.0) | (df["odds_spread_h"] != 1.9)].copy()
|
||||
else:
|
||||
df_filtered = df.copy()
|
||||
|
||||
X = df_filtered[FEATURES]
|
||||
y = df_filtered[target_col]
|
||||
|
||||
print(f"Data Shape: {X.shape}")
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
|
||||
|
||||
# Defaults for XGBoost
|
||||
if params is None:
|
||||
params = {
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss',
|
||||
'max_depth': 6,
|
||||
'learning_rate': 0.05,
|
||||
'n_estimators': 300,
|
||||
'subsample': 0.8,
|
||||
'colsample_bytree': 0.8,
|
||||
'random_state': 42
|
||||
}
|
||||
|
||||
clf = xgb.XGBClassifier(**params)
|
||||
clf.fit(
|
||||
X_train, y_train,
|
||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||
verbose=50
|
||||
)
|
||||
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
acc = accuracy_score(y_test, y_pred)
|
||||
prec = precision_score(y_test, y_pred, zero_division=0)
|
||||
rec = recall_score(y_test, y_pred, zero_division=0)
|
||||
|
||||
print(f"\n[{model_name}] Metrics:")
|
||||
print(f"Accuracy : {acc:.4f}")
|
||||
if len(np.unique(y_train)) == 2:
|
||||
print(f"Precision: {prec:.4f}")
|
||||
print(f"Recall : {rec:.4f}")
|
||||
|
||||
# Display Top 10 Feature Importances
|
||||
importances = clf.feature_importances_
|
||||
sorted_idx = np.argsort(importances)[::-1]
|
||||
print("\nTop 10 Feature Importances:")
|
||||
for i in range(10):
|
||||
print(f" {i+1}. {FEATURES[sorted_idx[i]]}: {importances[sorted_idx[i]]:.4f}")
|
||||
|
||||
# Save
|
||||
save_path = os.path.join(MODEL_DIR, f"{model_name}.json")
|
||||
clf.save_model(save_path)
|
||||
print(f"Saved to: {save_path}")
|
||||
return clf
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.path.exists(DATA_PATH):
|
||||
print(f"ERROR: Training data not found at {DATA_PATH}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading data from {DATA_PATH}")
|
||||
df = pd.read_csv(DATA_PATH)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 1. Match Winner (Moneyline)
|
||||
# ---------------------------------------------------------
|
||||
ml_params = {
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss',
|
||||
'max_depth': 5,
|
||||
'learning_rate': 0.03,
|
||||
'n_estimators': 250,
|
||||
'subsample': 0.85,
|
||||
'colsample_bytree': 0.8,
|
||||
'random_state': 42
|
||||
}
|
||||
train_model(df, "label_ml", "basketball_v21_ml", ml_params)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 2. Match Totals (Over / Under)
|
||||
# ---------------------------------------------------------
|
||||
# Finding O/U against dynamic line needs complex relationships
|
||||
tot_params = {
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss',
|
||||
'max_depth': 6,
|
||||
'learning_rate': 0.05,
|
||||
'n_estimators': 350,
|
||||
'subsample': 0.8,
|
||||
'colsample_bytree': 0.8,
|
||||
'random_state': 42
|
||||
}
|
||||
train_model(df, "label_tot", "basketball_v21_tot", tot_params)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 3. Spread (Handicap Cover)
|
||||
# ---------------------------------------------------------
|
||||
spread_params = {
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss',
|
||||
'max_depth': 6,
|
||||
'learning_rate': 0.04,
|
||||
'n_estimators': 300,
|
||||
'subsample': 0.8,
|
||||
'colsample_bytree': 0.8,
|
||||
'random_state': 42
|
||||
}
|
||||
train_model(df, "label_spread", "basketball_v21_spread", spread_params)
|
||||
|
||||
print("\n🏁 Advanced V21 Basketball Models trained successfully.")
|
||||
Reference in New Issue
Block a user