import streamlit as st from huggingface_hub import InferenceClient from gradio_client import Client import re # Set the page config st.set_page_config(layout="wide") # Load custom CSS with open('style.css') as f: st.markdown(f'', unsafe_allow_html=True) # Initialize the HuggingFace Inference Client text_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1") image_client = Client("phenixrhyder/nsfw-waifu-gradio") def format_prompt_for_description(caption_text): prompt = f"Generate a funny and relatable meme caption for Pepe the Frog: {caption_text}" return prompt def format_prompt_for_image(caption_text): prompt = f"Generate an image prompt for a Pepe the Frog meme with the following caption: {caption_text}" return prompt def clean_generated_text(text): # Remove any unwanted trailing tags or characters like clean_text = re.sub(r'$', '', text).strip() return clean_text def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0): temperature = max(temperature, 1e-2) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, ) try: stream = text_client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text return clean_generated_text(output) except Exception as e: st.error(f"Error generating text: {e}") return "" # Updated part for the new API def generate_image(prompt): try: result = image_client.predict( param_0=prompt, api_name="/predict" ) # Process and display the result if result: return [result] # Assuming the API returns a single image path as a result else: st.error("Unexpected result format from the Gradio API.") return None except Exception as e: st.error(f"Error generating image: {e}") st.write("Full error details:", e) return None def main(): st.title("Pepe Meme Generator") # User inputs col1, col2 = st.columns(2) with col1: caption_text = st.text_input("Enter a caption or meme idea for Pepe") # Advanced settings with st.expander("Advanced Settings"): temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05) max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64) top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05) repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05) # Initialize session state for generated text and image prompt if "meme_caption" not in st.session_state: st.session_state.meme_caption = "" if "image_prompt" not in st.session_state: st.session_state.image_prompt = "" if "image_paths" not in st.session_state: st.session_state.image_paths = [] # Generate button if st.button("Generate Pepe Meme"): with st.spinner("Generating Pepe meme..."): description_prompt = format_prompt_for_description(caption_text) image_prompt = format_prompt_for_image(caption_text) # Generate meme caption st.session_state.meme_caption = generate_text(description_prompt, temperature, max_new_tokens, top_p, repetition_penalty) # Generate image prompt st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty) # Generate image from image prompt st.session_state.image_paths = generate_image(st.session_state.image_prompt) st.success("Pepe meme generated!") with col2: # Display the generated meme caption and image prompt if st.session_state.meme_caption: st.subheader("Generated Meme Caption") st.write(st.session_state.meme_caption) if st.session_state.image_prompt: st.subheader("Image Prompt") st.write(st.session_state.image_prompt) if st.session_state.image_paths: st.subheader("Generated Image") for image_path in st.session_state.image_paths: st.image(image_path, caption="Generated Pepe Meme Image") if __name__ == "__main__": main()