TestingModelAPI / app.py
made1570's picture
Update app.py
c742ff6 verified
raw
history blame
1.39 kB
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel
import gradio as gr
# Set up device (CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load processor and model
model_name = "adarsh3601/my_gemma_pt3" # Change to your model path
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForImageTextToText.from_pretrained(model_name).to(device)
# Optional: If using PEFT model with adapter
# adapter_model_id = "your_adapter_model_id" # Uncomment and replace if using adapter
# model = PeftModel.from_pretrained(model, adapter_model_id)
# Define function to process the user input
def chat(prompt):
# Prepare the message in the format the model expects
messages = [{"role": "user", "content": prompt}]
# Process the input using the processor
inputs = processor(messages, return_tensors="pt").to(device)
# Generate the output from the model
with torch.no_grad():
outputs = model.generate(**inputs, max_length=200)
# Decode and return the response
return processor.decode(outputs[0], skip_special_tokens=True)
# Gradio interface
gr.Interface(
fn=chat,
inputs="text",
outputs="text",
title="Gemma Chat Model",
description="Chat with Gemma3 model",
live=True
).launch(share=False) # share=False for Hugging Face Spaces