Bhaiya Hari Narayan Singh commited on
Commit
d59bc5e
Β·
verified Β·
1 Parent(s): f3150ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -30
app.py CHANGED
@@ -1,33 +1,30 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from transformers.utils import get_json_schema
4
- import torch
5
 
6
  # -----------------------
7
- # Load model
8
  # -----------------------
9
  model_name = "bhaiyahnsingh45/functiongemma-multiagent-router"
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
- model_name,
14
- device_map="auto",
15
- torch_dtype="auto"
16
  )
17
 
18
  # -----------------------
19
  # Agents
20
  # -----------------------
21
  def technical_support_agent(issue_type: str, priority: str) -> str:
22
- return f"πŸ› οΈ Routing to Technical Support: {issue_type} ({priority})"
23
 
24
  def billing_agent(request_type: str, urgency: str) -> str:
25
- return f"πŸ’° Routing to Billing: {request_type} ({urgency})"
26
 
27
  def product_info_agent(query_type: str, category: str) -> str:
28
- return f"πŸ“¦ Routing to Product Info: {query_type} ({category})"
29
 
30
- # Tool schemas
31
  AGENT_TOOLS = [
32
  get_json_schema(technical_support_agent),
33
  get_json_schema(billing_agent),
@@ -37,7 +34,7 @@ AGENT_TOOLS = [
37
  SYSTEM_MSG = "You are an intelligent routing agent that directs customer queries to the appropriate specialized agent."
38
 
39
  # -----------------------
40
- # Core inference
41
  # -----------------------
42
  def route_query(user_query: str):
43
 
@@ -50,28 +47,24 @@ def route_query(user_query: str):
50
  messages,
51
  tools=AGENT_TOOLS,
52
  add_generation_prompt=True,
53
- return_dict=True,
54
  return_tensors="pt"
55
  )
56
 
57
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
58
-
59
  outputs = model.generate(
60
- **inputs,
61
- max_new_tokens=128,
62
- pad_token_id=tokenizer.eos_token_id
63
  )
64
 
65
  result = tokenizer.decode(
66
- outputs[0][len(inputs["input_ids"][0]):],
67
  skip_special_tokens=True
68
  )
69
 
70
- return result
71
 
72
 
73
  # -----------------------
74
- # Chatbot logic
75
  # -----------------------
76
  def chat_fn(message, history):
77
  response = route_query(message)
@@ -79,19 +72,21 @@ def chat_fn(message, history):
79
  return history, history
80
 
81
 
82
- # -----------------------
83
- # UI
84
- # -----------------------
85
  with gr.Blocks() as demo:
86
- gr.Markdown("## πŸ€– Multi-Agent Router Chatbot")
87
- gr.Markdown("Ask anything about billing, product, or technical issues.")
88
-
89
- chatbot = gr.Chatbot()
90
- msg = gr.Textbox(placeholder="Type your query here...")
91
- clear = gr.Button("Clear")
 
 
 
 
 
 
 
92
 
93
  msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot])
94
- clear.click(lambda: None, None, chatbot, queue=False)
95
 
96
- # Launch
97
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from transformers.utils import get_json_schema
 
4
 
5
  # -----------------------
6
+ # Load model (CPU friendly)
7
  # -----------------------
8
  model_name = "bhaiyahnsingh45/functiongemma-multiagent-router"
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
  model = AutoModelForCausalLM.from_pretrained(
13
+ model_name
 
 
14
  )
15
 
16
  # -----------------------
17
  # Agents
18
  # -----------------------
19
  def technical_support_agent(issue_type: str, priority: str) -> str:
20
+ return f"πŸ› οΈ Technical Support β†’ {issue_type} ({priority})"
21
 
22
  def billing_agent(request_type: str, urgency: str) -> str:
23
+ return f"πŸ’° Billing β†’ {request_type} ({urgency})"
24
 
25
  def product_info_agent(query_type: str, category: str) -> str:
26
+ return f"πŸ“¦ Product Info β†’ {query_type} ({category})"
27
 
 
28
  AGENT_TOOLS = [
29
  get_json_schema(technical_support_agent),
30
  get_json_schema(billing_agent),
 
34
  SYSTEM_MSG = "You are an intelligent routing agent that directs customer queries to the appropriate specialized agent."
35
 
36
  # -----------------------
37
+ # Inference
38
  # -----------------------
39
  def route_query(user_query: str):
40
 
 
47
  messages,
48
  tools=AGENT_TOOLS,
49
  add_generation_prompt=True,
 
50
  return_tensors="pt"
51
  )
52
 
 
 
53
  outputs = model.generate(
54
+ inputs,
55
+ max_new_tokens=100
 
56
  )
57
 
58
  result = tokenizer.decode(
59
+ outputs[0],
60
  skip_special_tokens=True
61
  )
62
 
63
+ return result.split("assistant")[-1].strip()
64
 
65
 
66
  # -----------------------
67
+ # Chat UI
68
  # -----------------------
69
  def chat_fn(message, history):
70
  response = route_query(message)
 
72
  return history, history
73
 
74
 
 
 
 
75
  with gr.Blocks() as demo:
76
+ gr.Markdown("## πŸ€– Multi-Agent Router (Fast CPU Demo)")
77
+
78
+ chatbot = gr.Chatbot(height=400)
79
+ msg = gr.Textbox(placeholder="Ask something...")
80
+
81
+ gr.Examples(
82
+ examples=[
83
+ "My app crashes when uploading files",
84
+ "I want a refund",
85
+ "What features are in premium plan?"
86
+ ],
87
+ inputs=msg
88
+ )
89
 
90
  msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot])
 
91
 
 
92
  demo.launch()