import requests import config class SmartRouter: def __init__(self): # Tracks what is currently loaded in the backend (Best Guess) self.lane_state = { "primary": None, # URL: ...-generate-primary.modal.run "secondary": None # URL: ...-generate-secondary.modal.run } def get_routing_plan(self, model_left_id, model_right_id): """ Decides which model goes to which lane to minimize cold starts. Returns: (lane_for_left_model, lane_for_right_model) """ primary_model = self.lane_state["primary"] secondary_model = self.lane_state["secondary"] # Score: 0 = Cache Hit (Good), 1 = Cache Miss (Bad) # Option A: Straight (Left -> Primary, Right -> Secondary) cost_straight = (0 if primary_model == model_left_id else 1) + \ (0 if secondary_model == model_right_id else 1) # Option B: Swapped (Left -> Secondary, Right -> Primary) cost_swapped = (0 if secondary_model == model_left_id else 1) + \ (0 if primary_model == model_right_id else 1) if cost_swapped < cost_straight: print(f"🔀 Smart Router: Swapping lanes to optimize cache!") # Update state for next time self.lane_state["secondary"] = model_left_id self.lane_state["primary"] = model_right_id return "secondary", "primary" else: print(f"⬇️ Smart Router: keeping straight lanes.") # Update state for next time self.lane_state["primary"] = model_left_id self.lane_state["secondary"] = model_right_id return "primary", "secondary" # Create a global instance router = SmartRouter() # --- STEP 3: REWRITE call_modal_api FOR STREAMING --- def call_modal_api(model_repo_id, prompt, lane): """ Calls the Modal API on a specific lane and yields tokens as they arrive. This is now a GENERATOR. """ if not model_repo_id: yield "Please select a model from the dropdown." return # Stop the generator if not config.MY_AUTH_TOKEN: yield "Error: `ARENA_AUTH_TOKEN` is not set on the Gradio server." return # Construct the URL based on the lane if lane == "primary": endpoint = f"{config.MODAL_BASE_URL}-generate-primary.modal.run" else: endpoint = f"{config.MODAL_BASE_URL}-generate-secondary.modal.run" print(f"🚀 Streaming from {model_repo_id} on [{lane.upper()}]...") headers = { "Content-Type": "application/json", "Authorization": f"Bearer {config.MY_AUTH_TOKEN}" } payload = {"model_id": model_repo_id, "prompt": prompt} try: # stream=True is the magic. response = requests.post( endpoint, json=payload, timeout=300, headers=headers, stream=True ) response.raise_for_status() # Yield tokens as they arrive for chunk in response.iter_content(chunk_size=None, decode_unicode=True): if chunk: yield chunk except requests.exceptions.RequestException as e: if e.response and e.response.status_code == 401: yield "Error: Authentication failed. The token is invalid." elif e.response: # Try to get error detail from the streaming API try: error_detail = e.response.json().get("detail", str(e)) yield f"API Error: {e.response.status_code} - {error_detail}" except: yield f"API Error: {e}" else: yield f"API Error: {e}" except Exception as e: yield f"An unexpected error occurred: {e}"