Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from huggingface_hub import list_models, list_repo_files, hf_hub_download | |
import subprocess | |
HF_USER = "fbaldassarri" | |
TEQ_KEYWORD = "TEQ" | |
def list_teq_models(): | |
models = list_models(author=HF_USER) | |
return [model.modelId for model in models if TEQ_KEYWORD in model.modelId] | |
def list_model_files(model_id): | |
files = list_repo_files(model_id) | |
weights = [f for f in files if f.endswith('.pt')] | |
configs = [f for f in files if f.endswith('.json')] | |
return weights, configs | |
def run_teq_inference(model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug): | |
if not weights_file or not config_file: | |
return "Please select both a weights file and a config file." | |
local_model_dir = f"./models/{model_id.replace('/', '_')}" | |
os.makedirs(local_model_dir, exist_ok=True) | |
hf_hub_download(model_id, weights_file, local_dir=local_model_dir) | |
hf_hub_download(model_id, config_file, local_dir=local_model_dir) | |
cmd = [ | |
"python", "teq_inference.py", | |
"--model_dir", local_model_dir, | |
"--weights_file", weights_file, | |
"--config_file", config_file, | |
"--base_model", base_model, | |
"--prompt", prompt, | |
"--max_new_tokens", str(max_new_tokens), | |
"--device", "cpu" | |
] | |
if debug: | |
cmd.append("--debug") | |
result = subprocess.run(cmd, capture_output=True, text=True) | |
output = result.stdout + "\n" + result.stderr | |
marker = "Generated text:" | |
if marker in output: | |
return output.split(marker)[-1].strip() | |
return output | |
def update_files(model_id): | |
weights, configs = list_model_files(model_id) | |
# Set value to first item if available, else None | |
weights_val = weights[0] if weights else None | |
configs_val = configs[0] if configs else None | |
return gr.Dropdown(choices=weights, value=weights_val, label="Weights File (.pt)", interactive=True), \ | |
gr.Dropdown(choices=configs, value=configs_val, label="Config File (.json)", interactive=True) | |
def build_ui(): | |
teq_models = list_teq_models() | |
with gr.Blocks() as demo: | |
gr.Markdown("# TEQ Quantized Model Inference Demo") | |
model_id = gr.Dropdown(teq_models, label="Select TEQ Model", interactive=True) | |
weights_file = gr.Dropdown(choices=[], label="Weights File (.pt)", interactive=True) | |
config_file = gr.Dropdown(choices=[], label="Config File (.json)", interactive=True) | |
base_model = gr.Textbox(label="Base Model Name", value="facebook/opt-350m") | |
prompt = gr.Textbox(label="Prompt", value="Once upon a time, a little girl") | |
max_new_tokens = gr.Slider(10, 512, value=100, label="Max New Tokens") | |
debug = gr.Checkbox(label="Debug Mode") | |
output = gr.Textbox(label="Generated Text", lines=10) | |
run_btn = gr.Button("Run Inference") | |
# Dynamically update the dropdowns for weights and config | |
model_id.change( | |
update_files, | |
inputs=model_id, | |
outputs=[weights_file, config_file] | |
) | |
run_btn.click( | |
run_teq_inference, | |
inputs=[model_id, weights_file, config_file, base_model, prompt, max_new_tokens, debug], | |
outputs=output | |
) | |
return demo | |
if __name__ == "__main__": | |
build_ui().launch() | |