| | import gradio as gr |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
|
| | |
| | MODEL_OPTIONS = { |
| | "Mistral-7B-Instruct": "mistralai/Mistral-7B-Instruct-v0.1", |
| | "Qwen2.5-3B-Instruct": "Qwen/Qwen2.5-3B-Instruct", |
| | "Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct", |
| | "StableLM2-1.6B": "stabilityai/stablelm-2-zephyr-1_6b", |
| | "SmolLM3-3B": "HuggingFaceTB/SmolLM3-3B", |
| | "BTLM-3B-8k-base": "cerebras/btlm-3b-8k-base" |
| | } |
| |
|
| | loaded = {} |
| | SYSTEM_PROMPT = "You are HugginGPT — helpful, friendly, and clear with memory." |
| |
|
| | def load_model(model_key): |
| | model_id = MODEL_OPTIONS[model_key] |
| | if model_key in loaded: |
| | return loaded[model_key] |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | device_map="auto", |
| | torch_dtype=torch.float16 |
| | ) |
| |
|
| | loaded[model_key] = (tokenizer, model) |
| | return tokenizer, model |
| |
|
| | def generate_response(message, history, model_choice): |
| | tokenizer, model = load_model(model_choice) |
| |
|
| | |
| | context = f"system: {SYSTEM_PROMPT}\n" |
| | if history: |
| | for u, a in history: |
| | context += f"user: {u}\nassistant: {a}\n" |
| | context += f"user: {message}\nassistant:" |
| |
|
| | inputs = tokenizer(context, return_tensors="pt").to(model.device) |
| | output = model.generate( |
| | **inputs, |
| | max_new_tokens=200, |
| | do_sample=True, |
| | top_p=0.9, |
| | temperature=0.8 |
| | ) |
| | text = tokenizer.decode(output[0], skip_special_tokens=True) |
| | reply = text.split("assistant:")[-1].strip() |
| | return reply |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("## HugginGPT") |
| |
|
| | model_selector = gr.Dropdown( |
| | choices=list(MODEL_OPTIONS.keys()), |
| | value="Mistral-7B-Instruct", |
| | label="Select model" |
| | ) |
| |
|
| | chat = gr.ChatInterface( |
| | fn=lambda message, history: generate_response(message, history, model_selector.value), |
| | title="HugginGPT" |
| | ) |
| |
|
| | demo.launch() |
| |
|