bkhmsi commited on
Commit
4e8105c
·
1 Parent(s): 8730f5f

changed model ordering

Browse files
Files changed (2) hide show
  1. app.py +5 -5
  2. router_backend.py +1 -0
app.py CHANGED
@@ -36,14 +36,14 @@ from router_backend import get_expert_routing
36
  EXPERTS = ["Language", "Logic", "Social", "World"]
37
 
38
  DEFAULT_MODELS = [
 
 
39
  "micro-llama-1b",
40
  "micro-llama-3b",
41
  "micro-llama-1b-dpo",
42
- "micro-moe-llama-1b",
43
- "micro-smollm2-135m",
44
- "micro-smollm2-360m",
45
  "micro-moe-smollm2-135m",
46
  "micro-moe-smollm2-360m",
 
47
  ]
48
 
49
  def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
@@ -139,7 +139,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
139
 
140
  with gr.Row():
141
  model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
142
- hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="hf token", lines=1)
143
 
144
  with gr.Row():
145
  user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
@@ -151,7 +151,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
151
 
152
  run = gr.Button("Run Routing", variant="primary")
153
 
154
- generation_output = gr.Textbox(lines=4, label="Generated continuation", placeholder="Generated text will appear here...", interactive=False)
155
 
156
  with gr.Row():
157
  table = gr.Dataframe(label="Routing Percentages", interactive=False)
 
36
  EXPERTS = ["Language", "Logic", "Social", "World"]
37
 
38
  DEFAULT_MODELS = [
39
+ "micro-smollm2-135m",
40
+ "micro-smollm2-360m",
41
  "micro-llama-1b",
42
  "micro-llama-3b",
43
  "micro-llama-1b-dpo",
 
 
 
44
  "micro-moe-smollm2-135m",
45
  "micro-moe-smollm2-360m",
46
+ "micro-moe-llama-1b",
47
  ]
48
 
49
  def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
 
139
 
140
  with gr.Row():
141
  model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
142
+ hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="Required for Llama-based models", lines=1)
143
 
144
  with gr.Row():
145
  user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
 
151
 
152
  run = gr.Button("Run Routing", variant="primary")
153
 
154
+ generation_output = gr.Textbox(lines=4, label="Generated Response", placeholder="Generated text will appear here...", interactive=False)
155
 
156
  with gr.Row():
157
  table = gr.Dataframe(label="Routing Percentages", interactive=False)
router_backend.py CHANGED
@@ -32,6 +32,7 @@ def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dic
32
 
33
  if isinstance(prompt, str):
34
  generation, routing_weights = generate_continuation(model, tokenizer, prompt)
 
35
  elif isinstance(prompt, list):
36
  generation = None
37
  routing_weights = get_routing_weights(model, tokenizer, [prompt])
 
32
 
33
  if isinstance(prompt, str):
34
  generation, routing_weights = generate_continuation(model, tokenizer, prompt)
35
+ generation = generation[0] if type(generation) is list else generation
36
  elif isinstance(prompt, list):
37
  generation = None
38
  routing_weights = get_routing_weights(model, tokenizer, [prompt])