File size: 9,010 Bytes
adca896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3402a8
adca896
c3402a8
 
ef65eaa
 
 
d2ed8d0
c3402a8
adca896
 
 
c3402a8
adca896
 
 
 
c3402a8
adca896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import json
import time
import ast
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from gradio_consilium_roundtable import consilium_roundtable

# === Constants ===
MODEL_NAME = "katanemo/Arch-Router-1.5B"
ARCH_ROUTER = "Arch Router"
WAIT_DEPARTMENT = 5
WAIT_SYSTEM = 5

# === Load model/tokenizer ===
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# === Route Definitions ===
route_config = [
    {"name": "code_generation", "description": "Generating code based on prompts"},
    {"name": "bug_fixing", "description": "Fixing errors or bugs in code"},
    {"name": "performance_optimization", "description": "Improving code performance"},
    {"name": "api_help", "description": "Assisting with APIs and libraries"},
    {"name": "programming", "description": "General programming Q&A"},
    {"name": "legal", "description": "Legal"},
    {"name": "healthcare", "description": "Healthcare and medical related"},
]

departments = {
    "code_generation": ("πŸ’»", "Code Generation"),
    "bug_fixing": ("🐞", "Bug Fixing"),
    "performance_optimization": ("⚑", "Performance Optimization"),
    "api_help": ("πŸ”Œ", "API Help"),
    "programming": ("πŸ“š", "Programming"),
    "legal": ("βš–οΈ", "Legal"),
    "healthcare": ("🩺", "Healthcare"),
    "other": ("❓", "Other / General Inquiry"),
}

# === Prompt Formatting ===
TASK_INSTRUCTION = """
You are a helpful assistant designed to find the best suited route. You are provided with route description within <routes></routes> XML tags:
<routes>
{routes}
</routes>

<conversation>
{conversation}
</conversation>
"""

FORMAT_PROMPT = """
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.

Based on your analysis, provide your response in the following JSON format:
{"route": "route_name"} 
"""

def format_prompt(route_config, conversation):
    return TASK_INSTRUCTION.format(
        routes=json.dumps(route_config), conversation=json.dumps(conversation)
    ) + FORMAT_PROMPT

def parse_route(response_text):
    try:
        start = response_text.find("{")
        end = response_text.rfind("}") + 1
        return ast.literal_eval(response_text[start:end]).get("route", "other")
    except Exception as e:
        print("Parsing failed:", e)
        return "other"

def init_state():
    avatar_emojis = {
        ARCH_ROUTER: "https://avatars.githubusercontent.com/u/112724757?s=200&v=4",
        "code_generation": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f4bb.png",
        "bug_fixing": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f41e.png",
        "performance_optimization": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/26a1.png",
        "api_help": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f50c.png",
        "programming": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f4da.png",
        "legal": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/2696.png",
        "healthcare": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1fa7a.png",
        "other": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/2753.png",
    }
    return {
        "messages": [],
        "participants": [ARCH_ROUTER] + list(departments.keys()),
        "currentSpeaker": None,
        "thinking": [],
        "showBubbles": [ARCH_ROUTER],
        "avatarImages": avatar_emojis,
    }

