LLM_Finetuning / backend.py
farhananis005's picture
LLM finetuning demo
ec0af28 verified
import requests
import config
class SmartRouter:
def __init__(self):
# Tracks what is currently loaded in the backend (Best Guess)
self.lane_state = {
"primary": None, # URL: ...-generate-primary.modal.run
"secondary": None # URL: ...-generate-secondary.modal.run
}
def get_routing_plan(self, model_left_id, model_right_id):
"""
Decides which model goes to which lane to minimize cold starts.
Returns: (lane_for_left_model, lane_for_right_model)
"""
primary_model = self.lane_state["primary"]
secondary_model = self.lane_state["secondary"]
# Score: 0 = Cache Hit (Good), 1 = Cache Miss (Bad)
# Option A: Straight (Left -> Primary, Right -> Secondary)
cost_straight = (0 if primary_model == model_left_id else 1) + \
(0 if secondary_model == model_right_id else 1)
# Option B: Swapped (Left -> Secondary, Right -> Primary)
cost_swapped = (0 if secondary_model == model_left_id else 1) + \
(0 if primary_model == model_right_id else 1)
if cost_swapped < cost_straight:
print(f"πŸ”€ Smart Router: Swapping lanes to optimize cache!")
# Update state for next time
self.lane_state["secondary"] = model_left_id
self.lane_state["primary"] = model_right_id
return "secondary", "primary"
else:
print(f"⬇️ Smart Router: keeping straight lanes.")
# Update state for next time
self.lane_state["primary"] = model_left_id
self.lane_state["secondary"] = model_right_id
return "primary", "secondary"
# Create a global instance
router = SmartRouter()
# --- STEP 3: REWRITE call_modal_api FOR STREAMING ---
def call_modal_api(model_repo_id, prompt, lane):
"""
Calls the Modal API on a specific lane and yields tokens as they arrive.
This is now a GENERATOR.
"""
if not model_repo_id:
yield "Please select a model from the dropdown."
return # Stop the generator
if not config.MY_AUTH_TOKEN:
yield "Error: `ARENA_AUTH_TOKEN` is not set on the Gradio server."
return
# Construct the URL based on the lane
if lane == "primary":
endpoint = f"{config.MODAL_BASE_URL}-generate-primary.modal.run"
else:
endpoint = f"{config.MODAL_BASE_URL}-generate-secondary.modal.run"
print(f"πŸš€ Streaming from {model_repo_id} on [{lane.upper()}]...")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {config.MY_AUTH_TOKEN}"
}
payload = {"model_id": model_repo_id, "prompt": prompt}
try:
# stream=True is the magic.
response = requests.post(
endpoint,
json=payload,
timeout=300,
headers=headers,
stream=True
)
response.raise_for_status()
# Yield tokens as they arrive
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
if chunk:
yield chunk
except requests.exceptions.RequestException as e:
if e.response and e.response.status_code == 401:
yield "Error: Authentication failed. The token is invalid."
elif e.response:
# Try to get error detail from the streaming API
try:
error_detail = e.response.json().get("detail", str(e))
yield f"API Error: {e.response.status_code} - {error_detail}"
except:
yield f"API Error: {e}"
else:
yield f"API Error: {e}"
except Exception as e:
yield f"An unexpected error occurred: {e}"