| | import os |
| | import gdown |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer |
| | import gradio as gr |
| |
|
| | os.makedirs("model", exist_ok=True) |
| |
|
| | MODEL_URL = "https://drive.google.com/uc?id=1Kg8KSGIgjBopeOKSbYbFWEgUlYOcqyXX" |
| | MODEL_PATH = "model/model.safetensors" |
| |
|
| | if not os.path.exists(MODEL_PATH): |
| | print("⬇ Downloading model weights...") |
| | gdown.download(MODEL_URL, MODEL_PATH, quiet=False) |
| | else: |
| | print("✅ Model file already exists") |
| |
|
| | print("🔧 Loading model & tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained("model") |
| | model = AutoModelForCausalLM.from_pretrained("model", torch_dtype=torch.float16) |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model.to(device) |
| | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| |
|
| | def respond(message, history, max_tokens, temperature, top_p): |
| | input_ids = tokenizer.encode(message, return_tensors="pt").to(device) |
| | history_text = "" |
| |
|
| | if history: |
| | for user, bot in history: |
| | history_text += f"<|user|>{user}<|assistant|>{bot}" |
| |
|
| | full_input = history_text + f"<|user|>{message}<|assistant|>" |
| |
|
| | inputs = tokenizer(full_input, return_tensors="pt").to(device) |
| | output = model.generate( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | do_sample=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | pad_token_id=tokenizer.eos_token_id |
| | ) |
| |
|
| | output_text = tokenizer.decode(output[0], skip_special_tokens=True) |
| | |
| | answer = output_text.split("<|assistant|>")[-1].strip() |
| | return answer |
| |
|
| | |
| | chat = gr.ChatInterface( |
| | fn=respond, |
| | additional_inputs=[ |
| | gr.Slider(64, 1024, value=256, label="Max Tokens"), |
| | gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"), |
| | gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), |
| | ], |
| | title="🦙 TinyLLaMA Chatbot", |
| | description="Fine-tuned TinyLLaMA using QLoRA.", |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | chat.launch() |
| |
|