Kims12 commited on
Commit
ea07ddd
·
verified ·
1 Parent(s): b80a15f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -113
app.py CHANGED
@@ -4,6 +4,7 @@ import openai
4
  import requests
5
  import logging
6
 
 
7
  # 로깅 설정
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
@@ -115,7 +116,6 @@ with gr.Blocks() as iface:
115
  output_box = gr.Textbox(label=category, visible=True)
116
  output_boxes[category] = output_box
117
 
118
- # 기존 동기 함수 (수정/삭제 금지)
119
  def validate_and_generate(topic):
120
  try:
121
  results = generate_copywriting(CATEGORIES, topic)
@@ -132,127 +132,112 @@ with gr.Blocks() as iface:
132
  except Exception as e:
133
  logger.error(f"Error during copywriting generation: {str(e)}")
134
  return [gr.update(value=f"오류 발생: {str(e)}")] + [gr.update(value="") for _ in CATEGORIES]
135
-
136
- ##########################################
137
- # 추가된 비동기/병렬 처리용 코드 시작
138
- ##########################################
139
- import asyncio
140
- import aiohttp
141
 
142
- async def call_api_async(content, system_message, max_tokens, temperature, top_p):
143
- """
144
- 비동기적으로 OpenAI API를 호출하는 함수
145
- """
146
- url = "https://api.openai.com/v1/chat/completions"
147
- headers = {"Authorization": f"Bearer {openai.api_key}"}
148
- payload = {
149
- "model": "gpt-4o-mini",
150
- "messages": [
151
- {"role": "system", "content": system_message},
152
- {"role": "user", "content": content},
153
- ],
154
- "max_tokens": max_tokens,
155
- "temperature": temperature,
156
- "top_p": top_p,
157
- }
158
- async with aiohttp.ClientSession() as session:
159
- async with session.post(url, headers=headers, json=payload) as resp:
160
- resp_json = await resp.json()
161
- return resp_json["choices"][0]["message"]["content"]
162
 
163
- async def generate_copywriting_async(categories, topic):
164
- """
165
- 여러 카테고리에 대해 비동기적으로 카피라이팅을 생성하는 함수
166
- """
167
- max_tokens = 1000
168
- temperature = 0.8
169
- top_p = 0.95
170
 
171
- tasks = []
172
- for category in categories:
173
- prompt = get_category_prompt(category)
174
- user_content = f"주제: {topic}"
175
- tasks.append(
176
- asyncio.create_task(
177
- call_api_async(user_content, prompt, max_tokens, temperature, top_p)
178
- )
179
- )
180
-
181
- # 병렬 실행
182
- results = await asyncio.gather(*tasks)
183
-
184
- # category 순서에 맞춰 딕셔너리화
185
- result_dict = {}
186
- for category, copywriting in zip(categories, results):
187
- result_dict[category] = copywriting
188
-
189
- return result_dict
190
 
191
- async def validate_and_generate_async(topic):
192
- """
193
- Gradio 스트리밍(yield)을 통해
194
- 카테고리의 결과가 나오면 즉시 전달하도록 하는 함수
195
- """
196
- try:
197
- # 우선 상태창 업데이트
198
- yield [gr.update(value="카피라이팅 생성 중...")] + [gr.update() for _ in CATEGORIES]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- # 비동기로 카피라이팅 생성
201
- # 각 카테고리별 결과를 기다리지 않고, 완료되는 순서대로 전달
202
- async def async_run():
203
- # as_completed로 각 결과가 나올 때마다 받기
204
- tasks = {}
205
- for category in CATEGORIES:
206
- prompt = get_category_prompt(category)
207
- user_content = f"주제: {topic}"
208
- tasks[category] = asyncio.create_task(
209
- call_api_async(user_content, prompt, 1000, 0.8, 0.95)
210
- )
211
-
212
- # 카테고리별 결과가 끝나는 순서대로 반환
213
- for finished_task in asyncio.as_completed(tasks.values()):
214
- # finished_task가 어떤 카테고리에 해당되는지 찾는다
215
- for cat, tsk in tasks.items():
216
- if tsk == finished_task:
217
- # 해당 카테고리의 결과
218
- try:
219
- result_text = await finished_task
220
- except Exception as e:
221
- result_text = f"오류 발생: {str(e)}"
222
- yield cat, result_text
223
- break
224
 
