Bhaiya Hari Narayan Singh commited on
Commit
39baba6
Β·
verified Β·
1 Parent(s): b329d87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -25
app.py CHANGED
@@ -1,30 +1,33 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from transformers.utils import get_json_schema
 
4
 
5
  # -----------------------
6
- # Load model (CPU friendly)
7
  # -----------------------
8
  model_name = "bhaiyahnsingh45/functiongemma-multiagent-router"
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
-
12
  model = AutoModelForCausalLM.from_pretrained(
13
- model_name
 
 
14
  )
15
 
16
  # -----------------------
17
  # Agents
18
  # -----------------------
19
  def technical_support_agent(issue_type: str, priority: str) -> str:
20
- return f"πŸ› οΈ Technical Support β†’ {issue_type} ({priority})"
21
 
22
  def billing_agent(request_type: str, urgency: str) -> str:
23
- return f"πŸ’° Billing β†’ {request_type} ({urgency})"
24
 
25
  def product_info_agent(query_type: str, category: str) -> str:
26
- return f"πŸ“¦ Product Info β†’ {query_type} ({category})"
27
 
 
28
  AGENT_TOOLS = [
29
  get_json_schema(technical_support_agent),
30
  get_json_schema(billing_agent),
@@ -34,7 +37,7 @@ AGENT_TOOLS = [
34
  SYSTEM_MSG = "You are an intelligent routing agent that directs customer queries to the appropriate specialized agent."
35
 
36
  # -----------------------
37
- # Inference
38
  # -----------------------
39
  def route_query(user_query: str):
40
 
@@ -47,24 +50,28 @@ def route_query(user_query: str):
47
  messages,
48
  tools=AGENT_TOOLS,
49
  add_generation_prompt=True,
 
50
  return_tensors="pt"
51
  )
52
 
 
 
53
  outputs = model.generate(
54
- inputs,
55
- max_new_tokens=100
 
56
  )
57
 
58
  result = tokenizer.decode(
59
- outputs[0],
60
  skip_special_tokens=True
61
  )
62
 
63
- return result.split("assistant")[-1].strip()
64
 
65
 
66
  # -----------------------
67
- # Chat UI
68
  # -----------------------
69
  def chat_fn(message, history):
70
  response = route_query(message)
@@ -72,21 +79,19 @@ def chat_fn(message, history):
72
  return history, history
73
 
74
 
 
 
 
75
  with gr.Blocks() as demo:
76
- gr.Markdown("## πŸ€– Multi-Agent Router (Fast CPU Demo)")
77
-
78
- chatbot = gr.Chatbot(height=400)
79
- msg = gr.Textbox(placeholder="Ask something...")
80
-
81
- gr.Examples(
82
- examples=[
83
- "My app crashes when uploading files",
84
- "I want a refund",
85
- "What features are in premium plan?"
86
- ],
87
- inputs=msg
88
- )
89
 
90
  msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot])
 
91
 
 
92
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from transformers.utils import get_json_schema
4
+ import torch
5
 
6
  # -----------------------
7
+ # Load model
8
  # -----------------------
9
  model_name = "bhaiyahnsingh45/functiongemma-multiagent-router"
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
+ model_name,
14
+ device_map="auto",
15
+ torch_dtype="auto"
16
  )
17
 
18
  # -----------------------
19
  # Agents
20
  # -----------------------
21
  def technical_support_agent(issue_type: str, priority: str) -> str:
22
+ return f"πŸ› οΈ Routing to Technical Support: {issue_type} ({priority})"
23
 
24
  def billing_agent(request_type: str, urgency: str) -> str:
25
+ return f"πŸ’° Routing to Billing: {request_type} ({urgency})"
26
 
27
  def product_info_agent(query_type: str, category: str) -> str:
28
+ return f"πŸ“¦ Routing to Product Info: {query_type} ({category})"
29
 
30
+ # Tool schemas
31
  AGENT_TOOLS = [
32
  get_json_schema(technical_support_agent),
33
  get_json_schema(billing_agent),
 
37
  SYSTEM_MSG = "You are an intelligent routing agent that directs customer queries to the appropriate specialized agent."
38
 
39
  # -----------------------
40
+ # Core inference
41
  # -----------------------
42
  def route_query(user_query: str):
43
 
 
50
  messages,
51
  tools=AGENT_TOOLS,
52
  add_generation_prompt=True,
53
+ return_dict=True,
54
  return_tensors="pt"
55
  )
56
 
57
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
58
+
59
  outputs = model.generate(
60
+ **inputs,
61
+ max_new_tokens=128,
62
+ pad_token_id=tokenizer.eos_token_id
63
  )
64
 
65
  result = tokenizer.decode(
66
+ outputs[0][len(inputs["input_ids"][0]):],
67
  skip_special_tokens=True
68
  )
69
 
70
+ return result
71
 
72
 
73
  # -----------------------
74
+ # Chatbot logic
75
  # -----------------------
76
  def chat_fn(message, history):
77
  response = route_query(message)
 
79
  return history, history
80
 
81
 
82
+ # -----------------------
83
+ # UI
84
+ # -----------------------
85
  with gr.Blocks() as demo:
86
+ gr.Markdown("## πŸ€– Multi-Agent Router Chatbot")
87
+ gr.Markdown("Ask anything about billing, product, or technical issues.")
88
+
89
+ chatbot = gr.Chatbot()
90
+ msg = gr.Textbox(placeholder="Type your query here...")
91
+ clear = gr.Button("Clear")
 
 
 
 
 
 
 
92
 
93
  msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot])
94
+ clear.click(lambda: None, None, chatbot, queue=False)
95
 
96
+ # Launch
97
  demo.launch()