Spaces:
Sleeping
Sleeping
import streamlit as st | |
from huggingface_hub import InferenceClient | |
from gradio_client import Client | |
import re | |
# Set the page config | |
st.set_page_config(layout="wide") | |
# Load custom CSS | |
with open('style.css') as f: | |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
# Initialize the HuggingFace Inference Client | |
text_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1") | |
image_client = Client("phenixrhyder/nsfw-waifu-gradio") | |
def format_prompt_for_description(caption_text): | |
prompt = f"Generate a funny and relatable meme caption for Pepe the Frog: {caption_text}" | |
return prompt | |
def format_prompt_for_image(caption_text): | |
prompt = f"Generate an image prompt for a Pepe the Frog meme with the following caption: {caption_text}" | |
return prompt | |
def clean_generated_text(text): | |
# Remove any unwanted trailing tags or characters like </s> | |
clean_text = re.sub(r'</s>$', '', text).strip() | |
return clean_text | |
def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0): | |
temperature = max(temperature, 1e-2) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
try: | |
stream = text_client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
return clean_generated_text(output) | |
except Exception as e: | |
st.error(f"Error generating text: {e}") | |
return "" | |
# Updated part for the new API | |
def generate_image(prompt): | |
try: | |
result = image_client.predict( | |
param_0=prompt, | |
api_name="/predict" | |
) | |
# Process and display the result | |
if result: | |
return [result] # Assuming the API returns a single image path as a result | |
else: | |
st.error("Unexpected result format from the Gradio API.") | |
return None | |
except Exception as e: | |
st.error(f"Error generating image: {e}") | |
st.write("Full error details:", e) | |
return None | |
def main(): | |
st.title("Pepe Meme Generator") | |
# User inputs | |
col1, col2 = st.columns(2) | |
with col1: | |
caption_text = st.text_input("Enter a caption or meme idea for Pepe") | |
# Advanced settings | |
with st.expander("Advanced Settings"): | |
temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05) | |
max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64) | |
top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05) | |
repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05) | |
# Initialize session state for generated text and image prompt | |
if "meme_caption" not in st.session_state: | |
st.session_state.meme_caption = "" | |
if "image_prompt" not in st.session_state: | |
st.session_state.image_prompt = "" | |
if "image_paths" not in st.session_state: | |
st.session_state.image_paths = [] | |
# Generate button | |
if st.button("Generate Pepe Meme"): | |
with st.spinner("Generating Pepe meme..."): | |
description_prompt = format_prompt_for_description(caption_text) | |
image_prompt = format_prompt_for_image(caption_text) | |
# Generate meme caption | |
st.session_state.meme_caption = generate_text(description_prompt, temperature, max_new_tokens, top_p, repetition_penalty) | |
# Generate image prompt | |
st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty) | |
# Generate image from image prompt | |
st.session_state.image_paths = generate_image(st.session_state.image_prompt) | |
st.success("Pepe meme generated!") | |
with col2: | |
# Display the generated meme caption and image prompt | |
if st.session_state.meme_caption: | |
st.subheader("Generated Meme Caption") | |
st.write(st.session_state.meme_caption) | |
if st.session_state.image_prompt: | |
st.subheader("Image Prompt") | |
st.write(st.session_state.image_prompt) | |
if st.session_state.image_paths: | |
st.subheader("Generated Image") | |
for image_path in st.session_state.image_paths: | |
st.image(image_path, caption="Generated Pepe Meme Image") | |
if __name__ == "__main__": | |
main() |