225
- # 스트리밍 처리
226
- # 카테고리가 끝날 때마다 yield로 결과 갱신
227
- # outputs 순서는 [status] + [테고리1, 카테고리2, ...]
228
- # 따라서 index를 찾아서 부분 업데이트
229
- # status(0)번, 카테고리별 1~N
230
- current_values = ["카피라이팅 생성 중..."] + ["" for _ in CATEGORIES]
231
 
232
- async for cat, result_text in async_run():
233
- # cat의 인덱스를 찾아서 갱신
234
- idx = CATEGORIES.index(cat) + 1 # status가 0번이므로 +1
235
- current_values[idx] = result_text
236
- yield current_values # 부분 업데이트
237
 
238
- # 모든 카테고 끝난 최종 상태메
239
- current_values[0] = "카피라이팅 생성이 모두 완료되었습니다."
240
- yield current_values
241
 
242
- except Exception as e:
243
- logger.error(f"Error during copywriting generation: {str(e)}")
244
- yield [f"오류 발생: {str(e)}"] + ["" for _ in CATEGORIES]
245
 
246
- # 비동기 함수Gradio 이벤트에 연결
247
- generate_btn.click(
248
- fn=validate_and_generate_async,
249
- inputs=[topic],
250
- outputs=[status] + [output_boxes[category] for category in CATEGORIES],
251
- api_name="generate_copy_async" # 임의의 api_name
 
252
  )
253
- ##########################################
254
- # 추가된 비동기/병렬 처리용 코드 끝
255
- ##########################################
256
 
257
- # 인터페이스 실행
258
- iface.launch()
 
4
  import requests
5
  import logging
6
 
7
+ # ===== (기본코드 시작) =====
8
  # 로깅 설정
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
 
116
  output_box = gr.Textbox(label=category, visible=True)
117
  output_boxes[category] = output_box
118
 
 
119
  def validate_and_generate(topic):
120
  try:
121
  results = generate_copywriting(CATEGORIES, topic)
 
132
  except Exception as e:
133
  logger.error(f"Error during copywriting generation: {str(e)}")
134
  return [gr.update(value=f"오류 발생: {str(e)}")] + [gr.update(value="") for _ in CATEGORIES]
 
 
 
 
 
 
135
 
136
+ generate_btn.click(
137
+ fn=validate_and_generate,
138
+ inputs=[topic],
139
+ outputs=[status] + [output_boxes[category] for category in CATEGORIES]
140
+ )
141
+
142
+ # 인터페이스 실행
143
+ iface.launch()
144
+ # ===== (기본코드 끝) =====
 
 
 
 
 
 
 
 
 
 
 
145
 
 
 
 
 
 
 
 
146
 
