Spaces:
Build error
Build error
| import streamlit as st | |
| from huggingface_hub import InferenceClient | |
| from gradio_client import Client | |
| import re | |
| # Set the page config | |
| st.set_page_config(layout="wide") | |
| # Load custom CSS | |
| with open('style.css') as f: | |
| st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
| # Initialize the HuggingFace Inference Client | |
| text_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1") | |
| image_client = Client("phenixrhyder/nsfw-waifu-gradio") | |
| def format_prompt_for_description(caption_text): | |
| prompt = f"Generate a funny and relatable meme caption for Pepe the Frog: {caption_text}" | |
| return prompt | |
| def format_prompt_for_image(caption_text): | |
| prompt = f"Generate an image prompt for a Pepe the Frog meme with the following caption: {caption_text}" | |
| return prompt | |
| def clean_generated_text(text): | |
| # Remove any unwanted trailing tags or characters like </s> | |
| clean_text = re.sub(r'</s>$', '', text).strip() | |
| return clean_text | |
| def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0): | |
| temperature = max(temperature, 1e-2) | |
| generate_kwargs = dict( | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| seed=42, | |
| ) | |
| try: | |
| stream = text_client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
| output = "" | |
| for response in stream: | |
| output += response.token.text | |
| return clean_generated_text(output) | |
| except Exception as e: | |
| st.error(f"Error generating text: {e}") | |
| return "" | |
| # Updated part for the new API | |
| def generate_image(prompt): | |
| try: | |
| result = image_client.predict( | |
| param_0=prompt, | |
| api_name="/predict" | |
| ) | |
| # Process and display the result | |
| if result: | |
| return [result] # Assuming the API returns a single image path as a result | |
| else: | |
| st.error("Unexpected result format from the Gradio API.") | |
| return None | |
| except Exception as e: | |
| st.error(f"Error generating image: {e}") | |
| st.write("Full error details:", e) | |
| return None | |
| def main(): | |
| st.title("Pepe Meme Generator") | |
| # User inputs | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| caption_text = st.text_input("Enter a caption or meme idea for Pepe") | |
| # Advanced settings | |
| with st.expander("Advanced Settings"): | |
| temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05) | |
| max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64) | |
| top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05) | |
| repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05) | |
| # Initialize session state for generated text and image prompt | |
| if "meme_caption" not in st.session_state: | |
| st.session_state.meme_caption = "" | |
| if "image_prompt" not in st.session_state: | |
| st.session_state.image_prompt = "" | |
| if "image_paths" not in st.session_state: | |
| st.session_state.image_paths = [] | |
| # Generate button | |
| if st.button("Generate Pepe Meme"): | |
| with st.spinner("Generating Pepe meme..."): | |
| description_prompt = format_prompt_for_description(caption_text) | |
| image_prompt = format_prompt_for_image(caption_text) | |
| # Generate meme caption | |
| st.session_state.meme_caption = generate_text(description_prompt, temperature, max_new_tokens, top_p, repetition_penalty) | |
| # Generate image prompt | |
| st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty) | |
| # Generate image from image prompt | |
| st.session_state.image_paths = generate_image(st.session_state.image_prompt) | |
| st.success("Pepe meme generated!") | |
| with col2: | |
| # Display the generated meme caption and image prompt | |
| if st.session_state.meme_caption: | |
| st.subheader("Generated Meme Caption") | |
| st.write(st.session_state.meme_caption) | |
| if st.session_state.image_prompt: | |
| st.subheader("Image Prompt") | |
| st.write(st.session_state.image_prompt) | |
| if st.session_state.image_paths: | |
| st.subheader("Generated Image") | |
| for image_path in st.session_state.image_paths: | |
| st.image(image_path, caption="Generated Pepe Meme Image") | |
| if __name__ == "__main__": | |
| main() |