Spaces:
Running
Running
changed model ordering
Browse files- app.py +5 -5
- router_backend.py +1 -0
app.py
CHANGED
|
@@ -36,14 +36,14 @@ from router_backend import get_expert_routing
|
|
| 36 |
EXPERTS = ["Language", "Logic", "Social", "World"]
|
| 37 |
|
| 38 |
DEFAULT_MODELS = [
|
|
|
|
|
|
|
| 39 |
"micro-llama-1b",
|
| 40 |
"micro-llama-3b",
|
| 41 |
"micro-llama-1b-dpo",
|
| 42 |
-
"micro-moe-llama-1b",
|
| 43 |
-
"micro-smollm2-135m",
|
| 44 |
-
"micro-smollm2-360m",
|
| 45 |
"micro-moe-smollm2-135m",
|
| 46 |
"micro-moe-smollm2-360m",
|
|
|
|
| 47 |
]
|
| 48 |
|
| 49 |
def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
|
|
@@ -139,7 +139,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 139 |
|
| 140 |
with gr.Row():
|
| 141 |
model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
|
| 142 |
-
hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="
|
| 143 |
|
| 144 |
with gr.Row():
|
| 145 |
user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
|
|
@@ -151,7 +151,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 151 |
|
| 152 |
run = gr.Button("Run Routing", variant="primary")
|
| 153 |
|
| 154 |
-
generation_output = gr.Textbox(lines=4, label="Generated
|
| 155 |
|
| 156 |
with gr.Row():
|
| 157 |
table = gr.Dataframe(label="Routing Percentages", interactive=False)
|
|
|
|
| 36 |
EXPERTS = ["Language", "Logic", "Social", "World"]
|
| 37 |
|
| 38 |
DEFAULT_MODELS = [
|
| 39 |
+
"micro-smollm2-135m",
|
| 40 |
+
"micro-smollm2-360m",
|
| 41 |
"micro-llama-1b",
|
| 42 |
"micro-llama-3b",
|
| 43 |
"micro-llama-1b-dpo",
|
|
|
|
|
|
|
|
|
|
| 44 |
"micro-moe-smollm2-135m",
|
| 45 |
"micro-moe-smollm2-360m",
|
| 46 |
+
"micro-moe-llama-1b",
|
| 47 |
]
|
| 48 |
|
| 49 |
def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
|
|
|
|
| 139 |
|
| 140 |
with gr.Row():
|
| 141 |
model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
|
| 142 |
+
hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="Required for Llama-based models", lines=1)
|
| 143 |
|
| 144 |
with gr.Row():
|
| 145 |
user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
|
|
|
|
| 151 |
|
| 152 |
run = gr.Button("Run Routing", variant="primary")
|
| 153 |
|
| 154 |
+
generation_output = gr.Textbox(lines=4, label="Generated Response", placeholder="Generated text will appear here...", interactive=False)
|
| 155 |
|
| 156 |
with gr.Row():
|
| 157 |
table = gr.Dataframe(label="Routing Percentages", interactive=False)
|
router_backend.py
CHANGED
|
@@ -32,6 +32,7 @@ def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dic
|
|
| 32 |
|
| 33 |
if isinstance(prompt, str):
|
| 34 |
generation, routing_weights = generate_continuation(model, tokenizer, prompt)
|
|
|
|
| 35 |
elif isinstance(prompt, list):
|
| 36 |
generation = None
|
| 37 |
routing_weights = get_routing_weights(model, tokenizer, [prompt])
|
|
|
|
| 32 |
|
| 33 |
if isinstance(prompt, str):
|
| 34 |
generation, routing_weights = generate_continuation(model, tokenizer, prompt)
|
| 35 |
+
generation = generation[0] if type(generation) is list else generation
|
| 36 |
elif isinstance(prompt, list):
|
| 37 |
generation = None
|
| 38 |
routing_weights = get_routing_weights(model, tokenizer, [prompt])
|