Spaces:
Paused
Paused
import torch | |
from transformers import AutoProcessor, AutoModelForImageTextToText, TextStreamer | |
from peft import PeftModel | |
import gradio as gr | |
# Load base model and processor | |
base_model_id = "unsloth/gemma-3-12b-it-unsloth-bnb-4bit" | |
adapter_model_id = "adarsh3601/my_gemma_pt3" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
processor = AutoProcessor.from_pretrained(base_model_id) | |
model = AutoModelForImageTextToText.from_pretrained(base_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto") | |
# Apply adapter (LoRA) | |
model = PeftModel.from_pretrained(model, adapter_model_id) | |
model.eval() | |
streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Helper to format messages using the chat template | |
def format_chat(messages): | |
formatted = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
return formatted | |
# Chat function | |
def chat(message, history): | |
messages = [] | |
# Format history into messages | |
for user_msg, bot_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": message}) | |
prompt = format_chat(messages) | |
inputs = processor(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_new_tokens=512, streamer=streamer) | |
decoded = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
response = decoded.split("<end_of_turn>")[0].strip().split("<start_of_turn>model")[-1].strip() | |
return response | |
# Gradio interface | |
gui = gr.ChatInterface(fn=chat, title="Gemma-3 Chatbot", description="Fine-tuned on adarsh3601/my_gemma_pt3") | |
gui.launch() | |