|
|
|
|
|
from flask import Flask, render_template, request, session, redirect, url_for |
|
from flask_session import Session |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import nltk |
|
from rouge_score import rouge_scorer |
|
from sacrebleu.metrics import BLEU |
|
from datetime import datetime |
|
import os |
|
import math |
|
import logging |
|
import gc |
|
import time |
|
|
|
print("AI ๋ชจ๋ธ๊ณผ ํ๊ฐ ์งํ๋ฅผ ๋ก๋ฉํฉ๋๋ค...") |
|
try: |
|
nltk_data_path = '/tmp/nltk_data' |
|
nltk.download('punkt', download_dir=nltk_data_path, quiet=True) |
|
nltk.data.path.append(nltk_data_path) |
|
|
|
model_name = "EleutherAI/polyglot-ko-1.3b" |
|
|
|
print(f"๋ชจ๋ธ ๋ก๋ฉ ์ค: {model_name}") |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True |
|
) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True |
|
) |
|
model.to(device) |
|
|
|
|
|
model.eval() |
|
if torch.cuda.is_available(): |
|
model.half() |
|
|
|
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) |
|
bleu = BLEU() |
|
|
|
print("AI ๋ชจ๋ธ ๋ก๋ฉ ๋ฐ ์ต์ ํ ์๋ฃ.") |
|
model_loaded = True |
|
|
|
if torch.cuda.is_available(): |
|
print(f"GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") |
|
|
|
except Exception as e: |
|
print(f"๋ชจ๋ธ ๋ก๋ฉ ์ค ์ฌ๊ฐํ ์ค๋ฅ ๋ฐ์: {e}") |
|
model_loaded = False |
|
|
|
app = Flask(__name__) |
|
|
|
app.config["SESSION_PERMANENT"] = False |
|
app.config["SESSION_TYPE"] = "filesystem" |
|
app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', os.urandom(24)) |
|
Session(app) |
|
|
|
log_handler = logging.FileHandler('report_log.txt', encoding='utf-8') |
|
log_handler.setLevel(logging.INFO) |
|
log_formatter = logging.Formatter('%(asctime)s - %(message)s', '%Y-%m-%d %H:%M:%S') |
|
log_handler.setFormatter(log_formatter) |
|
app.logger.addHandler(log_handler) |
|
app.logger.setLevel(logging.INFO) |
|
|
|
|
|
def validate_ppl_text(text): |
|
text_len = len(text) |
|
if text_len < 2000: |
|
return {"valid": False, "message": f"ํ
์คํธ๊ฐ ๋๋ฌด ์งง์ต๋๋ค. ํ์ฌ {text_len}์, ์ต์ 2000์ ์ด์ ์
๋ ฅํด์ฃผ์ธ์."} |
|
|
|
tokens = tokenizer.convert_ids_to_tokens(tokenizer(text, max_length=1024, truncation=True).input_ids) |
|
quadgrams = [tuple(tokens[i:i+4]) for i in range(len(tokens) - 3)] |
|
if len(quadgrams) > 0: |
|
repetition_ratio = 1.0 - (len(set(quadgrams)) / len(quadgrams)) |
|
if repetition_ratio > 0.5: |
|
return {"valid": False, "message": "๋ฐ๋ณต๋๋ ๋ด์ฉ์ด ๋๋ฌด ๋ง์ต๋๋ค. ๋ค์ํ ๋ด์ฉ์ ํ
์คํธ๋ฅผ ์
๋ ฅํด์ฃผ์ธ์."} |
|
|
|
word_count = len(text.split()) |
|
return {"valid": True, "message": f"โ
๊ฒ์ฆ ์๋ฃ: {text_len}์, {word_count}๋จ์ด"} |
|
|
|
|
|
def calculate_perplexity_logic(text, max_tokens=512, use_sliding_window=False): |
|
encodings = tokenizer(text, return_tensors="pt", max_length=max_tokens, truncation=True) |
|
input_ids = encodings.input_ids[0].to(device) |
|
|
|
if len(input_ids) < 10: |
|
raise ValueError("ํ ํฐ ์๊ฐ ๋๋ฌด ์ ์ต๋๋ค (์ต์ 10๊ฐ)") |
|
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
|
repetition_penalties = {} |
|
for n in range(2, 6): |
|
if len(tokens) >= n: |
|
ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)] |
|
if ngrams: |
|
unique_ratio = len(set(ngrams)) / len(ngrams) |
|
repetition_penalties[f'{n}gram'] = 1 - unique_ratio |
|
|
|
avg_repetition = sum(repetition_penalties.values()) / len(repetition_penalties) if repetition_penalties else 0 |
|
penalty_factor = math.exp(avg_repetition * 3.0) |
|
|
|
seq_len = input_ids.size(0) |
|
|
|
with torch.no_grad(): |
|
if not use_sliding_window or seq_len <= 256: |
|
outputs = model(input_ids.unsqueeze(0), labels=input_ids.unsqueeze(0)) |
|
ppl = torch.exp(outputs.loss).item() |
|
else: |
|
max_length = 256 |
|
stride = 128 |
|
nlls = [] |
|
for begin_loc in range(0, seq_len, stride): |
|
end_loc = min(begin_loc + max_length, seq_len) |
|
input_chunk = input_ids[begin_loc:end_loc].unsqueeze(0) |
|
try: |
|
outputs = model(input_chunk, labels=input_chunk) |
|
if outputs.loss is not None and torch.isfinite(outputs.loss): |
|
nlls.append(outputs.loss) |
|
except Exception as chunk_error: |
|
print(f"์ฒญํฌ ์ฒ๋ฆฌ ์ค๋ฅ: {chunk_error}") |
|
continue |
|
if not nlls: |
|
raise RuntimeError("์ ํจํ loss ๊ฐ์ ๊ณ์ฐํ ์ ์์ต๋๋ค") |
|
ppl = torch.exp(torch.mean(torch.stack(nlls))).item() |
|
|
|
adjusted_ppl = ppl * penalty_factor |
|
|
|
return { |
|
'base_ppl': ppl, |
|
'adjusted_ppl': adjusted_ppl, |
|
'penalty_factor': penalty_factor, |
|
'token_count': len(input_ids) |
|
} |
|
|
|
def get_ppl_calculation_mode(text_length): |
|
if text_length > 2000: |
|
return "ultra_fast" |
|
elif text_length > 1000: |
|
return "fast" |
|
else: |
|
return "accurate" |
|
|
|
def get_ppl_score(adjusted_ppl): |
|
if adjusted_ppl < 12: return 3.0 |
|
elif adjusted_ppl < 18: return 2.5 |
|
elif adjusted_ppl < 25: return 2.0 |
|
elif adjusted_ppl < 35: return 1.5 |
|
else: return 1.0 |
|
|
|
def cleanup_memory(): |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
@app.route('/', methods=['GET']) |
|
def index(): |
|
all_results = session.get('all_results', {}) |
|
input_texts = session.get('input_texts', {}) |
|
return render_template('index.html', model_loaded=model_loaded, all_results=all_results, input_texts=input_texts) |
|
|
|
|
|
@app.route('/evaluate', methods=['POST']) |
|
def evaluate_text(): |
|
if 'all_results' not in session: session['all_results'] = {} |
|
if 'input_texts' not in session: session['input_texts'] = {} |
|
|
|
target_url = request.form.get('target_url') |
|
if target_url: session['all_results']['target_url'] = target_url |
|
|
|
metric = request.form.get('metric') |
|
results_to_store = {'metric': metric} |
|
|
|
try: |
|
if metric == 'perplexity': |
|
text = request.form.get('ppl_text', '').strip() |
|
session['input_texts']['ppl_text'] = text |
|
|
|
validation_result = validate_ppl_text(text) |
|
if not validation_result["valid"]: |
|
results_to_store['error'] = validation_result["message"] |
|
elif not model_loaded: |
|
results_to_store['error'] = "๋ชจ๋ธ์ด ๋ก๋ฉ๋์ง ์์์ต๋๋ค." |
|
else: |
|
try: |
|
cleanup_memory() |
|
|
|
calc_mode = get_ppl_calculation_mode(len(text)) |
|
start_time = time.time() |
|
|
|
if calc_mode == "ultra_fast": |
|
ppl_result = calculate_perplexity_logic(text, max_tokens=256, use_sliding_window=False) |
|
elif calc_mode == "fast": |
|
ppl_result = calculate_perplexity_logic(text, max_tokens=384, use_sliding_window=False) |
|
else: |
|
ppl_result = calculate_perplexity_logic(text, max_tokens=512, use_sliding_window=True) |
|
|
|
calc_time = time.time() - start_time |
|
adjusted_ppl = ppl_result['adjusted_ppl'] |
|
|
|
results_to_store['score_value'] = adjusted_ppl |
|
results_to_store['score_display'] = f"{adjusted_ppl:.4f}" |
|
results_to_store['details'] = { |
|
'base_ppl': f"{ppl_result['base_ppl']:.4f}", |
|
'penalty_factor': f"{ppl_result['penalty_factor']:.4f}", |
|
'token_count': ppl_result['token_count'], |
|
'calc_time': f"{calc_time:.2f}s", |
|
'calc_mode': calc_mode |
|
} |
|
results_to_store['final_score'] = get_ppl_score(adjusted_ppl) |
|
|
|
cleanup_memory() |
|
|
|
except Exception as ppl_error: |
|
results_to_store['error'] = f"PPL ๊ณ์ฐ ์ค ์ค๋ฅ: {ppl_error}" |
|
|
|
session['all_results']['perplexity'] = results_to_store |
|
|
|
elif metric == 'rouge': |
|
gen_text = request.form.get('rouge_generated', '').strip() |
|
ref_text = request.form.get('rouge_reference', '').strip() |
|
session['input_texts']['rouge_generated'] = gen_text |
|
session['input_texts']['rouge_reference'] = ref_text |
|
|
|
if not gen_text or not ref_text: |
|
results_to_store['error'] = "์์ฑ๋ ์์ฝ๋ฌธ๊ณผ ์ฐธ์กฐ ์์ฝ๋ฌธ์ ๋ชจ๋ ์
๋ ฅํด์ฃผ์ธ์." |
|
else: |
|
scores = scorer.score(ref_text, gen_text) |
|
r1, r2, rL = scores['rouge1'].fmeasure, scores['rouge2'].fmeasure, scores['rougeL'].fmeasure |
|
|
|
weighted_avg = (r1 * 0.3 + r2 * 0.3 + rL * 0.4) |
|
|
|
len_gen = len(gen_text.split()); len_ref = len(ref_text.split()) |
|
length_ratio = len_gen / len_ref if len_ref > 0 else 0 |
|
if 0.8 <= length_ratio <= 1.2: length_penalty = 1.0 |
|
elif length_ratio < 0.5 or length_ratio > 2.0: length_penalty = 0.8 |
|
else: length_penalty = 0.9 |
|
final_rouge_score = weighted_avg * length_penalty |
|
|
|
results_to_store['score_value'] = final_rouge_score |
|
results_to_store['score_display'] = f"{final_rouge_score:.4f}" |
|
results_to_store['details'] = {'weighted_avg': f"{weighted_avg:.4f}", 'length_penalty': f"{length_penalty:.2f}"} |
|
|
|
if final_rouge_score >= 0.65: results_to_store['final_score'] = 3.0 |
|
elif final_rouge_score >= 0.55: results_to_store['final_score'] = 2.5 |
|
elif final_rouge_score >= 0.45: results_to_store['final_score'] = 2.0 |
|
elif final_rouge_score >= 0.35: results_to_store['final_score'] = 1.5 |
|
else: results_to_store['final_score'] = 1.0 |
|
|
|
session['all_results']['rouge'] = results_to_store |
|
|
|
elif metric == 'bleu': |
|
gen_text = request.form.get('bleu_generated', '').strip() |
|
ref_text = request.form.get('bleu_reference', '').strip() |
|
session['input_texts']['bleu_generated'] = gen_text |
|
session['input_texts']['bleu_reference'] = ref_text |
|
|
|
if not gen_text or not ref_text: |
|
results_to_store['error'] = "์์ฑ๋ ๋ฌธ์ฅ๊ณผ ์ฐธ์กฐ ๋ฌธ์ฅ์ ๋ชจ๋ ์
๋ ฅํด์ฃผ์ธ์." |
|
else: |
|
references = [line.strip() for line in ref_text.split('\n') if line.strip()] |
|
if not references: |
|
results_to_store['error'] = "์ฐธ์กฐ(์ ๋ต) ๋ฒ์ญ๋ฌธ์ ์
๋ ฅํด์ฃผ์ธ์." |
|
else: |
|
bleu_score = bleu.sentence_score(gen_text, references, smooth_method='exp').score / 100 |
|
results_to_store['score_value'] = bleu_score |
|
results_to_store['score_display'] = f"{bleu_score:.4f}" |
|
|
|
if bleu_score >= 0.55: results_to_store['final_score'] = 3.0 |
|
elif bleu_score >= 0.45: results_to_store['final_score'] = 2.5 |
|
elif bleu_score >= 0.35: results_to_store['final_score'] = 2.0 |
|
elif bleu_score >= 0.25: results_to_store['final_score'] = 1.5 |
|
else: results_to_store['final_score'] = 1.0 |
|
|
|
session['all_results']['bleu'] = results_to_store |
|
|
|
elif metric in ['mmlu', 'truthfulqa', 'drop', 'mbpp_humaneval']: |
|
generated_text = request.form.get(f'{metric}_generated', '') |
|
reference_text = request.form.get(f'{metric}_reference', '') |
|
grade = request.form.get(f'{metric}_grade', '') |
|
|
|
session['input_texts'][f'{metric}_generated'] = generated_text |
|
session['input_texts'][f'{metric}_reference'] = reference_text |
|
|
|
max_scores = {'mmlu': 4, 'truthfulqa': 4, 'drop': 4, 'mbpp_humaneval': 3} |
|
max_score = max_scores[metric] |
|
score_map = {'์': 1.0, '์ฐ': 0.9, '๋ฏธ': 0.8, '์': 0.7, '๊ฐ': 0.6} |
|
|
|
if grade and grade in score_map: |
|
final_score = max_score * score_map[grade] |
|
results_to_store['grade'] = grade |
|
results_to_store['final_score'] = final_score |
|
else: |
|
results_to_store['grade'] = None |
|
results_to_store['final_score'] = 0 |
|
if not grade: |
|
results_to_store['error'] = "ํ๊ฐ ๋ฑ๊ธ์ ์ ํํด์ฃผ์ธ์." |
|
|
|
session['all_results'][metric] = results_to_store |
|
|
|
except Exception as e: |
|
results_to_store['error'] = f"๊ณ์ฐ ์ค ์ค๋ฅ ๋ฐ์: {e}" |
|
session['all_results'][metric] = results_to_store |
|
app.logger.error(f"ํ๊ฐ ์ค ์ค๋ฅ - ๋ฉํธ๋ฆญ: {metric}, ์ค๋ฅ: {e}") |
|
|
|
session.modified = True |
|
return redirect(url_for('index', _anchor=metric)) |
|
|
|
|
|
@app.route('/report') |
|
def report(): |
|
all_results = session.get('all_results', {}) |
|
input_texts = session.get('input_texts', {}) |
|
try: |
|
target_url = all_results.get('target_url', 'N/A') |
|
total_score = sum(res.get('final_score', 0) for res in all_results.values() if isinstance(res, dict)) |
|
log_message = f"๋ณด๊ณ ์ ์์ฑ - ๋์: {target_url}, ์ด์ : {total_score:.2f}/24" |
|
app.logger.info(log_message) |
|
except Exception as e: |
|
app.logger.error(f"๋ก๊ทธ ๊ธฐ๋ก ์ค ์ค๋ฅ ๋ฐ์: {e}") |
|
return render_template('report.html', all_results=all_results, input_texts=input_texts) |
|
|
|
|
|
@app.route('/reset') |
|
def reset(): |
|
session.pop('all_results', None) |
|
session.pop('input_texts', None) |
|
cleanup_memory() |
|
return redirect(url_for('index')) |
|
|
|
|
|
@app.route('/memory_status') |
|
def memory_status(): |
|
status = {} |
|
if torch.cuda.is_available(): |
|
status['gpu_allocated'] = f"{torch.cuda.memory_allocated() / 1024**3:.2f} GB" |
|
status['gpu_reserved'] = f"{torch.cuda.memory_reserved() / 1024**3:.2f} GB" |
|
import psutil |
|
process = psutil.Process() |
|
status['ram_usage'] = f"{process.memory_info().rss / 1024**3:.2f} GB" |
|
return status |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |