ChloeLee22's picture
Update app.py
bb60938 verified
import gradio as gr
import os
import time
import re
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import OpenAI
import random
from concurrent.futures import TimeoutError as FuturesTimeoutError
from openai import APIStatusError, APITimeoutError, APIConnectionError
import traceback
from dotenv import load_dotenv
from prompts import (
USER_PROMPT,
WRAPPER_PROMPT,
CALL_1_SYSTEM_PROMPT,
CALL_2_SYSTEM_PROMPT,
CALL_3_SYSTEM_PROMPT,
)
import difflib
import csv
from threading import Lock
import threading
load_dotenv()
BASE_URL = "https://api.upstage.ai/v1"
API_KEY = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=API_KEY, base_url=BASE_URL) # 60์ดˆ ํ•˜๋“œ ํƒ€์ž„์•„์›ƒ
import re
import re
import re
def postprocess_pronoun(text: str) -> str:
"""
'์ด์žฌ๋ช… ๋Œ€ํ‘œ'๊ฐ€ ํฌํ•จ๋œ ๋ชจ๋“  ๋‹จ์–ด๋ฅผ '์ด์žฌ๋ช… ๋Œ€ํ†ต๋ น'์œผ๋กœ ๊ต์ฒดํ•˜๋ฉฐ,
๋’ค๋”ฐ๋ฅด๋Š” ์กฐ์‚ฌ๊ฐ€ ์žˆ์„ ๊ฒฝ์šฐ ํ•จ๊ป˜ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค.
"""
# ํ˜•ํƒœ๊ฐ€ '๋ฐ”๋€Œ์–ด์•ผ ํ•˜๋Š”' ์กฐ์‚ฌ/์–ด๋ฏธ๋งŒ ์ •์˜
correction_map = {
'๋Š”': '์€', '๊ฐ€': '์ด', '๋ฅผ': '์„', '์™€': '๊ณผ', '๋กœ': '์œผ๋กœ',
'์—ฌ': '์ด์—ฌ', '๋ผ': '์ด๋ผ', '๋ž‘': '์ด๋ž‘',
'๋‹ค': '์ด๋‹ค', '์˜€๋‹ค': '์ด์—ˆ๋‹ค', '๋ผ๋ฉด': '์ด๋ผ๋ฉด', '๋ผ์„œ': '์ด๋ผ์„œ'
}
# ์ •๊ทœ์‹์ด ์ฐพ์•„์•ผ ํ•  ๋ชจ๋“  ์กฐ์‚ฌ/์–ด๋ฏธ ๋ชฉ๋ก
all_target_particles = list(correction_map.keys()) + ['๋กœ๋ถ€ํ„ฐ', '๋งŒ', '๋„', '๊ป˜์„œ']
particle_pattern = "|".join(re.escape(p) for p in all_target_particles)
# ์ตœ์ข… ์ •๊ทœ์‹: ์•ˆ์ „์žฅ์น˜(?!...)๋ฅผ ์ œ๊ฑฐํ•˜์—ฌ ๋ชจ๋“  ๊ฒฝ์šฐ๋ฅผ ์ฐพ์•„๋ƒ…๋‹ˆ๋‹ค.
regex = re.compile(f"(์ด์žฌ๋ช…\s*๋Œ€ํ‘œ)({particle_pattern})?")
def replace_func(match):
particle = match.group(2)
new_phrase = "์ด์žฌ๋ช… ๋Œ€ํ†ต๋ น"
if particle:
new_phrase += correction_map.get(particle, particle)
return new_phrase
return regex.sub(replace_func, text)
def extract_json_from_text(text):
"""
ํ…์ŠคํŠธ์—์„œ JSON ๋ถ€๋ถ„์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
์—ฌ๋Ÿฌ ํŒจํ„ด์„ ์‹œ๋„ํ•˜์—ฌ JSON์„ ์ฐพ์Šต๋‹ˆ๋‹ค.
Args:
text: JSON์ด ํฌํ•จ๋œ ํ…์ŠคํŠธ
Returns:
dict: ํŒŒ์‹ฑ๋œ JSON ๊ฐ์ฒด ๋˜๋Š” None
"""
if not text or not text.strip():
return None
# ํ…์ŠคํŠธ ์ •๋ฆฌ (์•ž๋’ค ๊ณต๋ฐฑ ์ œ๊ฑฐ)
text = text.strip()
# ํŒจํ„ด 1: ```json ... ``` ํ˜•ํƒœ
json_code_block_pattern = r'```json\s*(.*?)\s*```'
match = re.search(json_code_block_pattern, text, re.DOTALL)
if match:
try:
extracted = match.group(1).strip()
if extracted:
return json.loads(extracted)
except json.JSONDecodeError:
pass
# ํŒจํ„ด 2: ``` ... ``` ํ˜•ํƒœ (json ํƒœ๊ทธ ์—†์ด)
code_block_pattern = r'```\s*(.*?)\s*```'
match = re.search(code_block_pattern, text, re.DOTALL)
if match:
try:
extracted = match.group(1).strip()
if extracted:
return json.loads(extracted)
except json.JSONDecodeError:
pass
# ํŒจํ„ด 3: {๋กœ ์‹œ์ž‘ํ•˜๊ณ  }๋กœ ๋๋‚˜๋Š” JSON ๊ฐ์ฒด
json_object_pattern = r'\{.*\}'
match = re.search(json_object_pattern, text, re.DOTALL)
if match:
try:
extracted = match.group(0).strip()
if extracted:
return json.loads(extracted)
except json.JSONDecodeError:
pass
# ํŒจํ„ด 4: [๋กœ ์‹œ์ž‘ํ•˜๊ณ  ]๋กœ ๋๋‚˜๋Š” JSON ๋ฐฐ์—ด
json_array_pattern = r'\[.*\]'
match = re.search(json_array_pattern, text, re.DOTALL)
if match:
try:
extracted = match.group(0).strip()
if extracted:
return json.loads(extracted)
except json.JSONDecodeError:
pass
# ํŒจํ„ด 5: ์ „์ฒด ํ…์ŠคํŠธ๊ฐ€ JSON์ธ ๊ฒฝ์šฐ
try:
if text.startswith('{') or text.startswith('['):
return json.loads(text)
except json.JSONDecodeError:
pass
return None
# Load vocabulary for rule-based correction
def load_vocabulary():
vocabulary = {}
with open("Vocabulary.csv", "r", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
# Debug: print first row to check column names
if len(vocabulary) == 0:
print("CSV columns:", list(row.keys()))
vocabulary[row["original"]] = row["corrected"]
return vocabulary
VOCABULARY = load_vocabulary()
# ์Šค๋ ˆ๋“œ ์•ˆ์ „ํ•œ ์นด์šดํ„ฐ
counter_lock = Lock()
processed_count = 0
total_bulks = 0
def apply_vocabulary_correction(text):
for original, corrected in VOCABULARY.items():
text = text.replace(original, corrected)
return text
def create_bulk_paragraphs(text, max_chars=500):
"""
ํ…์ŠคํŠธ๋ฅผ 500์ž ๊ธฐ์ค€์œผ๋กœ ๋ฒŒํฌ ๋‹จ์œ„๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
Args:
text: ์ž…๋ ฅ ํ…์ŠคํŠธ
max_chars: ์ตœ๋Œ€ ๋ฌธ์ž ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 500)
Returns:
List[str]: ๋ฒŒํฌ ๋‹จ์œ„๋กœ ๋ถ„ํ• ๋œ ํ…์ŠคํŠธ ๋ฆฌ์ŠคํŠธ
"""
paragraphs = [p.strip() for p in text.split("\n") if p.strip()]
if not paragraphs:
return []
bulks = []
current_bulk = []
current_length = 0
for para in paragraphs:
para_length = len(para)
# ํ˜„์žฌ ๋ฌธ๋‹จ์ด 500์ž๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ๊ฒฝ์šฐ
if para_length > max_chars:
# ํ˜„์žฌ ๋ฒŒํฌ๊ฐ€ ์žˆ๋‹ค๋ฉด ์ถ”๊ฐ€
if current_bulk:
bulks.append("\n".join(current_bulk))
current_bulk = []
current_length = 0
# ๊ธด ๋ฌธ๋‹จ์€ ๋‹จ๋…์œผ๋กœ ์ฒ˜๋ฆฌ
bulks.append(para)
else:
# ํ˜„์žฌ ๋ฒŒํฌ์— ์ถ”๊ฐ€ํ–ˆ์„ ๋•Œ 500์ž๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ๊ฒฝ์šฐ
if (
current_length + para_length + len(current_bulk) > max_chars
and current_bulk
):
# ํ˜„์žฌ ๋ฒŒํฌ๋ฅผ ์™„์„ฑํ•˜๊ณ  ์ƒˆ ๋ฒŒํฌ ์‹œ์ž‘
bulks.append("\n".join(current_bulk))
current_bulk = [para]
current_length = para_length
else:
# ํ˜„์žฌ ๋ฒŒํฌ์— ์ถ”๊ฐ€
current_bulk.append(para)
current_length += para_length
# ๋งˆ์ง€๋ง‰ ๋ฒŒํฌ ์ถ”๊ฐ€
if current_bulk:
bulks.append("\n".join(current_bulk))
return bulks
def process_bulk(bulk_text, bulk_index, max_retries=3, article_info=""):
"""
ํ•˜๋‚˜์˜ ๋ฒŒํฌ๋ฅผ ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
API ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ, ๋งˆ์ง€๋ง‰์œผ๋กœ ์„ฑ๊ณตํ•œ ๋‹จ๊ณ„์˜ ๊ฒฐ๊ณผ๋ฌผ์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
global processed_count
thread_id = threading.get_ident()
start = time.time()
# ๊ฐ ๋‹จ๊ณ„์˜ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•  ๋ณ€์ˆ˜
step0, proofread_result, step1, step1_explanation, step2, step2_explanation, step3, step4, step5 = (None,) * 9
# ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ๋ฐ˜ํ™˜ํ•  ๋งˆ์ง€๋ง‰ ์„ฑ๊ณต ๊ฒฐ๊ณผ๋ฌผ (์ดˆ๊ธฐ๊ฐ’์€ ์›๋ณธ ํ…์ŠคํŠธ)
last_successful_output = bulk_text
for attempt in range(max_retries):
try:
# Step 0: ๋‹จ์–ด์žฅ ๊ธฐ๋ฐ˜ ๊ต์ •
step0 = apply_vocabulary_correction(bulk_text)
last_successful_output = step0
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling proofread...")
proofread_result = call_proofread(step0)
# Step 1: Solar API ํ˜ธ์ถœ 1
system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT)
user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result)
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling step1...")
step1_json = call_solar_pro2(system_step1, user_step1)
try:
parsed_json = json.loads(step1_json)
step1 = parsed_json.get('output', step0) # ํŒŒ์‹ฑ ์„ฑ๊ณต ์‹œ output ํ•„๋“œ ์‚ฌ์šฉ, ์—†์œผ๋ฉด ์ด์ „ ๋‹จ๊ณ„ ๊ฒฐ๊ณผ
step1_explanation = parsed_json.get('explanation', '')
last_successful_output = step1
except json.JSONDecodeError:
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Step1 JSON ํŒŒ์‹ฑ ์‹คํŒจ. ์ถ”์ถœ ์‹œ๋„...")
extracted_json = extract_json_from_text(step1_json)
if extracted_json and 'output' in extracted_json:
step1 = extracted_json['output']
step1_explanation = extracted_json.get('explanation', '')
last_successful_output = step1
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ”์ถœ ์„ฑ๊ณต")
else:
step1 = step0 # ์ถ”์ถœ๋„ ์‹คํŒจํ•˜๋ฉด ์ด์ „ ๋‹จ๊ณ„ ๊ฒฐ๊ณผ ์‚ฌ์šฉ
step1_explanation = ""
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ”์ถœ ์‹คํŒจ")
# Step 2: Solar API ํ˜ธ์ถœ 2
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling step2...")
step2_json = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1)
try:
parsed_json = json.loads(step2_json)
step2 = parsed_json.get('output', step1)
step2_explanation = parsed_json.get('explanation', '')
last_successful_output = step2
except json.JSONDecodeError:
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Step2 JSON ํŒŒ์‹ฑ ์‹คํŒจ. ์ถ”์ถœ ์‹œ๋„...")
extracted_json = extract_json_from_text(step2_json)
if extracted_json and 'output' in extracted_json:
step2 = extracted_json['output']
step2_explanation = extracted_json.get('explanation', '')
last_successful_output = step2
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ”์ถœ ์„ฑ๊ณต")
else:
step2 = step1
step2_explanation = ""
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ”์ถœ ์‹คํŒจ")
# Step 3: Solar API ํ˜ธ์ถœ 3
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling step3...")
step3_json = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2)
try:
parsed_json = json.loads(step3_json)
step3 = parsed_json.get('output', step2)
last_successful_output = step3
except json.JSONDecodeError:
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Step3 JSON ํŒŒ์‹ฑ ์‹คํŒจ. ์ถ”์ถœ ์‹œ๋„...")
extracted_json = extract_json_from_text(step3_json)
if extracted_json and 'output' in extracted_json:
step3 = extracted_json['output']
last_successful_output = step3
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ”์ถœ ์„ฑ๊ณต")
else:
step3 = step2
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ”์ถœ ์‹คํŒจ")
# Step 4: ๋‹จ์–ด์žฅ ๊ธฐ๋ฐ˜ ๊ต์ •
step4 = apply_vocabulary_correction(step3)
# Step 5: ๋Œ€๋ช…์‚ฌ ํ›„์ฒ˜๋ฆฌ
step5 = postprocess_pronoun(step4)
last_successful_output = step5
elapsed = time.time() - start
with counter_lock:
processed_count += 1
# ๋ชจ๋“  ๋‹จ๊ณ„๊ฐ€ ์„ฑ๊ณต์ ์œผ๋กœ ์™„๋ฃŒ๋˜๋ฉด ๋ฃจํ”„ ํƒˆ์ถœ ๋ฐ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return {
"bulk_index": bulk_index,
"original": bulk_text,
"final": last_successful_output,
"processing_time": elapsed,
"character_count": len(bulk_text),
"attempts": attempt + 1,
}
except Exception as e:
if attempt < max_retries - 1:
print(
f"{article_info}[Thread-{thread_id}] ๋ฒŒํฌ {bulk_index+1} ์‹œ๋„ {attempt+1} ์‹คํŒจ, ์žฌ์‹œ๋„: {type(e).__name__}"
)
time.sleep(1 * (attempt + 1)) # ์žฌ์‹œ๋„ ์ „ ์ž ์‹œ ๋Œ€๊ธฐ
continue
else:
# ์ตœ์ข… ์‹คํŒจ ์‹œ, ๋งˆ์ง€๋ง‰์œผ๋กœ ์„ฑ๊ณตํ–ˆ๋˜ ๊ฒฐ๊ณผ๋ฌผ์„ final์— ๋‹ด์•„ ๋ฐ˜ํ™˜
print(f"๐Ÿ”ฅ๐Ÿ”ฅ๐Ÿ”ฅ {article_info}[Thread-{thread_id}] ๋ฒŒํฌ {bulk_index+1} ์ตœ์ข… ์‹คํŒจ! ๋งˆ์ง€๋ง‰ ์„ฑ๊ณต ๊ฒฐ๊ณผ๋ฌผ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ๐Ÿ”ฅ๐Ÿ”ฅ๐Ÿ”ฅ")
traceback.print_exc() # ์ƒ์„ธ ์—๋Ÿฌ ๋กœ๊ทธ ์ถœ๋ ฅ
return {
"bulk_index": bulk_index,
"original": bulk_text,
"final": last_successful_output, # โœจ ํ•ต์‹ฌ: ์›๋ณธ ๋Œ€์‹  ๋งˆ์ง€๋ง‰ ์„ฑ๊ณต ๊ฒฐ๊ณผ๋ฌผ์„ ์‚ฌ์šฉ
"processing_time": time.time() - start,
"character_count": len(bulk_text),
"error": traceback.format_exc(),
"attempts": max_retries,
}
# ๋ฃจํ”„๊ฐ€ ์˜ˆ๊ธฐ์น˜ ์•Š๊ฒŒ ์ข…๋ฃŒ๋  ๊ฒฝ์šฐ๋ฅผ ๋Œ€๋น„ํ•œ ๋ฐ˜ํ™˜
return {"bulk_index": bulk_index, "final": bulk_text, "error": "unknown_flow_error"}
def call_solar_pro2(system, user, temperature=0.0, model_name="solar-pro2"):
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
stream=False,
temperature=temperature,
)
return response.choices[0].message.content
def call_proofread(paragraph):
prompt = "์ž…๋ ฅ๋œ ๋ฌธ์„œ์— ๋Œ€ํ•œ ๊ต์—ด ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•ด ์ฃผ์„ธ์š”."
response = client.chat.completions.create(
model="ft:solar-news-correction-dev",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": paragraph},
],
stream=False,
temperature=0.0,
)
return response.choices[0].message.content
def highlight_diff(original, corrected):
matcher = difflib.SequenceMatcher(None, original, corrected)
result_html = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "equal":
result_html.append(f"<span>{original[i1:i2]}</span>")
elif tag == "replace":
result_html.append(
f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>'
)
result_html.append(
f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>'
)
elif tag == "delete":
result_html.append(
f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>'
)
elif tag == "insert":
result_html.append(
f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>'
)
return "".join(result_html)
def process_text_parallel(input_text, max_workers=10):
"""ํ…์ŠคํŠธ๋ฅผ ๋ฒŒํฌ ๋‹จ์œ„๋กœ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค."""
global processed_count, total_bulks
# ๋ฒŒํฌ ์ƒ์„ฑ
bulks = create_bulk_paragraphs(input_text)
total_bulks = len(bulks)
processed_count = 0
if not bulks:
return []
results = []
# ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# ๋ชจ๋“  ๋ฒŒํฌ๋ฅผ ๋ณ‘๋ ฌ๋กœ ์ œ์ถœ
future_to_bulk = {
executor.submit(process_bulk, bulk, i): i for i, bulk in enumerate(bulks)
}
# ์™„๋ฃŒ๋œ ์ˆœ์„œ๋Œ€๋กœ ๊ฒฐ๊ณผ ์ˆ˜์ง‘
for future in as_completed(future_to_bulk):
try:
result = future.result()
results.append(result)
except Exception as e:
bulk_index = future_to_bulk[future]
print(f"๋ฒŒํฌ {bulk_index+1} ์ฒ˜๋ฆฌ ์ค‘ ์˜ˆ์™ธ ๋ฐœ์ƒ: {e}")
results.append(
{
"bulk_index": bulk_index,
"original": bulks[bulk_index],
"final": bulks[bulk_index],
"processing_time": 0,
"character_count": len(bulks[bulk_index]),
"error": str(e),
}
)
# ๋ฒŒํฌ ์ธ๋ฑ์Šค ์ˆœ์„œ๋Œ€๋กœ ์ •๋ ฌ
results.sort(key=lambda x: x["bulk_index"])
return results
def demo_fn(input_text):
# ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋กœ ๋ฒŒํฌ ๋‹จ์œ„๋กœ ์ฒ˜๋ฆฌ
bulk_results = process_text_parallel(input_text, max_workers=10)
if not bulk_results:
return input_text, input_text
# ๊ฒฐ๊ณผ ํ•ฉ์น˜๊ธฐ
final_texts = [r["final"] for r in bulk_results]
final_result = "\n".join(final_texts)
# ํ•˜์ด๋ผ์ดํŠธ ์ƒ์„ฑ
highlighted = highlight_diff(input_text, final_result)
return final_result, highlighted
with gr.Blocks() as demo:
gr.Markdown("# ๊ต์—ด ๋ชจ๋ธ ๋ฐ๋ชจ")
input_text = gr.Textbox(
label="์›๋ฌธ ์ž…๋ ฅ", lines=10, placeholder="๋ฌธ๋‹จ ๋‹จ์œ„๋กœ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”."
)
btn = gr.Button("๊ต์—ดํ•˜๊ธฐ")
output_corrected = gr.Textbox(label="๊ต์—ด ๊ฒฐ๊ณผ", lines=10)
output_highlight = gr.HTML(label="์ˆ˜์ •๋œ ๋ถ€๋ถ„ ๊ฐ•์กฐ")
btn.click(
fn=demo_fn, inputs=input_text, outputs=[output_corrected, output_highlight]
)
if __name__ == "__main__":
demo.launch()