randeom commited on
Commit
db63f1a
1 Parent(s): 1d33274

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -36
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
2
  from huggingface_hub import InferenceClient
 
3
 
 
4
  client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
5
 
6
  def format_prompt(message, history, system_prompt=""):
@@ -14,11 +16,7 @@ def format_prompt(message, history, system_prompt=""):
14
  return prompt
15
 
16
  def generate(prompt, history, system_prompt="", temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
17
- temperature = float(temperature)
18
- if temperature < 1e-2:
19
- temperature = 1e-2
20
- top_p = float(top_p)
21
-
22
  generate_kwargs = dict(
23
  temperature=temperature,
24
  max_new_tokens=max_new_tokens,
@@ -29,41 +27,48 @@ def generate(prompt, history, system_prompt="", temperature=0.9, max_new_tokens=
29
  )
30
 
31
  formatted_prompt = format_prompt(prompt, history, system_prompt)
32
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
33
- output = ""
34
-
35
- for response in stream:
36
- output += response.token.text
37
- st.session_state.generated_text = output
 
 
 
38
 
39
- return output
 
40
 
41
- # Streamlit UI
42
- st.title("Waifu Character Generator")
 
 
 
 
43
 
44
- # User inputs
45
- name = st.text_input("Name of the Waifu")
46
- hair_color = st.selectbox("Hair Color", ["Blonde", "Brunette", "Red", "Black", "Blue", "Pink"])
47
- personality = st.selectbox("Personality", ["Tsundere", "Yandere", "Kuudere", "Dandere", "Genki", "Normal"])
48
- outfit_style = st.selectbox("Outfit Style", ["School Uniform", "Maid Outfit", "Casual", "Kimono", "Gothic Lolita"])
49
- system_prompt = st.text_input("Optional System Prompt", "")
50
 
51
- # Advanced settings
52
- temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05)
53
- max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64)
54
- top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
55
- repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)
56
 
57
- # Initialize session state for generated text
58
- if "generated_text" not in st.session_state:
59
- st.session_state.generated_text = ""
 
 
 
60
 
61
- # Generate button
62
- if st.button("Generate Waifu"):
63
- history = []
64
- prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}."
65
- generate(prompt, history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
66
-
67
  # Display the generated character
68
- st.subheader("Generated Waifu Character")
69
- st.write(st.session_state.generated_text)
 
 
 
 
 
1
  import streamlit as st
2
  from huggingface_hub import InferenceClient
3
+ import time
4
 
5
+ # Initialize the HuggingFace Inference Client
6
  client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
7
 
8
  def format_prompt(message, history, system_prompt=""):
 
16
  return prompt
17
 
18
  def generate(prompt, history, system_prompt="", temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
19
+ temperature = max(temperature, 1e-2)
 
 
 
 
20
  generate_kwargs = dict(
21
  temperature=temperature,
22
  max_new_tokens=max_new_tokens,
 
27
  )
28
 
29
  formatted_prompt = format_prompt(prompt, history, system_prompt)
30
+ try:
31
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
32
+ output = ""
33
+ for response in stream:
34
+ output += response.token.text
35
+ return output
36
+ except Exception as e:
37
+ st.error(f"Error generating text: {e}")
38
+ return ""
39
 
40
+ def main():
41
+ st.title("Waifu Character Generator")
42
 
43
+ # User inputs
44
+ name = st.text_input("Name of the Waifu")
45
+ hair_color = st.selectbox("Hair Color", ["Blonde", "Brunette", "Red", "Black", "Blue", "Pink"])
46
+ personality = st.selectbox("Personality", ["Tsundere", "Yandere", "Kuudere", "Dandere", "Genki", "Normal"])
47
+ outfit_style = st.selectbox("Outfit Style", ["School Uniform", "Maid Outfit", "Casual", "Kimono", "Gothic Lolita"])
48
+ system_prompt = st.text_input("Optional System Prompt", "")
49
 
50
+ # Advanced settings
51
+ with st.expander("Advanced Settings"):
52
+ temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05)
53
+ max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64)
54
+ top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
55
+ repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)
56
 
57
+ # Initialize session state for generated text
58
+ if "generated_text" not in st.session_state:
59
+ st.session_state.generated_text = ""
 
 
60
 
61
+ # Generate button
62
+ if st.button("Generate Waifu"):
63
+ with st.spinner("Generating waifu character..."):
64
+ history = []
65
+ prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}."
66
+ st.session_state.generated_text = generate(prompt, history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
67
 
 
 
 
 
 
 
68
  # Display the generated character
69
+ if st.session_state.generated_text:
70
+ st.subheader("Generated Waifu Character")
71
+ st.write(st.session_state.generated_text)
72
+
73
+ if __name__ == "__main__":
74
+ main()