def route_and_visualize(user_input_text, rt_state, chat_history):
    chat_history = chat_history or []
    rt_state = rt_state or {"messages": []}
    chat_history.append(("User", user_input_text))

    # Step 1: Disable input and show route detection
    rt_state["messages"] = [{"speaker": ARCH_ROUTER, "text": "πŸ”Ž Identifying route, please wait..."}]
    yield rt_state, chat_history, rt_state, gr.update(interactive=False)

    # Step 2: Prepare prompt and get route
    conversation = [{"role": "user", "content": user_input_text}]
    route_prompt = format_prompt(route_config, conversation)
    input_ids = tokenizer.apply_chat_template(
        [{"role": "user", "content": route_prompt}],
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    with torch.no_grad():
        output = model.generate(input_ids=input_ids, max_new_tokens=512)

    prompt_len = input_ids.shape[1]
    response = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True).strip()
    print("MODEL RAW:", response)
    route = parse_route(response)

    emoji, dept_name = departments.get(route, departments["other"])

    # Step 3: Show route identified
    rt_state["messages"][0] = {
        "speaker": ARCH_ROUTER,
        "text": f"πŸ“Œ Identified department: **{dept_name}**. Forwarding task..."
    }
    chat_history.append((ARCH_ROUTER, f"πŸ“Œ Identified department: {dept_name}. Forwarding task..."))
    yield rt_state, chat_history, rt_state, gr.update(interactive=False)

    # Step 4: Show processing
    time.sleep(3)
    rt_state["messages"].extend([
        {"speaker": route, "text": f"{emoji} {dept_name} simulation is processing your request in {WAIT_DEPARTMENT} secs..."},
        {"speaker": ARCH_ROUTER, "text": "⏳ Waiting for department to respond..."}
    ])
    rt_state["showBubbles"] = [ARCH_ROUTER, route]
    yield rt_state, chat_history, rt_state, gr.update(interactive=False)

    # Step 5: Simulate delay and complete
    time.sleep(WAIT_DEPARTMENT)
    rt_state["messages"][-2]["text"] = f"βœ… {dept_name} completed the task."
    rt_state["messages"][-1]["text"] = f"βœ… {dept_name} department has completed the task."
    chat_history.append((ARCH_ROUTER, f"βœ… {dept_name} department completed the task."))
    yield rt_state, chat_history, rt_state, gr.update(interactive=False)

    # Step 6: Reset visible bubbles
    rt_state["showBubbles"] = [ARCH_ROUTER]
    yield rt_state, chat_history, rt_state, gr.update(interactive=False)

    # Step 7: System ready
    time.sleep(WAIT_SYSTEM)
    rt_state["messages"].append({"speaker": ARCH_ROUTER, "text": "Arch Router is ready to discuss."})
    yield rt_state, chat_history, rt_state, gr.update(interactive=True)

# === Gradio UI ===
with gr.Blocks(title="Arch Router Simulation: Smart Department Dispatcher", theme=gr.themes.Ocean()) as demo:
    gr.Markdown(
    """
    ## 🧭 Arch Router Simulation: Smart Department Dispatcher
    **This is a demo simulation of <a href="https://huggingface.co/katanemo/Arch-Router-1.5B" target="_blank">katanemo/Arch-Router-1.5B</a>.**
    **Kindly refer official documentation for more details**

    * See how Arch Router identifies the best route **(or Domain – the high-level category)** based on user prompt and take desired **Action (specific type of operation user wants to perform)** by forwarding it to respective department.
    """
    )

    with gr.Row():
        with gr.Column(scale=2):
            rt_state = gr.State(init_state())
            chat_state = gr.State([])
            roundtable = consilium_roundtable(value=init_state())

        with gr.Column(scale=1):
            chatbot = gr.Chatbot(label="Chat History", max_height=300)
            textbox = gr.Textbox(placeholder="Describe your issue...", label="Ask Arch Router")
            submit_btn = gr.Button("Submit")

            example_inputs = [
                "How do I optimize this loop in Python?",
                "Generate a function to sort an array in python",
                "Help me anonymize patient health records before storing them",
                "I'm getting a TypeError in following code",
                "Do I need to include attribution for MIT-licensed software?",
                "How do I connect to external API from this code?"
            ]

            # Trigger submission via Enter or Button
            for trigger in (textbox.submit, submit_btn.click):
                trigger(
                    route_and_visualize,
                    inputs=[textbox, rt_state, chat_state],
                    outputs=[roundtable, chatbot, rt_state, textbox],
                    concurrency_limit=1
                )

            # Example block
            gr.Examples(
                examples=example_inputs,
                inputs=textbox,
                label="Try one of these examples"
            )

if __name__ == "__main__":
    demo.launch()