Spaces:
Running
Running
from typing import Annotated, Dict, List | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, Header, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from mistralai import Mistral | |
import hackathon.agent.arbitrary as ar | |
import hackathon.game_mechanics.entities as ent | |
import hackathon.game_mechanics.pre_game_mechanics as pre | |
from hackathon.agent.arbitrary import EmotionAgent | |
from hackathon.agent.character import AIAgent | |
from hackathon.agent.engagement import Engagement | |
from hackathon.agent.presenter import Presenter | |
from hackathon.config import settings | |
from hackathon.server.schemas import ( | |
CardsVoiceRequest, | |
CardsVoiceResponse, | |
InferenceRequest, | |
InferenceResponse, | |
StartRequest, | |
StartResponse, | |
) | |
from hackathon.speech.speech import read_audio_config, text_to_speech_file | |
load_dotenv() | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Disable CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
app.state.games = {} | |
class GameEngine: | |
def __init__( | |
self, | |
candidate_1_name: str, | |
candidate_2_name: str, | |
api_key: str = settings.MISTRAL_API_KEY, | |
model_name: str = "mistral-large-latest", | |
): | |
self.model_name = model_name | |
self.api_key = api_key | |
candidate_1_yaml = ( | |
settings.API_BASE_PATH / "config" / f"{candidate_1_name}.yaml" | |
) | |
candidate_2_yaml = ( | |
settings.API_BASE_PATH / "config" / f"{candidate_2_name}.yaml" | |
) | |
self.audio_yaml = settings.API_BASE_PATH / "config" / "audio.yaml" | |
self.data_folder = settings.API_BASE_PATH / "data" | |
context_yaml = settings.API_BASE_PATH / "config" / "context.yaml" | |
cards_trump_yaml = settings.API_BASE_PATH / "config" / "cards_trump.yaml" | |
cards_kamala_yaml = settings.API_BASE_PATH / "config" / "cards_kamala.yaml" | |
cards_neutral_yaml = settings.API_BASE_PATH / "config" / "cards_neutral.yaml" | |
self.client = Mistral(api_key=api_key) | |
emotion_agent = EmotionAgent(self.client, model=self.model_name) | |
self.candidate_1 = AIAgent.from_yaml( | |
candidate_1_yaml, context_yaml, self.client, emotion_agent | |
) | |
# generate_background_personality(self.candidate_1, self.client) | |
self.candidate_2 = AIAgent.from_yaml( | |
candidate_2_yaml, context_yaml, self.client, emotion_agent | |
) | |
# generate_background_personality(self.candidate_2, self.client) | |
self.engagement = Engagement() | |
self.presenter = Presenter( | |
self.candidate_1.general_context, self.client, model_name | |
) | |
card_agent = ar.CardAgent(self.client, model="mistral-large-latest") | |
self.deck = ent.Deck(cards_trump_yaml, cards_kamala_yaml, cards_neutral_yaml) | |
self.deck.sample() | |
pre.add_cards_to_personal_context_full_prompt( | |
card_agent, [self.candidate_1, self.candidate_2], self.deck | |
) | |
self.audio_config = read_audio_config(self.audio_yaml) | |
self.timestamp = 0 | |
async def start(request: StartRequest, game_id: Annotated[str | None, Header()] = None): | |
# game_id = " qsdqsd" | |
if game_id is None: | |
raise HTTPException( | |
status_code=400, detail="Game ID not provided in the header." | |
) | |
app.state.games[game_id] = GameEngine( | |
candidate_1_name=request.candidate_1_name, | |
candidate_2_name=request.candidate_2_name, | |
) | |
print(f"Created new game ({game_id})") | |
return {"status": "Game engine initialized successfully"} | |
async def infer( | |
request: InferenceRequest, game_id: Annotated[str | None, Header()] = None | |
): | |
if game_id is None: | |
raise HTTPException( | |
status_code=400, detail="Game ID not provided in the header." | |
) | |
elif game_id not in app.state.games: | |
raise HTTPException( | |
status_code=400, detail="Game engine not initialized. Call /start first." | |
) | |
game_engine = app.state.games[game_id] | |
game_engine.timestamp += 1 | |
data_folder = game_engine.data_folder | |
if request.current_speaker == game_engine.candidate_1.name: | |
current_speaker = game_engine.candidate_1 | |
elif request.current_speaker == game_engine.candidate_2.name: | |
current_speaker = game_engine.candidate_2 | |
else: | |
raise ValueError("Candidate name requested do not exist.") | |
current_audio_config = game_engine.audio_config[current_speaker.name] | |
input_text = f"{request.previous_speaker} said :{request.previous_character_text}. You have to respond to {request.previous_speaker}. Limit to less than 50 words." | |
current_speaker.update_emotions(input_text) | |
msg = current_speaker.respond(input_text) | |
audio_signal = text_to_speech_file( | |
text=msg, | |
voice_id=current_audio_config["voice_id"], | |
stability=current_audio_config["stability"], | |
similarity=current_audio_config["similarity"], | |
style=current_audio_config["style"], | |
base_path=str(data_folder), | |
) | |
return { | |
"generated_text": msg, | |
"anger": current_speaker.emotions["anger"], | |
"audio": audio_signal, | |
} | |
async def engagement( | |
game_id: Annotated[str | None, Header()] = None, | |
): | |
if game_id is None: | |
raise HTTPException( | |
status_code=400, detail="Game ID not provided in the header." | |
) | |
elif game_id not in app.state.games: | |
raise HTTPException( | |
status_code=400, detail="Game engine not initialized. Call /start first." | |
) | |
game_engine = app.state.games[game_id] | |
if game_engine.timestamp > game_engine.engagement.timestamp: | |
candidate_1_anger = game_engine.candidate_1.emotions["anger"] | |
candidate_2_anger = game_engine.candidate_2.emotions["anger"] | |
game_engine.engagement.update(candidate_1_anger, candidate_2_anger) | |
value = game_engine.engagement.current_value | |
else: | |
value = game_engine.engagement.current_value | |
return {"engagement": value} | |
async def cards( | |
request: CardsVoiceRequest, | |
game_id: Annotated[str | None, Header()] = None, | |
): | |
""" | |
WARNING CARDS HAVE AN IMPACT HERE | |
""" | |
if game_id is None: | |
raise HTTPException( | |
status_code=400, detail="Game ID not provided in the header." | |
) | |
game_engine = app.state.games.get(game_id, None) | |
if game_engine is None: | |
raise HTTPException( | |
status_code=400, detail="Game engine not initialized. Call /start first." | |
) | |
game_engine = app.state.games[game_id] | |
game_engine.timestamp += 1 | |
presenter = game_engine.presenter | |
last_text = request.previous_character_text | |
previous_speaker_name = request.previous_speaker | |
if previous_speaker_name == game_engine.candidate_1.name: | |
next_speaker = game_engine.candidate_2 | |
last_speaker = game_engine.candidate_1 | |
elif previous_speaker_name == game_engine.candidate_2.name: | |
next_speaker = game_engine.candidate_1 | |
last_speaker = game_engine.candidate_2 | |
elif previous_speaker_name == "player": | |
next_speaker = game_engine.candidate_2 | |
last_speaker = game_engine.candidate_1 | |
else: | |
raise ValueError(f"{previous_speaker_name} is not known!!") | |
card_id = request.card_id # WARNING!!!! CHECK THE FORMAT | |
card = game_engine.deck.all_cards[card_id] | |
current_audio_config = game_engine.audio_config["chairman"] | |
msg = presenter.play_card(card, last_text, last_speaker, next_speaker) | |
data_folder = game_engine.data_folder | |
audio_signal = text_to_speech_file( | |
text=msg, | |
voice_id=current_audio_config["voice_id"], | |
stability=current_audio_config["stability"], | |
similarity=current_audio_config["similarity"], | |
style=current_audio_config["style"], | |
base_path=str(data_folder), | |
) | |
return {"presenter_question": msg, "audio": audio_signal} | |
async def cards_request( | |
request: Request, game_id: Annotated[str | None, Header()] = None | |
): | |
if game_id is None: | |
raise HTTPException( | |
status_code=400, detail="Game ID not provided in the header." | |
) | |
elif game_id not in app.state.games: | |
raise HTTPException( | |
status_code=400, detail="Game engine not initialized. Call /start first." | |
) | |
game_engine = app.state.games[game_id] | |
cards_list = game_engine.deck.to_list() | |
return cards_list | |