File size: 5,767 Bytes
19cef8f
b23b2e1
74e0f7a
707e3ee
19cef8f
5a59b16
 
 
f593040
 
 
 
db63f1a
74e0f7a
fe28822
19cef8f
7249a81
efdf463
 
b23b2e1
 
7249a81
efdf463
46a4fa5
 
707e3ee
 
 
 
 
0833433
db63f1a
b23b2e1
 
 
 
 
 
 
 
db63f1a
74e0f7a
db63f1a
 
 
707e3ee
db63f1a
 
 
b23b2e1
1dc7336
74e0f7a
 
 
1dc7336
b1a6c07
74e0f7a
a73ddf4
 
1dc7336
cfe386b
 
 
74e0f7a
f49bf68
 
74e0f7a
 
db63f1a
7249a81
19cef8f
7249a81
 
 
f593040
 
 
 
7249a81
 
 
 
 
77a9b4d
 
 
 
 
 
7249a81
77a9b4d
 
 
 
 
 
 
7249a81
77a9b4d
 
 
 
 
7249a81
77a9b4d
 
 
 
 
 
 
 
 
 
7249a81
77a9b4d
 
 
 
 
 
 
 
 
 
 
 
db63f1a
 
ab9dfa6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
from huggingface_hub import InferenceClient
from gradio_client import Client
import re

# Set the page config
st.set_page_config(layout="wide")

# Load custom CSS
with open('style.css') as f:
    st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)

# Initialize the HuggingFace Inference Client
text_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
image_client = Client("phenixrhyder/nsfw-waifu-gradio")

def format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story):
    prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. "
    prompt += f"Her hobbies include {hobbies}. Her favorite food is {favorite_food}. Here is her background story: {background_story}."
    return prompt

def format_prompt_for_image(name, hair_color, personality, outfit_style):
    prompt = f"Generate an image prompt for a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}."
    return prompt

def clean_generated_text(text):
    # Remove any unwanted trailing tags or characters like </s>
    clean_text = re.sub(r'</s>$', '', text).strip()
    return clean_text

def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    temperature = max(temperature, 1e-2)
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    try:
        stream = text_client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
        output = ""
        for response in stream:
            output += response.token.text
        return clean_generated_text(output)
    except Exception as e:
        st.error(f"Error generating text: {e}")
        return ""

# Updated part for the new API
def generate_image(prompt):
    try:
        result = image_client.predict(
            param_0=prompt,
            api_name="/predict"
        )
        # Process and display the result
        if result:
            return [result]  # Assuming the API returns a single image path as a result
        else:
            st.error("Unexpected result format from the Gradio API.")
            return None
    except Exception as e:
        st.error(f"Error generating image: {e}")
        st.write("Full error details:", e)
        return None

def main():
    st.title("Enhanced Waifu Character Generator")

    # User inputs
    col1, col2 = st.columns(2)
    with col1:
        name = st.text_input("Name of the Waifu")
        hair_color = st.selectbox("Hair Color", ["Blonde", "Brunette", "Red", "Black", "Blue", "Pink"])
        personality = st.selectbox("Personality", ["Tsundere", "Yandere", "Kuudere", "Dandere", "Genki", "Normal"])
        outfit_style = st.selectbox("Outfit Style", ["School Uniform", "Maid Outfit", "Casual", "Kimono", "Gothic Lolita"])
        hobbies = st.text_input("Hobbies")
        favorite_food = st.text_input("Favorite Food")
        background_story = st.text_area("Background Story")
        system_prompt = st.text_input("Optional System Prompt", "")

        # Advanced settings
        with st.expander("Advanced Settings"):
            temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05)
            max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64)
            top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
            repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)

        # Initialize session state for generated text and image prompt
        if "character_description" not in st.session_state:
            st.session_state.character_description = ""
        if "image_prompt" not in st.session_state:
            st.session_state.image_prompt = ""
        if "image_paths" not in st.session_state:
            st.session_state.image_paths = []

        # Generate button
        if st.button("Generate Waifu"):
            with st.spinner("Generating waifu character..."):
                description_prompt = format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story)
                image_prompt = format_prompt_for_image(name, hair_color, personality, outfit_style)

                # Generate character description
                st.session_state.character_description = generate_text(description_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
                
                # Generate image prompt
                st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
                
                # Generate image from image prompt
                st.session_state.image_paths = generate_image(st.session_state.image_prompt)
                
                st.success("Waifu character generated!")

    with col2:
        # Display the generated character and image prompt
        if st.session_state.character_description:
            st.subheader("Generated Waifu Character")
            st.write(st.session_state.character_description)
        if st.session_state.image_prompt:
            st.subheader("Image Prompt")
            st.write(st.session_state.image_prompt)
        if st.session_state.image_paths:
            st.subheader("Generated Image")
            for image_path in st.session_state.image_paths:
                st.image(image_path, caption="Generated Waifu Image")

if __name__ == "__main__":
    main()