Spaces:
Sleeping
Sleeping
| # ํ์ผ ์ด๋ฆ: evaluate.py | |
| # ํ์ตํ ๋ชจ๋ธ์ ํ๊ฐํ๊ณ ํผ๋ ํ๋ ฌ์ ์์ฑํ๋ ์คํฌ๋ฆฝํธ | |
| import torch | |
| import pandas as pd | |
| # 'from pyexpat import model' ๋ผ์ธ์ ์์ ํ ์ญ์ ํฉ๋๋ค. | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix | |
| import os | |
| import json | |
| import re | |
| import platform | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| # --- Matplotlib ํ๊ธ ํฐํธ ์ค์ (์ด์ ๊ณผ ๋์ผ) --- | |
| try: | |
| if platform.system() == 'Windows': | |
| plt.rc('font', family='Malgun Gothic') | |
| elif platform.system() == 'Darwin': # Mac OS | |
| plt.rc('font', family='AppleGothic') | |
| else: # Linux | |
| plt.rc('font', family='NanumBarunGothic') | |
| plt.rcParams['axes.unicode_minus'] = False | |
| except: | |
| print("ํ๊ธ ํฐํธ ์ค์ ์ ์คํจํ์ต๋๋ค. ๊ทธ๋ํ์ ๋ผ๋ฒจ์ด ๊นจ์ง ์ ์์ต๋๋ค.") | |
| # --- ํฌํผ(Helper) ํจ์ ๋ฐ ํด๋์ค ์ ์ (์ด์ ๊ณผ ๋์ผ) --- | |
| class EmotionDataset(torch.utils.data.Dataset): | |
| def __init__(self, encodings, labels): | |
| self.encodings = encodings | |
| self.labels = labels | |
| def __getitem__(self, idx): | |
| item = {key: val[idx].clone().detach() for key, val in self.encodings.items()} | |
| item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long) | |
| return item | |
| def __len__(self): | |
| return len(self.labels) | |
| def compute_metrics(pred): | |
| labels = pred.label_ids | |
| preds = pred.predictions.argmax(-1) | |
| acc = accuracy_score(labels, preds) | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0) | |
| return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall} | |
| # --- train_final.py์ ๋์ผํ ๋ฐ์ดํฐ ์ฒ๋ฆฌ ๋ก์ง ์ ์ฒด๋ฅผ ์ฌ๊ธฐ์ ์ถ๊ฐ --- | |
| def map_ecode_to_major_emotion(ecode): | |
| """E์ฝ๋๋ฅผ ๋๋ถ๋ฅ ๊ฐ์ ์ผ๋ก ๋งคํํ๋ ํจ์""" | |
| try: | |
| code_num = int(ecode[1:]) | |
| except (ValueError, TypeError): | |
| return None | |
| # ์ด ๋ถ๋ถ์ train_final.py์ ์์ ํ ๋์ผํด์ผ ํฉ๋๋ค. | |
| if 10 <= code_num <= 19: return '๋ถ๋ ธ' | |
| elif 20 <= code_num <= 29: return '์ฌํ' | |
| elif 30 <= code_num <= 39: return '๋ถ์' | |
| elif 40 <= code_num <= 49: return '์์ฒ' | |
| elif 50 <= code_num <= 59: return '๋นํฉ' | |
| elif 60 <= code_num <= 69: return '๊ธฐ์จ' | |
| else: return None | |
| def load_and_process_validation_data(file_path='./data/'): | |
| """JSON์ ๋ก๋ํ๊ณ ๋ ์ด๋ธ์ ํตํฉ/์ฒ๋ฆฌํ๋ ์์ ํ ํจ์""" | |
| # ์ฃผ์: ์ค์ ํ ์คํธ ํ์ผ๋ช ์ผ๋ก ๋ณ๊ฒฝํด์ผ ํฉ๋๋ค. | |
| test_label_path = os.path.join(file_path, 'test.json') | |
| try: | |
| with open(test_label_path, 'r', encoding='utf-8') as f: | |
| test_data_raw = json.load(f) | |
| except FileNotFoundError: | |
| print(f"์ค๋ฅ: ํ ์คํธ์ฉ ๋ผ๋ฒจ ํ์ผ '{test_label_path}'๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| return None | |
| data = [{'text': " ".join(d['talk']['content'].values()), 'emotion': d['profile']['emotion']['type']} for d in test_data_raw] | |
| df_test = pd.DataFrame(data) | |
| df_test['major_emotion'] = df_test['emotion'].apply(map_ecode_to_major_emotion) | |
| df_test.dropna(subset=['major_emotion'], inplace=True) | |
| def clean_text(text): | |
| return re.sub(r'[^๊ฐ-ํฃa-zA-Z0-9 ]', '', text) | |
| df_test['cleaned_text'] = df_test['text'].apply(clean_text) | |
| return df_test | |
| # --- ๋ฉ์ธ ํ๊ฐ ๋ก์ง --- | |
| def evaluate_saved_model(): | |
| """์ ์ฅ๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ ์ฑ๋ฅ ํ๊ฐ ๋ฐ ํผ๋ ํ๋ ฌ์ ์์ฑํ๋ ๋ฉ์ธ ํจ์""" | |
| MODEL_PATH = "E:/Emotion/results/emotion_model_v2_manual" | |
| print(f"'{MODEL_PATH}' ๊ฒฝ๋ก์ ๋ชจ๋ธ์ ํ๊ฐํฉ๋๋ค.") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| loaded_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) | |
| loaded_model.config.problem_type = "single_label_classification" | |
| except OSError: | |
| print(f"์ค๋ฅ: '{MODEL_PATH}' ๊ฒฝ๋ก์์ ๋ชจ๋ธ ๋๋ ํ ํฌ๋์ด์ ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| return | |
| df_val = load_and_process_validation_data() | |
| if df_val is None or df_val.empty: | |
| print("์ฒ๋ฆฌ ํ ํ๊ฐ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.") | |
| return | |
| label2id = loaded_model.config.label2id | |
| id2label = loaded_model.config.id2label | |
| df_val['label_id'] = df_val['major_emotion'].map(label2id) | |
| df_val.dropna(subset=['label_id'], inplace=True) | |
| val_labels = df_val['label_id'].tolist() | |
| val_texts = df_val['cleaned_text'].tolist() | |
| val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=128, return_tensors="pt") | |
| val_dataset = EmotionDataset(val_encodings, val_labels) | |
| training_args = TrainingArguments( | |
| output_dir="./results1/temp_eval", | |
| report_to="none" | |
| ) | |
| trainer = Trainer( | |
| model=loaded_model, | |
| args=training_args, | |
| compute_metrics=compute_metrics | |
| ) | |
| print("ํ๊ฐ๋ฅผ ์์ํฉ๋๋ค...") | |
| results = trainer.evaluate(eval_dataset=val_dataset) | |
| print("\n--- ์ต์ข ํ๊ฐ ๊ฒฐ๊ณผ ---") | |
| print(results) | |
| # ์ต์ข ํ๊ฐ ๊ฒฐ๊ณผ๋ฅผ JSON ํ์ผ๋ก ์ ์ฅ | |
| results_to_save = { | |
| "accuracy": results.get("eval_accuracy"), | |
| "f1": results.get("eval_f1"), | |
| "loss": results.get("eval_loss") # ์์ค ๊ฐ ์ถ๊ฐ | |
| } | |
| results_path = os.path.join(MODEL_PATH, "final_test_results.json") | |
| with open(results_path, "w", encoding='utf-8') as f: | |
| json.dump(results_to_save, f, indent=4, ensure_ascii=False) | |
| print(f"์ต์ข ํ๊ฐ ๊ฒฐ๊ณผ๊ฐ {results_path}์ ์ ์ฅ๋์์ต๋๋ค.") | |
| print("\n--- ํผ๋ ํ๋ ฌ ์์ฑ ---") | |
| predictions = trainer.predict(val_dataset) | |
| y_pred = predictions.predictions.argmax(-1) | |
| y_true = predictions.label_ids | |
| labels = [id2label[i] for i in sorted(id2label.keys())] | |
| cm = confusion_matrix(y_true, y_pred) | |
| plt.figure(figsize=(10, 8)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels) | |
| plt.xlabel('์์ธก ๋ผ๋ฒจ (Predicted Label)') | |
| plt.ylabel('์ค์ ๋ผ๋ฒจ (True Label)') | |
| plt.title('Confusion Matrix') | |
| cm_path = os.path.join(MODEL_PATH, "confusion_matrix.png") | |
| plt.savefig(cm_path) | |
| print(f"ํผ๋ ํ๋ ฌ์ด {cm_path}์ ์ ์ฅ๋์์ต๋๋ค.") | |
| if __name__ == "__main__": | |
| evaluate_saved_model() |