Spaces:
Running
Running
import streamlit as st | |
import torch | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForImageTextToText | |
# Set page configuration | |
st.set_page_config(page_title="Llama 3.2 Vision Model", page_icon="???") | |
# Title and description | |
st.title("Llama 3.2 Vision Model Inference") | |
st.write("Upload an image and provide a prompt to get model insights!") | |
# Load model and processor (consider caching to improve performance) | |
def load_model(): | |
try: | |
processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-90B-Vision-Instruct") | |
model = AutoModelForImageTextToText.from_pretrained("meta-llama/Llama-3.2-90B-Vision-Instruct") | |
return processor, model | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
return None, None | |
# Inference function | |
def generate_response(image, prompt): | |
processor, model = load_model() | |
if not processor or not model: | |
return "Model could not be loaded." | |
try: | |
# Prepare inputs | |
inputs = processor(images=image, text=prompt, return_tensors="pt") | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate(**inputs) | |
# Decode the response | |
response = processor.decode(outputs[0], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
st.error(f"Error during inference: {e}") | |
return "An error occurred during image processing." | |
# Sidebar for user inputs | |
st.sidebar.header("Image and Prompt") | |
# Image uploader | |
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
# Prompt input | |
prompt = st.sidebar.text_input("Enter your prompt:", | |
placeholder="Describe what you want to know about the image") | |
# Main content area | |
if uploaded_file is not None: | |
# Display uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Generate button | |
if st.sidebar.button("Generate Response"): | |
if prompt: | |
# Show loading spinner | |
with st.spinner("Generating response..."): | |
response = generate_response(image, prompt) | |
# Display response | |
st.subheader("Model Response") | |
st.write(response) | |
else: | |
st.warning("Please enter a prompt!") | |
else: | |
st.info("Upload an image and enter a prompt to get started!") | |
# Additional error handling and information | |
st.sidebar.markdown("---") | |
st.sidebar.info("Note: Model performance depends on image quality and prompt specificity.") | |