ApiEndPointDemo / app.py
Gajendra5490's picture
Update app.py
71976e8 verified
raw
history blame contribute delete
961 Bytes
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava.model.builder import load_pretrained_model # Import LLaVA model builder
# Model name
model_name = "MONAI/Llama3-VILA-M3-8B"
# Load LLaVA model
tokenizer, model, _ = load_pretrained_model(model_path=model_name, model_base=None, device="cuda" if torch.cuda.is_available() else "cpu")
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
output = model.generate(**inputs, max_length=200)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Gradio Interface
iface = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt..."),
outputs="text",
title="LLaVA Llama3-VILA-M3-8B Chatbot",
description="A chatbot powered by LLaVA and Llama3-VILA-M3-8B",
)
iface.launch()