|
import torch |
|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
from peft import PeftModel |
|
import gradio as gr |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_name = "adarsh3601/my_gemma_pt3" |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = AutoModelForImageTextToText.from_pretrained(model_name).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat(prompt): |
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
inputs = processor(messages, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, max_length=200) |
|
|
|
|
|
return processor.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
gr.Interface( |
|
fn=chat, |
|
inputs="text", |
|
outputs="text", |
|
title="Gemma Chat Model", |
|
description="Chat with Gemma3 model", |
|
live=True |
|
).launch(share=False) |
|
|