|
import os |
|
import streamlit as st |
|
from huggingface_hub import login |
|
from transformers import MllamaForConditionalGeneration, AutoProcessor |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
if huggingface_token: |
|
login(token=huggingface_token) |
|
else: |
|
st.error("Hugging Face token not found. Please set it in the Secrets section.") |
|
|
|
|
|
try: |
|
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" |
|
model = MllamaForConditionalGeneration.from_pretrained( |
|
model_name, |
|
use_auth_token=huggingface_token, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
) |
|
processor = AutoProcessor.from_pretrained( |
|
model_name, |
|
use_auth_token=huggingface_token, |
|
) |
|
st.success("Model and processor loaded successfully!") |
|
except Exception as e: |
|
st.error(f"Error loading model or processor: {str(e)}") |
|
|
|
|
|
def main(): |
|
st.title("Llama 3.2 11B Vision Model") |
|
st.write("Upload an image and enter a prompt to generate output.") |
|
|
|
|
|
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
prompt = st.text_area("Enter your prompt here:") |
|
|
|
if st.button("Generate Output"): |
|
if image_file and prompt: |
|
|
|
image = Image.open(image_file).convert("RGB") |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
try: |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}, |
|
{"type": "image"} |
|
] |
|
} |
|
] |
|
|
|
|
|
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
|
|
|
inputs = processor( |
|
text=input_text, |
|
images=[image], |
|
return_tensors="pt" |
|
).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=250, |
|
) |
|
|
|
|
|
output_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
|
|
if input_text in output_text: |
|
generated_output = output_text.replace(input_text, "").strip() |
|
else: |
|
generated_output = output_text.strip() |
|
|
|
st.write("Generated Output:", generated_output) |
|
except Exception as e: |
|
st.error(f"Error during prediction: {str(e)}") |
|
else: |
|
st.warning("Please upload an image and enter a prompt.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|