fantaxy commited on
Commit
9185b08
ยท
verified ยท
1 Parent(s): 7c37716

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -37
app.py CHANGED
@@ -16,24 +16,40 @@ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %
16
  hf_client = InferenceClient("CohereForAI/c4ai-command-r-plus-08-2024", token=os.getenv("HF_TOKEN"))
17
  IMAGE_API_URL = "http://211.233.58.201:7896"
18
 
19
- def generate_image(prompt: str) -> tuple:
20
- """์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  try:
22
  client = Client(IMAGE_API_URL)
23
- # ํ”„๋กฌํ”„ํŠธ ์•ž์— "fantasy style," ์ถ”๊ฐ€
24
- enhanced_prompt = f"fantasy style, {prompt}"
25
- result = client.predict(
26
- prompt=enhanced_prompt,
27
- width=768,
28
- height=768,
29
- guidance=7.5,
30
- inference_steps=30,
31
- seed=3,
32
- do_img2img=False,
33
- init_image=None,
34
- image2image_strength=0.8,
35
- resize_img=True,
36
- api_name="/generate_image"
 
37
  )
38
  return result[0], result[1]
39
  except Exception as e:
@@ -49,6 +65,7 @@ def respond(
49
  top_p=0.9,
50
  ):
51
  system_prefix = """
 
52
  [์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ๋‚ด์šฉ...]
53
  """
54
 
@@ -62,7 +79,6 @@ def respond(
62
 
63
  response = ""
64
  try:
65
- # ํ…์ŠคํŠธ ์ƒ์„ฑ
66
  for msg in hf_client.chat_completion(
67
  messages,
68
  max_tokens=max_tokens,
@@ -73,23 +89,27 @@ def respond(
73
  token = msg.choices[0].delta.content
74
  if token is not None:
75
  response += token.strip("")
76
- # ์ฑ„ํŒ… ํžˆ์Šคํ† ๋ฆฌ์— ๋งž๋Š” ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜
77
  history = history + [(message, response)]
78
- yield history, None
79
-
80
- # ํ…์ŠคํŠธ ์ƒ์„ฑ์ด ์™„๋ฃŒ๋œ ํ›„ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
81
- try:
82
- image, seed = generate_image(response[:200])
83
- history = history[:-1] + [(message, response)] # ๋งˆ์ง€๋ง‰ ๋ฉ”์‹œ์ง€ ์—…๋ฐ์ดํŠธ
84
- yield history, image
85
- except Exception as e:
86
- logging.error(f"Image generation failed: {str(e)}")
87
- yield history, None
88
 
89
  except Exception as e:
90
  error_message = f"Error: {str(e)}"
91
  history = history + [(message, error_message)]
92
- yield history, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
95
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as interface:
@@ -103,11 +123,14 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as interface:
103
  label="Chat History",
104
  height=500
105
  )
106
- msg = gr.Textbox(
107
- label="Enter your message",
108
- placeholder="Type your message here...",
109
- lines=2
110
- )
 
 
 
111
  system_msg = gr.Textbox(
112
  label="System Message",
113
  value="Write(output) in ํ•œ๊ตญ์–ด.",
@@ -156,10 +179,16 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as interface:
156
  )
157
 
158
  # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
 
 
 
 
 
 
159
  msg.submit(
160
- respond,
161
- [msg, chatbot, system_msg, max_tokens, temperature, top_p],
162
- [chatbot, image_output]
163
  )
164
 
165
  # ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์‹คํ–‰
 
16
  hf_client = InferenceClient("CohereForAI/c4ai-command-r-plus-08-2024", token=os.getenv("HF_TOKEN"))
17
  IMAGE_API_URL = "http://211.233.58.201:7896"
18
 
19
+ async def generate_image_prompt(text: str) -> str:
20
+ """์ž…๋ ฅ ํ…์ŠคํŠธ๋กœ๋ถ€ํ„ฐ ์ด๋ฏธ์ง€ ์ƒ์„ฑ์šฉ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
21
+ try:
22
+ prompt_messages = [
23
+ {"role": "system", "content": "You are an expert at creating image generation prompts. Convert the input text into a detailed image generation prompt that describes the visual elements, style, and atmosphere."},
24
+ {"role": "user", "content": f"Create an image generation prompt for: {text}"}
25
+ ]
26
+
27
+ response = hf_client.chat_completion(prompt_messages, max_tokens=200)
28
+ image_prompt = response.choices[0].message.content
29
+ return f"fantasy style, {image_prompt}"
30
+ except Exception as e:
31
+ logging.error(f"Image prompt generation failed: {str(e)}")
32
+ return f"fantasy style, {text}"
33
+
34
+ async def generate_image_async(prompt: str) -> tuple:
35
+ """๋น„๋™๊ธฐ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜"""
36
  try:
37
  client = Client(IMAGE_API_URL)
38
+ result = await asyncio.get_event_loop().run_in_executor(
39
+ ThreadPoolExecutor(),
40
+ lambda: client.predict(
41
+ prompt=prompt,
42
+ width=768,
43
+ height=768,
44
+ guidance=7.5,
45
+ inference_steps=30,
46
+ seed=3,
47
+ do_img2img=False,
48
+ init_image=None,
49
+ image2image_strength=0.8,
50
+ resize_img=True,
51
+ api_name="/generate_image"
52
+ )
53
  )
54
  return result[0], result[1]
55
  except Exception as e:
 
65
  top_p=0.9,
66
  ):
67
  system_prefix = """
68
+ You are 'FantasyAIโœจ', an advanced AI storyteller specialized in creating immersive fantasy narratives.
69
  [์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ๋‚ด์šฉ...]
70
  """
71
 
 
79
 
80
  response = ""
81
  try:
 
82
  for msg in hf_client.chat_completion(
83
  messages,
84
  max_tokens=max_tokens,
 
89
  token = msg.choices[0].delta.content
90
  if token is not None:
91
  response += token.strip("")
 
92
  history = history + [(message, response)]
93
+ yield history
 
 
 
 
 
 
 
 
 
94
 
95
  except Exception as e:
96
  error_message = f"Error: {str(e)}"
97
  history = history + [(message, error_message)]
98
+ yield history
99
+
100
+ async def process_input(message, history, system_message, max_tokens, temperature, top_p):
101
+ """์ž…๋ ฅ ์ฒ˜๋ฆฌ ๋ฐ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ†ตํ•ฉ ํ•จ์ˆ˜"""
102
+ # ์ด๋ฏธ์ง€ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ ๋ฐ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹œ์ž‘
103
+ image_prompt = await generate_image_prompt(message)
104
+ image_future = asyncio.create_task(generate_image_async(image_prompt))
105
+
106
+ # ํ…์ŠคํŠธ ์‘๋‹ต ์ƒ์„ฑ
107
+ text_response = respond(message, history, system_message, max_tokens, temperature, top_p)
108
+
109
+ # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์™„๋ฃŒ ๋Œ€๊ธฐ
110
+ image, seed = await image_future
111
+
112
+ return text_response, image
113
 
114
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
115
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as interface:
 
123
  label="Chat History",
124
  height=500
125
  )
126
+ with gr.Row():
127
+ msg = gr.Textbox(
128
+ label="Enter your message",
129
+ placeholder="Type your message here...",
130
+ lines=2
131
+ )
132
+ submit_btn = gr.Button("Submit", variant="primary")
133
+
134
  system_msg = gr.Textbox(
135
  label="System Message",
136
  value="Write(output) in ํ•œ๊ตญ์–ด.",
 
179
  )
180
 
181
  # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
182
+ submit_btn.click(
183
+ fn=process_input,
184
+ inputs=[msg, chatbot, system_msg, max_tokens, temperature, top_p],
185
+ outputs=[chatbot, image_output]
186
+ )
187
+
188
  msg.submit(
189
+ fn=process_input,
190
+ inputs=[msg, chatbot, system_msg, max_tokens, temperature, top_p],
191
+ outputs=[chatbot, image_output]
192
  )
193
 
194
  # ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์‹คํ–‰