Spaces:
Running
Running
| import requests | |
| import config | |
| class SmartRouter: | |
| def __init__(self): | |
| # Tracks what is currently loaded in the backend (Best Guess) | |
| self.lane_state = { | |
| "primary": None, # URL: ...-generate-primary.modal.run | |
| "secondary": None # URL: ...-generate-secondary.modal.run | |
| } | |
| def get_routing_plan(self, model_left_id, model_right_id): | |
| """ | |
| Decides which model goes to which lane to minimize cold starts. | |
| Returns: (lane_for_left_model, lane_for_right_model) | |
| """ | |
| primary_model = self.lane_state["primary"] | |
| secondary_model = self.lane_state["secondary"] | |
| # Score: 0 = Cache Hit (Good), 1 = Cache Miss (Bad) | |
| # Option A: Straight (Left -> Primary, Right -> Secondary) | |
| cost_straight = (0 if primary_model == model_left_id else 1) + \ | |
| (0 if secondary_model == model_right_id else 1) | |
| # Option B: Swapped (Left -> Secondary, Right -> Primary) | |
| cost_swapped = (0 if secondary_model == model_left_id else 1) + \ | |
| (0 if primary_model == model_right_id else 1) | |
| if cost_swapped < cost_straight: | |
| print(f"π Smart Router: Swapping lanes to optimize cache!") | |
| # Update state for next time | |
| self.lane_state["secondary"] = model_left_id | |
| self.lane_state["primary"] = model_right_id | |
| return "secondary", "primary" | |
| else: | |
| print(f"β¬οΈ Smart Router: keeping straight lanes.") | |
| # Update state for next time | |
| self.lane_state["primary"] = model_left_id | |
| self.lane_state["secondary"] = model_right_id | |
| return "primary", "secondary" | |
| # Create a global instance | |
| router = SmartRouter() | |
| # --- STEP 3: REWRITE call_modal_api FOR STREAMING --- | |
| def call_modal_api(model_repo_id, prompt, lane): | |
| """ | |
| Calls the Modal API on a specific lane and yields tokens as they arrive. | |
| This is now a GENERATOR. | |
| """ | |
| if not model_repo_id: | |
| yield "Please select a model from the dropdown." | |
| return # Stop the generator | |
| if not config.MY_AUTH_TOKEN: | |
| yield "Error: `ARENA_AUTH_TOKEN` is not set on the Gradio server." | |
| return | |
| # Construct the URL based on the lane | |
| if lane == "primary": | |
| endpoint = f"{config.MODAL_BASE_URL}-generate-primary.modal.run" | |
| else: | |
| endpoint = f"{config.MODAL_BASE_URL}-generate-secondary.modal.run" | |
| print(f"π Streaming from {model_repo_id} on [{lane.upper()}]...") | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {config.MY_AUTH_TOKEN}" | |
| } | |
| payload = {"model_id": model_repo_id, "prompt": prompt} | |
| try: | |
| # stream=True is the magic. | |
| response = requests.post( | |
| endpoint, | |
| json=payload, | |
| timeout=300, | |
| headers=headers, | |
| stream=True | |
| ) | |
| response.raise_for_status() | |
| # Yield tokens as they arrive | |
| for chunk in response.iter_content(chunk_size=None, decode_unicode=True): | |
| if chunk: | |
| yield chunk | |
| except requests.exceptions.RequestException as e: | |
| if e.response and e.response.status_code == 401: | |
| yield "Error: Authentication failed. The token is invalid." | |
| elif e.response: | |
| # Try to get error detail from the streaming API | |
| try: | |
| error_detail = e.response.json().get("detail", str(e)) | |
| yield f"API Error: {e.response.status_code} - {error_detail}" | |
| except: | |
| yield f"API Error: {e}" | |
| else: | |
| yield f"API Error: {e}" | |
| except Exception as e: | |
| yield f"An unexpected error occurred: {e}" |