import streamlit as st from PIL import Image import base64 import requests import json import os import re import torch from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer import argparse import io from utils.model_utils import get_model_caption from utils.image_utils import overlay_caption @st.cache_resource def load_models(): base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") model_angry = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_angry") model_happy = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_happy") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") base_model.to(device) model_happy.to(device) model_angry.to(device) # Load the adapters for specific moods base_model.load_adapter("NursNurs/outputs_gemma2b_happy", "happy") base_model.load_adapter("NursNurs/outputs_gemma2b_angry", "angry") return base_model, tokenizer, model_happy, model_angry, device # x = st.slider('Select a value') # st.write(x, 'squared is', x * x) def generate_meme_from_image(img_path, base_model, tokenizer, hf_token, output_dir, device='cuda'): caption = get_model_caption(img_path, base_model, tokenizer, hf_token) image = overlay_caption(caption, img_path, output_dir) return image, caption st.title("Image Upload and Processing App") def main(): st.title("Meme Generator with Mood") base_model, tokenizer, model_happy, model_angry, device = load_models() # Input widget to upload an image uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"]) # Input widget to add Hugging Face token hf_token = st.text_input("Enter your Hugging Face Token", type="password") # Dropdown to select mood # mood = st.selectbox("Select Mood", options=["happy", "angry"]) # Directory for saving the meme (optional, but you can let users set this if needed) output_dir = "results" if uploaded_image is not None and hf_token: # Convert uploaded image to a PIL image img = Image.open(uploaded_image) # Generate meme when button is pressed if st.button("Generate Meme"): with st.spinner('Generating meme...'): image, caption = generate_meme_from_image(img, base_model, tokenizer, hf_token, device) # Display the output st.image(image, caption=f"Generated Meme: {caption}") # Optionally allow downloading the meme buf = io.BytesIO() image.save(buf, format="PNG") byte_im = buf.getvalue() st.download_button( label="Download Meme", data=byte_im, file_name="generated_meme.png", mime="image/png" ) if __name__ == '__main__': main() # # Upload the image # uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"]) # # Process and display if image is uploaded # if uploaded_image is not None: # image = Image.open(uploaded_image) # st.image(image, caption="Uploaded Image", use_column_width=True)