|
import pandas as pd |
|
import numpy as np |
|
from fastapi import HTTPException |
|
from models.train_model import ( |
|
load_and_preprocess_data, train_team_performance_model, train_player_score_model, |
|
predict_player_score, predict_team_performance |
|
) |
|
from groq import Groq |
|
|
|
|
|
TEAM_WIN_MODEL = None |
|
TEAM_SCORE_MODEL = None |
|
TEAM_DATA = None |
|
TEAM_SCALER = None |
|
PLAYER_SCORE_MODEL = None |
|
PLAYER_SCALER = None |
|
PLAYER_DATA = None |
|
MATCH_DF = None |
|
BALL_DF = None |
|
|
|
|
|
GROQ_API_KEY = "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" |
|
client = Groq(api_key=GROQ_API_KEY) |
|
|
|
|
|
def initialize_models(): |
|
global TEAM_WIN_MODEL, TEAM_SCORE_MODEL, TEAM_DATA, TEAM_SCALER |
|
global PLAYER_SCORE_MODEL, PLAYER_SCALER, PLAYER_DATA, MATCH_DF, BALL_DF |
|
|
|
MATCH_DF, BALL_DF = load_and_preprocess_data() |
|
TEAM_WIN_MODEL, TEAM_SCORE_MODEL, TEAM_DATA, TEAM_SCALER = train_team_performance_model(MATCH_DF) |
|
PLAYER_SCORE_MODEL, PLAYER_SCALER, PLAYER_DATA = train_player_score_model(MATCH_DF, BALL_DF) |
|
print("Models trained and loaded into memory.") |
|
|
|
|
|
initialize_models() |
|
|
|
|
|
player_team_mapping = BALL_DF.groupby('striker')['batting_team'].agg(lambda x: x.mode()[0] if len(x.mode()) > 0 else None).to_dict() |
|
|
|
|
|
def clean_json(data): |
|
if isinstance(data, dict): |
|
return {k: clean_json(v) for k, v in data.items()} |
|
elif isinstance(data, list): |
|
return [clean_json(v) for v in data] |
|
elif isinstance(data, float): |
|
return 0.0 if pd.isna(data) or np.isinf(data) else data |
|
elif pd.isna(data): |
|
return None |
|
elif isinstance(data, pd.Timestamp): |
|
return data.strftime('%Y-%m-%d') if pd.notna(data) else None |
|
elif isinstance(data, (int, bool)): |
|
return data |
|
return str(data) |
|
|
|
|
|
def generate_summary(data, context_type): |
|
prompt = "" |
|
if context_type == "player_stats": |
|
prompt = f"Summarize this player data in one sentence: {data}" |
|
elif context_type == "team_stats": |
|
prompt = f"Summarize this team data in one sentence: {data}" |
|
elif context_type == "match_history": |
|
prompt = f"Summarize this match history between {data['team1']} and {data['team2']} in one sentence: {data['matches']}" |
|
elif context_type == "prediction_score": |
|
prompt = f"Summarize this prediction in one sentence: {data}" |
|
elif context_type == "prediction_team": |
|
prompt = f"Summarize this team prediction in one sentence: {data}" |
|
|
|
try: |
|
chat_completion = client.chat.completions.create( |
|
model="mixtral-8x7b-32768", |
|
messages=[ |
|
{"role": "system", "content": "You are a concise cricket analyst."}, |
|
{"role": "user", "content": prompt} |
|
], |
|
max_tokens=50, |
|
temperature=0.7 |
|
) |
|
return chat_completion.choices[0].message.content.strip() |
|
except Exception as e: |
|
return f"Summary unavailable due to error: {str(e)}" |
|
|
|
|
|
def get_player_stats(player_name: str, season: str = None, role: str = "Batting"): |
|
player_name = player_name.strip().title() |
|
name_variations = [player_name, player_name.replace(" ", ""), " ".join(reversed(player_name.split()))] |
|
player_data = BALL_DF[BALL_DF['striker'].isin(name_variations) | BALL_DF['bowler'].isin(name_variations)] |
|
if season and 'season' in BALL_DF.columns: |
|
player_data = player_data[player_data['season'] == season] |
|
if player_data.empty: |
|
raise HTTPException(status_code=404, detail=f"Player '{player_name}' not found. Variations tried: {name_variations}") |
|
|
|
if role == "Batting": |
|
batting_data = player_data[player_data['striker'].isin(name_variations)] |
|
total_runs = int(batting_data['runs_off_bat'].sum()) |
|
balls_faced = int(batting_data.shape[0]) |
|
strike_rate = float((total_runs / balls_faced * 100) if balls_faced > 0 else 0) |
|
matches_played = int(len(batting_data['match_id'].unique())) |
|
|
|
stats = { |
|
"player_name": player_name, |
|
"role": role, |
|
"total_runs": total_runs, |
|
"balls_faced": balls_faced, |
|
"strike_rate": strike_rate, |
|
"matches_played": matches_played, |
|
"season": season if season else "All Seasons" |
|
} |
|
stats["summary"] = generate_summary(stats, "player_stats") |
|
return clean_json(stats) |
|
|
|
elif role == "Bowling": |
|
bowling_data = player_data[player_data['bowler'].isin(name_variations)] |
|
bowler_wicket_types = ["caught", "bowled", "lbw", "caught and bowled", "hit wicket"] |
|
wickets_data = bowling_data[bowling_data['player_dismissed'].notna() & |
|
bowling_data['wicket_type'].isin(bowler_wicket_types)] |
|
total_wickets = int(wickets_data.shape[0]) |
|
total_runs_conceded = int(bowling_data['total_runs'].sum()) |
|
total_balls_bowled = int(bowling_data.shape[0]) |
|
total_overs_bowled = float(total_balls_bowled / 6) |
|
bowling_average = float(total_runs_conceded / total_wickets) if total_wickets > 0 else float('inf') |
|
economy_rate = float(total_runs_conceded / total_overs_bowled) if total_overs_bowled > 0 else 0 |
|
bowling_strike_rate = float(total_balls_bowled / total_wickets) if total_wickets > 0 else float('inf') |
|
bowling_matches = int(len(bowling_data['match_id'].unique())) |
|
|
|
stats = { |
|
"player_name": player_name, |
|
"role": role, |
|
"total_wickets": total_wickets, |
|
"bowling_average": 0.0 if np.isinf(bowling_average) else round(bowling_average, 2), |
|
"economy_rate": round(economy_rate, 2), |
|
"bowling_strike_rate": 0.0 if np.isinf(bowling_strike_rate) else round(bowling_strike_rate, 2), |
|
"overs_bowled": round(total_overs_bowled, 1), |
|
"bowling_matches": bowling_matches, |
|
"season": season if season else "All Seasons" |
|
} |
|
stats["summary"] = generate_summary(stats, "player_stats") |
|
return clean_json(stats) |
|
|
|
|
|
def get_team_stats(team_name: str, season: str = None): |
|
team_name = team_name.strip().title() |
|
team_matches = MATCH_DF[(MATCH_DF['team1'] == team_name) | (MATCH_DF['team2'] == team_name)] |
|
if season and 'season' in MATCH_DF.columns: |
|
team_matches = team_matches[team_matches['season'] == season] |
|
if team_matches.empty: |
|
raise HTTPException(status_code=404, detail="Team not found") |
|
|
|
wins = int(team_matches[team_matches['winner'] == team_name].shape[0]) |
|
total_matches = int(team_matches.shape[0]) |
|
|
|
stats = { |
|
"total_matches": total_matches, |
|
"wins": wins, |
|
"losses": total_matches - wins, |
|
"win_percentage": float((wins / total_matches * 100) if total_matches > 0 else 0), |
|
"season": season if season else "All Seasons" |
|
} |
|
stats["summary"] = generate_summary(stats, "team_stats") |
|
return clean_json(stats) |
|
|
|
|
|
def get_match_history(team1: str, team2: str, season: str = None): |
|
team1 = team1.strip().title() |
|
team2 = team2.strip().title() |
|
available_teams = set(MATCH_DF['team1'].unique().tolist() + MATCH_DF['team2'].unique().tolist()) |
|
if team1 not in available_teams or team2 not in available_teams: |
|
raise HTTPException(status_code=404, detail=f"Team {team1 if team1 not in available_teams else team2} not found.") |
|
|
|
team_matches = MATCH_DF[ |
|
((MATCH_DF['team1'] == team1) & (MATCH_DF['team2'] == team2)) | |
|
((MATCH_DF['team1'] == team2) & (MATCH_DF['team2'] == team1)) |
|
].copy() |
|
if season and 'season' in MATCH_DF.columns: |
|
team_matches = team_matches[team_matches['season'] == season] |
|
if team_matches.empty: |
|
raise HTTPException(status_code=404, detail=f"No match history found between {team1} and {team2}.") |
|
|
|
team_matches['date'] = team_matches['date'].apply(lambda x: x.strftime('%Y-%m-%d') if pd.notna(x) else None) |
|
team_matches['winner'] = team_matches['winner'].fillna("Draw") |
|
for column in ['team1', 'team2', 'winner']: |
|
team_matches[column] = team_matches[column].apply(lambda x: str(x) if pd.notna(x) else None) |
|
history = team_matches[['date', 'team1', 'team2', 'winner']].to_dict(orient='records') |
|
|
|
response = { |
|
"team1": team1, |
|
"team2": team2, |
|
"season": season if season else "All Seasons", |
|
"matches": history |
|
} |
|
response["summary"] = generate_summary(response, "match_history") |
|
return clean_json(response) |
|
|
|
|
|
def predict_score(player_name: str, opposition_team: str): |
|
try: |
|
player_name = player_name.strip().replace("+", " ").title() |
|
name_variations = [player_name, player_name.replace(" ", ""), " ".join(reversed(player_name.split()))] |
|
player_team = None |
|
for name in name_variations: |
|
if name in player_team_mapping: |
|
player_team = player_team_mapping[name] |
|
player_name = name |
|
break |
|
if not player_team: |
|
raise ValueError(f"Player {player_name} not found in historical data") |
|
|
|
predicted_runs = predict_player_score( |
|
player=player_name, |
|
team=player_team, |
|
opponent=opposition_team, |
|
venue=None, |
|
city=None, |
|
toss_winner=None, |
|
toss_decision=None, |
|
score_model=PLAYER_SCORE_MODEL, |
|
scaler=PLAYER_SCALER, |
|
player_data=PLAYER_DATA |
|
) |
|
stats = { |
|
"player": player_name, |
|
"team": player_team, |
|
"opposition": opposition_team, |
|
"predicted_runs": predicted_runs["expected_score"] |
|
} |
|
stats["summary"] = generate_summary(stats, "prediction_score") |
|
return clean_json(stats) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error predicting score for {player_name} against {opposition_team}: {str(e)}") |
|
|
|
def predict_team_outcome(team1: str, team2: str): |
|
prediction = predict_team_performance( |
|
team1=team1, |
|
team2=team2, |
|
venue=None, |
|
city=None, |
|
toss_winner=None, |
|
toss_decision=None, |
|
win_model=TEAM_WIN_MODEL, |
|
score_model=TEAM_SCORE_MODEL, |
|
data=TEAM_DATA, |
|
scaler=TEAM_SCALER |
|
) |
|
prediction["summary"] = generate_summary(prediction, "prediction_team") |
|
return clean_json(prediction) |
|
|
|
|
|
def get_teams(): |
|
return clean_json({"teams": sorted(set(MATCH_DF['team1'].unique().tolist() + MATCH_DF['team2'].unique().tolist()))}) |
|
|
|
def get_players(): |
|
unique_players = sorted(set(BALL_DF['striker'].dropna().unique().tolist())) |
|
return clean_json({"players": unique_players}) |
|
|
|
def get_seasons(): |
|
return clean_json({"seasons": ["All Seasons"] + sorted(MATCH_DF['season'].dropna().unique().tolist())}) |
|
|
|
|
|
def get_team_trends(team_name: str): |
|
team_name = team_name.strip().title() |
|
team_matches = MATCH_DF[(MATCH_DF['team1'] == team_name) | (MATCH_DF['team2'] == team_name)] |
|
if team_matches.empty: |
|
raise HTTPException(status_code=404, detail="Team not found") |
|
|
|
trends = [] |
|
for season in MATCH_DF['season'].unique(): |
|
season_matches = team_matches[team_matches['season'] == season] |
|
if not season_matches.empty: |
|
wins = season_matches[season_matches['winner'] == team_name].shape[0] |
|
total_matches = season_matches.shape[0] |
|
win_percentage = (wins / total_matches * 100) if total_matches > 0 else 0 |
|
trends.append({ |
|
"season": season, |
|
"wins": wins, |
|
"total_matches": total_matches, |
|
"win_percentage": win_percentage |
|
}) |
|
|
|
return {"team_name": team_name, "trends": trends} |
|
|
|
|
|
def get_player_trends(player_name: str, role: str = "Batting"): |
|
player_name = player_name.strip().title() |
|
name_variations = [player_name, player_name.replace(" ", ""), " ".join(reversed(player_name.split()))] |
|
player_data = BALL_DF[BALL_DF['striker'].isin(name_variations) | BALL_DF['bowler'].isin(name_variations)] |
|
if player_data.empty: |
|
raise HTTPException(status_code=404, detail=f"Player '{player_name}' not found") |
|
|
|
trends = [] |
|
for season in BALL_DF['season'].unique(): |
|
season_data = player_data[player_data['season'] == season] |
|
if not season_data.empty: |
|
if role == "Batting": |
|
total_runs = int(season_data['runs_off_bat'].sum()) |
|
balls_faced = int(season_data.shape[0]) |
|
strike_rate = float((total_runs / balls_faced * 100) if balls_faced > 0 else 0) |
|
matches_played = int(len(season_data['match_id'].unique())) |
|
trends.append({ |
|
"season": season, |
|
"total_runs": total_runs, |
|
"strike_rate": strike_rate, |
|
"matches_played": matches_played |
|
}) |
|
elif role == "Bowling": |
|
total_wickets = int(season_data[season_data['wicket_type'].notna()].shape[0]) |
|
total_runs_conceded = int(season_data['total_runs'].sum()) |
|
total_overs_bowled = float(season_data.shape[0] / 6) |
|
bowling_average = float(total_runs_conceded / total_wickets) if total_wickets > 0 else float('inf') |
|
economy_rate = float(total_runs_conceded / total_overs_bowled) if total_overs_bowled > 0 else 0 |
|
matches_played = int(len(season_data['match_id'].unique())) |
|
trends.append({ |
|
"season": season, |
|
"total_wickets": total_wickets, |
|
"bowling_average": bowling_average, |
|
"economy economy_rate": economy_rate, |
|
"matches_played": matches_played |
|
}) |
|
|
|
return {"player_name": player_name, "role": role, "trends": trends} |