File size: 3,315 Bytes
54d6eb5
 
82c86d4
54d6eb5
 
 
 
 
 
82c86d4
 
54d6eb5
 
 
 
 
 
 
 
14def3d
 
54d6eb5
 
2b10ab8
 
54d6eb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82c86d4
 
14def3d
 
 
 
 
82c86d4
 
54d6eb5
 
 
2b10ab8
 
 
54d6eb5
 
 
 
 
 
82c86d4
14def3d
82c86d4
 
 
14def3d
82c86d4
54d6eb5
 
 
 
 
 
 
 
82c86d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import gradio as gr
from huggingface_hub import list_models, list_repo_files, hf_hub_download
import subprocess

HF_USER = "fbaldassarri"
TEQ_KEYWORD = "TEQ"

def list_teq_models():
    models = list_models(author=HF_USER)
    return [model.modelId for model in models if TEQ_KEYWORD in model.modelId]

def list_model_files(model_id):
    files = list_repo_files(model_id)
    weights = [f for f in files if f.endswith('.pt')]
    configs = [f for f in files if f.endswith('.json')]
    return weights, configs

def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug):
    if not weights_file or not config_file:
        return "Please select both a weights file and a config file."
    local_model_dir = f"./models/{model_id.replace('/', '_')}"
    os.makedirs(local_model_dir, exist_ok=True)
    hf_hub_download(model_id, weights_file, local_dir=local_model_dir)
    hf_hub_download(model_id, config_file, local_dir=local_model_dir)
    cmd = [
        "python", "teq_inference.py",
        "--model_dir", local_model_dir,
        "--weights_file", weights_file,
        "--config_file", config_file,
        "--base_model", base_model,
        "--prompt", prompt,
        "--max_new_tokens", str(max_new_tokens),
        "--device", "cpu"
    ]
    if debug:
        cmd.append("--debug")
    result = subprocess.run(cmd, capture_output=True, text=True)
    output = result.stdout + "\n" + result.stderr
    marker = "Generated text:"
    if marker in output:
        return output.split(marker)[-1].strip()
    return output

def update_files(model_id):
    weights, configs = list_model_files(model_id)
    # Set value to first item if available, else None
    weights_val = weights[0] if weights else None
    configs_val = configs[0] if configs else None
    return gr.Dropdown(choices=weights, value=weights_val, label="Weights File (.pt)", interactive=True), \
           gr.Dropdown(choices=configs, value=configs_val, label="Config File (.json)", interactive=True)

def build_ui():
    teq_models = list_teq_models()
    with gr.Blocks() as demo:
        gr.Markdown("# TEQ Quantized Model Inference Demo")
        model_id = gr.Dropdown(teq_models, label="Select TEQ Model", interactive=True)
        weights_file = gr.Dropdown(choices=[], label="Weights File (.pt)", interactive=True)
        config_file = gr.Dropdown(choices=[], label="Config File (.json)", interactive=True)
        base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m")
        prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl")
        max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens")
        debug = gr.Checkbox(label="Debug Mode")
        output = gr.Textbox(label="Generated Text", lines=10)
        run_btn = gr.Button("Run Inference")

        # Dynamically update the dropdowns for weights and config
        model_id.change(
            update_files,
            inputs=model_id,
            outputs=[weights_file, config_file]
        )
        run_btn.click(
            run_teq_inference,
            inputs=[model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug],
            outputs=output
        )
    return demo

if __name__ == "__main__":
    build_ui().launch()