import streamlit as st from huggingface_hub import InferenceClient from gradio_client import Client import re # Set the page config st.set_page_config(layout="wide") # Load custom CSS with open('style.css') as f: st.markdown(f'', unsafe_allow_html=True) # Initialize the HuggingFace Inference Client text_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1") image_client = Client("phenixrhyder/nsfw-waifu-gradio") def format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story): prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. " prompt += f"Her hobbies include {hobbies}. Her favorite food is {favorite_food}. Here is her background story: {background_story}." return prompt def format_prompt_for_image(name, hair_color, personality, outfit_style): prompt = f"Generate an image prompt for a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}." return prompt def clean_generated_text(text): # Remove any unwanted trailing tags or characters like clean_text = re.sub(r'$', '', text).strip() return clean_text def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0): temperature = max(temperature, 1e-2) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, ) try: stream = text_client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text return clean_generated_text(output) except Exception as e: st.error(f"Error generating text: {e}") return "" # Updated part for the new API def generate_image(prompt): try: result = image_client.predict( param_0=prompt, api_name="/predict" ) # Process and display the result if result: return [result] # Assuming the API returns a single image path as a result else: st.error("Unexpected result format from the Gradio API.") return None except Exception as e: st.error(f"Error generating image: {e}") st.write("Full error details:", e) return None def main(): st.title("Enhanced Waifu Character Generator") # User inputs col1, col2 = st.columns(2) with col1: name = st.text_input("Name of the Waifu") hair_color = st.selectbox("Hair Color", ["Blonde", "Brunette", "Red", "Black", "Blue", "Pink"]) personality = st.selectbox("Personality", ["Tsundere", "Yandere", "Kuudere", "Dandere", "Genki", "Normal"]) outfit_style = st.selectbox("Outfit Style", ["School Uniform", "Maid Outfit", "Casual", "Kimono", "Gothic Lolita"]) hobbies = st.text_input("Hobbies") favorite_food = st.text_input("Favorite Food") background_story = st.text_area("Background Story") system_prompt = st.text_input("Optional System Prompt", "") # Advanced settings with st.expander("Advanced Settings"): temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05) max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64) top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05) repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05) # Initialize session state for generated text and image prompt if "character_description" not in st.session_state: st.session_state.character_description = "" if "image_prompt" not in st.session_state: st.session_state.image_prompt = "" if "image_paths" not in st.session_state: st.session_state.image_paths = [] # Generate button if st.button("Generate Waifu"): with st.spinner("Generating waifu character..."): description_prompt = format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story) image_prompt = format_prompt_for_image(name, hair_color, personality, outfit_style) # Generate character description st.session_state.character_description = generate_text(description_prompt, temperature, max_new_tokens, top_p, repetition_penalty) # Generate image prompt st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty) # Generate image from image prompt st.session_state.image_paths = generate_image(st.session_state.image_prompt) st.success("Waifu character generated!") with col2: # Display the generated character and image prompt if st.session_state.character_description: st.subheader("Generated Waifu Character") st.write(st.session_state.character_description) if st.session_state.image_prompt: st.subheader("Image Prompt") st.write(st.session_state.image_prompt) if st.session_state.image_paths: st.subheader("Generated Image") for image_path in st.session_state.image_paths: st.image(image_path, caption="Generated Waifu Image") if __name__ == "__main__": main()