0r0b0r0s commited on
Commit
25b21d1
·
verified ·
1 Parent(s): f2bd555

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -34
app.py CHANGED
@@ -1,44 +1,31 @@
1
  import os
2
  import gradio as gr
3
  import requests
4
- import inspect
5
  import pandas as pd
6
  from langgraph.graph import StateGraph, END
7
  from huggingface_hub import InferenceClient
 
8
 
9
-
10
- # (Keep Constants as is)
11
  # --- Constants ---
12
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
13
 
14
- # --- Basic Agent Definition ---
15
- # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
16
- # --- GAIA-Optimized Agent Implementation ---
17
- # Configure fallback models
18
  MODELS = [
19
  "Qwen/Qwen2-0.5B-Instruct",
20
  "google/flan-t5-xxl",
21
  "mistralai/Mistral-7B-Instruct-v0.2"
22
  ]
23
 
24
- # Initialize clients with automatic retry
25
  clients = [InferenceClient(model=model, token=os.environ["HF_TOKEN"]) for model in MODELS]
26
 
27
- # Define state structure using dictionary
28
- initial_state = {
29
- "question": "",
30
- "retries": 0,
31
- "current_model": 0,
32
- "answer": ""
33
- }
34
-
35
  def model_router(state: dict) -> dict:
36
  """Rotate through available models"""
37
  state["current_model"] = (state["current_model"] + 1) % len(MODELS)
38
  return state
39
 
40
  def query_model(state: dict) -> dict:
41
- """Attempt to get answer from current model"""
42
  try:
43
  response = clients[state["current_model"]].text_generation(
44
  prompt=f"""<|im_start|>system
@@ -51,48 +38,50 @@ Answer with ONLY the exact value requested.<|im_end|>
51
  max_new_tokens=50,
52
  stop_sequences=["<|im_end|>"]
53
  )
54
- state["answer"] = response.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
55
- state["answer"] = re.sub(r'[^a-zA-Z0-9]', '', state["answer"]).lower()
 
 
56
  except Exception as e:
57
  print(f"Model error: {str(e)}")
58
  state["answer"] = ""
59
  return state
60
 
61
- def validate_answer(state: dict) -> str:
62
- """Check if we have a valid answer"""
63
- return "final_answer" if state["answer"] else "retry"
64
 
65
  # Build workflow
66
  workflow = StateGraph(dict)
67
  workflow.add_node("route_model", model_router)
68
  workflow.add_node("query", query_model)
69
- workflow.add_node("validate", validate_answer)
70
 
 
71
  workflow.add_edge("route_model", "query")
72
- workflow.add_edge("query", "validate")
73
-
74
  workflow.add_conditional_edges(
75
- "validate",
76
- lambda x: "final_answer" if x["answer"] else "retry",
77
- {"final_answer": END, "retry": "route_model"}
78
  )
79
 
80
- workflow.set_entry_point("route_model")
81
  compiled_agent = workflow.compile()
82
 
83
- # GAIA Interface
84
  class BasicAgent:
85
  def __call__(self, question: str) -> str:
86
- state = initial_state.copy()
87
- state["question"] = question
 
 
 
 
88
 
89
  for _ in range(3): # Max 3 attempts
90
  state = compiled_agent.invoke(state)
91
  if state["answer"]:
92
  return state["answer"]
93
- time.sleep(1) # Backoff
94
 
95
- return "" # Return empty to preserve scoring
96
 
97
  def run_and_submit_all( profile: gr.OAuthProfile | None):
98
  """
 
1
  import os
2
  import gradio as gr
3
  import requests
4
+ import re # Added missing import
5
  import pandas as pd
6
  from langgraph.graph import StateGraph, END
7
  from huggingface_hub import InferenceClient
8
+ import time # Added missing import
9
 
 
 
10
  # --- Constants ---
11
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
 
13
+ # --- Optimized Agent Implementation ---
 
 
 
14
  MODELS = [
15
  "Qwen/Qwen2-0.5B-Instruct",
16
  "google/flan-t5-xxl",
17
  "mistralai/Mistral-7B-Instruct-v0.2"
18
  ]
19
 
 
20
  clients = [InferenceClient(model=model, token=os.environ["HF_TOKEN"]) for model in MODELS]
21
 
 
 
 
 
 
 
 
 
22
  def model_router(state: dict) -> dict:
23
  """Rotate through available models"""
24
  state["current_model"] = (state["current_model"] + 1) % len(MODELS)
25
  return state
26
 
27
  def query_model(state: dict) -> dict:
28
+ """Generate answer with error handling"""
29
  try:
30
  response = clients[state["current_model"]].text_generation(
31
  prompt=f"""<|im_start|>system
 
38
  max_new_tokens=50,
39
  stop_sequences=["<|im_end|>"]
40
  )
41
+ # Fixed answer extraction
42
+ answer_part = response.split("<|im_start|>assistant")[-1]
43
+ answer = answer_part.split("<|im_end|>")[0].strip()
44
+ state["answer"] = re.sub(r'[^a-zA-Z0-9]', '', answer).lower()
45
  except Exception as e:
46
  print(f"Model error: {str(e)}")
47
  state["answer"] = ""
48
  return state
49
 
50
+ def should_continue(state: dict) -> str:
51
+ """Conditional edge function (not a node)"""
52
+ return END if state["answer"] else "route_model"
53
 
54
  # Build workflow
55
  workflow = StateGraph(dict)
56
  workflow.add_node("route_model", model_router)
57
  workflow.add_node("query", query_model)
 
58
 
59
+ workflow.set_entry_point("route_model")
60
  workflow.add_edge("route_model", "query")
 
 
61
  workflow.add_conditional_edges(
62
+ "query",
63
+ should_continue,
64
+ {END: END, "route_model": "route_model"}
65
  )
66
 
 
67
  compiled_agent = workflow.compile()
68
 
 
69
  class BasicAgent:
70
  def __call__(self, question: str) -> str:
71
+ state = {
72
+ "question": question,
73
+ "retries": 0,
74
+ "current_model": 0,
75
+ "answer": ""
76
+ }
77
 
78
  for _ in range(3): # Max 3 attempts
79
  state = compiled_agent.invoke(state)
80
  if state["answer"]:
81
  return state["answer"]
82
+ time.sleep(1)
83
 
84
+ return ""
85
 
86
  def run_and_submit_all( profile: gr.OAuthProfile | None):
87
  """