147
+ # ===== (아래부터 비동기/병렬 구조 추가 코드) =====
148
+ import asyncio
149
+ import aiohttp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ """
152
+ [요청사항]
153
+ 1) 현재는 한번출력되는 형태이므로 병렬 구조로 만들어줌.
154
+ 2) 지피티를 각각 항목에 대해 동작.
155
+ 3) 결과가 모두 완성된 뒤에 한꺼번에 나오지 말고, '비동기식'으로 바로바로 출력.
156
+ 4) asyncio와 aiohttp를 사용하여 비동기 API 호출을 구현하고,
157
+ Gradio의 스트리밍 기능을 활용하여 결과를 비동기적으로 업데이트.
158
+ """
159
+
160
+ # (기본코드와 동일 기능: 프롬프트 생성 함수 get_category_prompt 유지)
161
+
162
+ async def call_api_async(session, content, system_message, max_tokens=1000, temperature=0.8, top_p=0.95):
163
+ """
164
+ aiohttp를 사용하여 OpenAI API를 비동기로 호출하는 함수
165
+ """
166
+ url = "https://api.openai.com/v1/chat/completions"
167
+ headers = {"Authorization": f"Bearer {openai.api_key}"}
168
+ json_data = {
169
+ "model": "gpt-4o-mini",
170
+ "messages": [
171
+ {"role": "system", "content": system_message},
172
+ {"role": "user", "content": content},
173
+ ],
174
+ "max_tokens": max_tokens,
175
+ "temperature": temperature,
176
+ "top_p": top_p,
177
+ }
178
+ async with session.post(url, headers=headers, json=json_data) as resp:
179
+ resp_json = await resp.json()
180
+ return resp_json["choices"][0]["message"]["content"]
181
+
182
+ async def fetch_copywriting(session, category, topic):
183
+ """
184
+ 각 카테고리에 대한 카피라이팅 결과를 비동기로 가져오는 코루틴
185
+ """
186
+ prompt = get_category_prompt(category)
187
+ user_content = f"주제: {topic}"
188
+ result = await call_api_async(session, user_content, prompt)
189
+ return category, result
190
+
191
+
192
+ def generate_copywriting_stream(topic):
193
+ """
194
+ Gradio의 스트리밍 기능을 위해 generator를 리턴하는 함수.
195
+ 카테고리별로 비동기 호출을 하고, 완료되는 순서대로 yield하여 즉시 결과를 보냄.
196
+ """
197
+ async def async_gen():
198
+ # '병렬' 실행을 위해 asyncio.gather 사용
199
+ async with aiohttp.ClientSession() as session:
200
+ tasks = []
201
+ for category in CATEGORIES:
202
+ tasks.append(fetch_copywriting(session, category, topic))
203
+
204
+ # asyncio.as_completed를 사용하면 '완료되는 순서'대로 결과를 가져올 수 있음
205
+ for coro in asyncio.as_completed(tasks):
206
+ try:
207
+ category, result = await coro
208
+ # 카테고리 결과를 실시간 스트리밍으로 전달
209
+ yield f"### [{category}] 결과\n{result}\n\n"
210
+ except Exception as e:
211
+ yield f"### 오류 발생\n{str(e)}\n\n"
212
+
213
+ # Gradio가 stream=True 옵션으로 사용할 수 있는 비동기 제너레이터 반환
214
+ return async_gen()
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ # 아래는 비동기 버튼(추가) - 기존 코드 수정 없이 새로운 버튼과 출력만 추가
218
+ with gr.Blocks() as async_interface:
219
+ gr.Markdown("## (비동기 버전) AI피라이팅 생성기")
220
+ with gr.Column():
221
+ topic_async = gr.Textbox(lines=1, label="주제를 입력하세요(비동기)")
 
222
 
223
+ # 비동기 요청용 버튼
224
+ async_generate_btn = gr.Button("비동기로 카피라이팅 생성하기")
 
 
 
225
 
226
+ # 스트 출력을 받을 Textbox (하나의 텍스트 박스에 순차적으로 표)
227
+ async_output = gr.Textbox(label="비동기 결과 스트리밍", lines=20)
 
228
 
229
+ def start_async_generation(topic):
230
+ # Gradio에서 stream=True 옵션을 사용하기 위해서는 generator(또는 async generator)가 필요
231
+ return generate_copywriting_stream(topic)
232
 
233
+ # 클릭 시, 비동기 제너레이터stream=True로 연결
234
+ async_generate_btn.click(
235
+ fn=start_async_generation,
236
+ inputs=topic_async,
237
+ outputs=async_output,
238
+ queue=False, # 비동기 즉시 응답
239
+ stream=True # 스���리밍으로 결과 전달
240
  )
 
 
 
241
 
242
+ # 비동기 인터페이스 실행 (포트 다르게 설정하거나 원하는 대로 사용)
243
+ async_interface.launch(server_name="0.0.0.0", server_port=7861)