|
import gradio as gr |
|
import os |
|
import time |
|
import re |
|
import json |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from openai import OpenAI |
|
import random |
|
from concurrent.futures import TimeoutError as FuturesTimeoutError |
|
from openai import APIStatusError, APITimeoutError, APIConnectionError |
|
import traceback |
|
from dotenv import load_dotenv |
|
from prompts import ( |
|
USER_PROMPT, |
|
WRAPPER_PROMPT, |
|
CALL_1_SYSTEM_PROMPT, |
|
CALL_2_SYSTEM_PROMPT, |
|
CALL_3_SYSTEM_PROMPT, |
|
) |
|
import difflib |
|
import csv |
|
from threading import Lock |
|
import threading |
|
|
|
load_dotenv() |
|
|
|
BASE_URL = "https://api.upstage.ai/v1" |
|
API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
client = OpenAI(api_key=API_KEY, base_url=BASE_URL) |
|
|
|
|
|
import re |
|
import re |
|
import re |
|
|
|
def postprocess_pronoun(text: str) -> str: |
|
""" |
|
'์ด์ฌ๋ช
๋ํ'๊ฐ ํฌํจ๋ ๋ชจ๋ ๋จ์ด๋ฅผ '์ด์ฌ๋ช
๋ํต๋ น'์ผ๋ก ๊ต์ฒดํ๋ฉฐ, |
|
๋ค๋ฐ๋ฅด๋ ์กฐ์ฌ๊ฐ ์์ ๊ฒฝ์ฐ ํจ๊ป ์์ ํฉ๋๋ค. |
|
""" |
|
|
|
|
|
correction_map = { |
|
'๋': '์', '๊ฐ': '์ด', '๋ฅผ': '์', '์': '๊ณผ', '๋ก': '์ผ๋ก', |
|
'์ฌ': '์ด์ฌ', '๋ผ': '์ด๋ผ', '๋': '์ด๋', |
|
'๋ค': '์ด๋ค', '์๋ค': '์ด์๋ค', '๋ผ๋ฉด': '์ด๋ผ๋ฉด', '๋ผ์': '์ด๋ผ์' |
|
} |
|
|
|
|
|
all_target_particles = list(correction_map.keys()) + ['๋ก๋ถํฐ', '๋ง', '๋', '๊ป์'] |
|
|
|
particle_pattern = "|".join(re.escape(p) for p in all_target_particles) |
|
|
|
|
|
regex = re.compile(f"(์ด์ฌ๋ช
\s*๋ํ)({particle_pattern})?") |
|
|
|
def replace_func(match): |
|
particle = match.group(2) |
|
new_phrase = "์ด์ฌ๋ช
๋ํต๋ น" |
|
|
|
if particle: |
|
new_phrase += correction_map.get(particle, particle) |
|
|
|
return new_phrase |
|
|
|
return regex.sub(replace_func, text) |
|
|
|
|
|
|
|
def extract_json_from_text(text): |
|
""" |
|
ํ
์คํธ์์ JSON ๋ถ๋ถ์ ์ถ์ถํฉ๋๋ค. |
|
์ฌ๋ฌ ํจํด์ ์๋ํ์ฌ JSON์ ์ฐพ์ต๋๋ค. |
|
|
|
Args: |
|
text: JSON์ด ํฌํจ๋ ํ
์คํธ |
|
|
|
Returns: |
|
dict: ํ์ฑ๋ JSON ๊ฐ์ฒด ๋๋ None |
|
""" |
|
if not text or not text.strip(): |
|
return None |
|
|
|
|
|
text = text.strip() |
|
|
|
|
|
json_code_block_pattern = r'```json\s*(.*?)\s*```' |
|
match = re.search(json_code_block_pattern, text, re.DOTALL) |
|
if match: |
|
try: |
|
extracted = match.group(1).strip() |
|
if extracted: |
|
return json.loads(extracted) |
|
except json.JSONDecodeError: |
|
pass |
|
|
|
|
|
code_block_pattern = r'```\s*(.*?)\s*```' |
|
match = re.search(code_block_pattern, text, re.DOTALL) |
|
if match: |
|
try: |
|
extracted = match.group(1).strip() |
|
if extracted: |
|
return json.loads(extracted) |
|
except json.JSONDecodeError: |
|
pass |
|
|
|
|
|
json_object_pattern = r'\{.*\}' |
|
match = re.search(json_object_pattern, text, re.DOTALL) |
|
if match: |
|
try: |
|
extracted = match.group(0).strip() |
|
if extracted: |
|
return json.loads(extracted) |
|
except json.JSONDecodeError: |
|
pass |
|
|
|
|
|
json_array_pattern = r'\[.*\]' |
|
match = re.search(json_array_pattern, text, re.DOTALL) |
|
if match: |
|
try: |
|
extracted = match.group(0).strip() |
|
if extracted: |
|
return json.loads(extracted) |
|
except json.JSONDecodeError: |
|
pass |
|
|
|
|
|
try: |
|
if text.startswith('{') or text.startswith('['): |
|
return json.loads(text) |
|
except json.JSONDecodeError: |
|
pass |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
def load_vocabulary(): |
|
vocabulary = {} |
|
with open("Vocabulary.csv", "r", encoding="utf-8-sig") as f: |
|
reader = csv.DictReader(f) |
|
for row in reader: |
|
|
|
if len(vocabulary) == 0: |
|
print("CSV columns:", list(row.keys())) |
|
vocabulary[row["original"]] = row["corrected"] |
|
return vocabulary |
|
|
|
|
|
VOCABULARY = load_vocabulary() |
|
|
|
|
|
counter_lock = Lock() |
|
processed_count = 0 |
|
total_bulks = 0 |
|
|
|
|
|
def apply_vocabulary_correction(text): |
|
for original, corrected in VOCABULARY.items(): |
|
text = text.replace(original, corrected) |
|
return text |
|
|
|
|
|
def create_bulk_paragraphs(text, max_chars=500): |
|
""" |
|
ํ
์คํธ๋ฅผ 500์ ๊ธฐ์ค์ผ๋ก ๋ฒํฌ ๋จ์๋ก ๋ถํ ํฉ๋๋ค. |
|
|
|
Args: |
|
text: ์
๋ ฅ ํ
์คํธ |
|
max_chars: ์ต๋ ๋ฌธ์ ์ (๊ธฐ๋ณธ๊ฐ: 500) |
|
|
|
Returns: |
|
List[str]: ๋ฒํฌ ๋จ์๋ก ๋ถํ ๋ ํ
์คํธ ๋ฆฌ์คํธ |
|
""" |
|
paragraphs = [p.strip() for p in text.split("\n") if p.strip()] |
|
|
|
if not paragraphs: |
|
return [] |
|
|
|
bulks = [] |
|
current_bulk = [] |
|
current_length = 0 |
|
|
|
for para in paragraphs: |
|
para_length = len(para) |
|
|
|
|
|
if para_length > max_chars: |
|
|
|
if current_bulk: |
|
bulks.append("\n".join(current_bulk)) |
|
current_bulk = [] |
|
current_length = 0 |
|
|
|
|
|
bulks.append(para) |
|
else: |
|
|
|
if ( |
|
current_length + para_length + len(current_bulk) > max_chars |
|
and current_bulk |
|
): |
|
|
|
bulks.append("\n".join(current_bulk)) |
|
current_bulk = [para] |
|
current_length = para_length |
|
else: |
|
|
|
current_bulk.append(para) |
|
current_length += para_length |
|
|
|
|
|
if current_bulk: |
|
bulks.append("\n".join(current_bulk)) |
|
|
|
return bulks |
|
|
|
|
|
|
|
def process_bulk(bulk_text, bulk_index, max_retries=3, article_info=""): |
|
""" |
|
ํ๋์ ๋ฒํฌ๋ฅผ ํ์ดํ๋ผ์ธ์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค. |
|
API ์๋ฌ ๋ฐ์ ์, ๋ง์ง๋ง์ผ๋ก ์ฑ๊ณตํ ๋จ๊ณ์ ๊ฒฐ๊ณผ๋ฌผ์ ๋ฐํํฉ๋๋ค. |
|
""" |
|
global processed_count |
|
thread_id = threading.get_ident() |
|
start = time.time() |
|
|
|
|
|
step0, proofread_result, step1, step1_explanation, step2, step2_explanation, step3, step4, step5 = (None,) * 9 |
|
|
|
|
|
last_successful_output = bulk_text |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
|
|
step0 = apply_vocabulary_correction(bulk_text) |
|
last_successful_output = step0 |
|
|
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling proofread...") |
|
proofread_result = call_proofread(step0) |
|
|
|
|
|
system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT) |
|
user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result) |
|
|
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling step1...") |
|
step1_json = call_solar_pro2(system_step1, user_step1) |
|
try: |
|
parsed_json = json.loads(step1_json) |
|
step1 = parsed_json.get('output', step0) |
|
step1_explanation = parsed_json.get('explanation', '') |
|
last_successful_output = step1 |
|
except json.JSONDecodeError: |
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Step1 JSON ํ์ฑ ์คํจ. ์ถ์ถ ์๋...") |
|
extracted_json = extract_json_from_text(step1_json) |
|
if extracted_json and 'output' in extracted_json: |
|
step1 = extracted_json['output'] |
|
step1_explanation = extracted_json.get('explanation', '') |
|
last_successful_output = step1 |
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์ฑ๊ณต") |
|
else: |
|
step1 = step0 |
|
step1_explanation = "" |
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์คํจ") |
|
|
|
|
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling step2...") |
|
step2_json = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1) |
|
try: |
|
parsed_json = json.loads(step2_json) |
|
step2 = parsed_json.get('output', step1) |
|
step2_explanation = parsed_json.get('explanation', '') |
|
last_successful_output = step2 |
|
except json.JSONDecodeError: |
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Step2 JSON ํ์ฑ ์คํจ. ์ถ์ถ ์๋...") |
|
extracted_json = extract_json_from_text(step2_json) |
|
if extracted_json and 'output' in extracted_json: |
|
step2 = extracted_json['output'] |
|
step2_explanation = extracted_json.get('explanation', '') |
|
last_successful_output = step2 |
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์ฑ๊ณต") |
|
else: |
|
step2 = step1 |
|
step2_explanation = "" |
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์คํจ") |
|
|
|
|
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling step3...") |
|
step3_json = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2) |
|
try: |
|
parsed_json = json.loads(step3_json) |
|
step3 = parsed_json.get('output', step2) |
|
last_successful_output = step3 |
|
except json.JSONDecodeError: |
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Step3 JSON ํ์ฑ ์คํจ. ์ถ์ถ ์๋...") |
|
extracted_json = extract_json_from_text(step3_json) |
|
if extracted_json and 'output' in extracted_json: |
|
step3 = extracted_json['output'] |
|
last_successful_output = step3 |
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์ฑ๊ณต") |
|
else: |
|
step3 = step2 |
|
print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์คํจ") |
|
|
|
|
|
step4 = apply_vocabulary_correction(step3) |
|
|
|
|
|
step5 = postprocess_pronoun(step4) |
|
last_successful_output = step5 |
|
|
|
elapsed = time.time() - start |
|
|
|
with counter_lock: |
|
processed_count += 1 |
|
|
|
|
|
return { |
|
"bulk_index": bulk_index, |
|
"original": bulk_text, |
|
"final": last_successful_output, |
|
"processing_time": elapsed, |
|
"character_count": len(bulk_text), |
|
"attempts": attempt + 1, |
|
} |
|
|
|
except Exception as e: |
|
if attempt < max_retries - 1: |
|
print( |
|
f"{article_info}[Thread-{thread_id}] ๋ฒํฌ {bulk_index+1} ์๋ {attempt+1} ์คํจ, ์ฌ์๋: {type(e).__name__}" |
|
) |
|
time.sleep(1 * (attempt + 1)) |
|
continue |
|
else: |
|
|
|
print(f"๐ฅ๐ฅ๐ฅ {article_info}[Thread-{thread_id}] ๋ฒํฌ {bulk_index+1} ์ต์ข
์คํจ! ๋ง์ง๋ง ์ฑ๊ณต ๊ฒฐ๊ณผ๋ฌผ์ ์ฌ์ฉํฉ๋๋ค. ๐ฅ๐ฅ๐ฅ") |
|
traceback.print_exc() |
|
|
|
return { |
|
"bulk_index": bulk_index, |
|
"original": bulk_text, |
|
"final": last_successful_output, |
|
"processing_time": time.time() - start, |
|
"character_count": len(bulk_text), |
|
"error": traceback.format_exc(), |
|
"attempts": max_retries, |
|
} |
|
|
|
|
|
return {"bulk_index": bulk_index, "final": bulk_text, "error": "unknown_flow_error"} |
|
|
|
|
|
|
|
def call_solar_pro2(system, user, temperature=0.0, model_name="solar-pro2"): |
|
response = client.chat.completions.create( |
|
model=model_name, |
|
messages=[ |
|
{"role": "system", "content": system}, |
|
{"role": "user", "content": user}, |
|
], |
|
stream=False, |
|
temperature=temperature, |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def call_proofread(paragraph): |
|
prompt = "์
๋ ฅ๋ ๋ฌธ์์ ๋ํ ๊ต์ด ๊ฒฐ๊ณผ๋ฅผ ์์ฑํด ์ฃผ์ธ์." |
|
response = client.chat.completions.create( |
|
model="ft:solar-news-correction-dev", |
|
messages=[ |
|
{"role": "system", "content": prompt}, |
|
{"role": "user", "content": paragraph}, |
|
], |
|
stream=False, |
|
temperature=0.0, |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def highlight_diff(original, corrected): |
|
matcher = difflib.SequenceMatcher(None, original, corrected) |
|
result_html = [] |
|
for tag, i1, i2, j1, j2 in matcher.get_opcodes(): |
|
if tag == "equal": |
|
result_html.append(f"<span>{original[i1:i2]}</span>") |
|
elif tag == "replace": |
|
result_html.append( |
|
f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>' |
|
) |
|
result_html.append( |
|
f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>' |
|
) |
|
elif tag == "delete": |
|
result_html.append( |
|
f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>' |
|
) |
|
elif tag == "insert": |
|
result_html.append( |
|
f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>' |
|
) |
|
return "".join(result_html) |
|
|
|
|
|
def process_text_parallel(input_text, max_workers=10): |
|
"""ํ
์คํธ๋ฅผ ๋ฒํฌ ๋จ์๋ก ๋ณ๋ ฌ ์ฒ๋ฆฌํฉ๋๋ค.""" |
|
global processed_count, total_bulks |
|
|
|
|
|
bulks = create_bulk_paragraphs(input_text) |
|
total_bulks = len(bulks) |
|
processed_count = 0 |
|
|
|
if not bulks: |
|
return [] |
|
|
|
results = [] |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
|
future_to_bulk = { |
|
executor.submit(process_bulk, bulk, i): i for i, bulk in enumerate(bulks) |
|
} |
|
|
|
|
|
for future in as_completed(future_to_bulk): |
|
try: |
|
result = future.result() |
|
results.append(result) |
|
except Exception as e: |
|
bulk_index = future_to_bulk[future] |
|
print(f"๋ฒํฌ {bulk_index+1} ์ฒ๋ฆฌ ์ค ์์ธ ๋ฐ์: {e}") |
|
results.append( |
|
{ |
|
"bulk_index": bulk_index, |
|
"original": bulks[bulk_index], |
|
"final": bulks[bulk_index], |
|
"processing_time": 0, |
|
"character_count": len(bulks[bulk_index]), |
|
"error": str(e), |
|
} |
|
) |
|
|
|
|
|
results.sort(key=lambda x: x["bulk_index"]) |
|
|
|
return results |
|
|
|
|
|
|
|
def demo_fn(input_text): |
|
|
|
bulk_results = process_text_parallel(input_text, max_workers=10) |
|
|
|
if not bulk_results: |
|
return input_text, input_text |
|
|
|
|
|
final_texts = [r["final"] for r in bulk_results] |
|
final_result = "\n".join(final_texts) |
|
|
|
|
|
highlighted = highlight_diff(input_text, final_result) |
|
|
|
return final_result, highlighted |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# ๊ต์ด ๋ชจ๋ธ ๋ฐ๋ชจ") |
|
input_text = gr.Textbox( |
|
label="์๋ฌธ ์
๋ ฅ", lines=10, placeholder="๋ฌธ๋จ ๋จ์๋ก ์
๋ ฅํด ์ฃผ์ธ์." |
|
) |
|
btn = gr.Button("๊ต์ดํ๊ธฐ") |
|
output_corrected = gr.Textbox(label="๊ต์ด ๊ฒฐ๊ณผ", lines=10) |
|
output_highlight = gr.HTML(label="์์ ๋ ๋ถ๋ถ ๊ฐ์กฐ") |
|
|
|
btn.click( |
|
fn=demo_fn, inputs=input_text, outputs=[output_corrected, output_highlight] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|