| """ |
| Fixed Optimized Japanese Counseling Model Benchmark with proper DataParallel handling |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn.parallel import DataParallel |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import numpy as np |
| from typing import List, Dict, Tuple, Optional, Any |
| import json |
| from tqdm import tqdm |
| import os |
| import gc |
| import warnings |
| from datetime import datetime |
| import pandas as pd |
| from collections import defaultdict |
| import MeCab |
| from rouge_score import rouge_scorer |
| from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
| import re |
| import wandb |
| from concurrent.futures import ThreadPoolExecutor |
| import time |
|
|
| |
| warnings.filterwarnings('ignore') |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
|
|
| |
| import logging |
| logging.getLogger('pydantic').setLevel(logging.ERROR) |
|
|
| class TestDataset(Dataset): |
| """Custom dataset for efficient batch processing""" |
| |
| def __init__(self, data: List[Dict]): |
| self.data = data |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| return self.data[idx] |
|
|
| def custom_collate_fn(batch): |
| """Custom collate function to handle dictionary data properly""" |
| return batch |
|
|
| class OptimizedJapaneseBenchmark: |
| """ |
| Highly optimized benchmark suite with multi-GPU support and WandB logging |
| """ |
| |
| def __init__(self, |
| base_model_name: str = "LiquidAI/LFM2-1.2B", |
| finetuned_model_path: str = "./merged_counselor_model", |
| test_data_path: str = "./processed_data_score80/test.jsonl", |
| batch_size: int = 16, |
| num_workers: int = 0, |
| use_wandb: bool = True): |
| """ |
| Initialize optimized benchmark with multi-GPU support |
| """ |
| self.base_model_name = base_model_name |
| self.finetuned_model_path = finetuned_model_path |
| self.test_data_path = test_data_path |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| |
| |
| self.setup_devices() |
| |
| |
| if use_wandb: |
| self.init_wandb() |
| else: |
| self.wandb_enabled = False |
| |
| |
| self.setup_tokenizers_and_scorers() |
| |
| |
| self.results = {} |
| self.detailed_results = [] |
| |
| def setup_devices(self): |
| """Setup multi-GPU configuration""" |
| if torch.cuda.is_available(): |
| self.num_gpus = torch.cuda.device_count() |
| print(f"🚀 Found {self.num_gpus} GPUs") |
| |
| self.device_ids = list(range(self.num_gpus)) |
| self.device = torch.device("cuda:0") |
| |
| for i in range(self.num_gpus): |
| print(f" GPU {i}: {torch.cuda.get_device_name(i)}") |
| print(f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB") |
| else: |
| self.num_gpus = 0 |
| self.device = torch.device("cpu") |
| print("⚠️ No GPU found, using CPU") |
| |
| def init_wandb(self): |
| """Initialize WandB for experiment tracking""" |
| try: |
| run_name = f"benchmark-{datetime.now().strftime('%Y%m%d-%H%M%S')}" |
| |
| wandb.init( |
| project="japanese-counseling-benchmark", |
| name=run_name, |
| config={ |
| "base_model": self.base_model_name, |
| "finetuned_model": self.finetuned_model_path, |
| "batch_size": self.batch_size, |
| "num_gpus": self.num_gpus, |
| "timestamp": datetime.now().isoformat() |
| }, |
| tags=["benchmark", "japanese", "counseling", "multi-gpu"] |
| ) |
| |
| self.wandb_enabled = True |
| print(f"✅ WandB initialized: {wandb.run.name}") |
| print(f"📊 View at: {wandb.run.get_url()}") |
| except Exception as e: |
| print(f"⚠️ WandB initialization failed: {e}") |
| self.wandb_enabled = False |
| |
| def setup_tokenizers_and_scorers(self): |
| """Setup tokenizers and scoring functions""" |
| |
| try: |
| self.mecab = MeCab.Tagger("-Owakati") |
| print("✅ MeCab initialized") |
| except: |
| print("⚠️ MeCab not available, using character tokenization") |
| self.mecab = None |
| |
| |
| self.rouge_scorer = rouge_scorer.RougeScorer( |
| ['rouge1', 'rouge2', 'rougeL'], |
| use_stemmer=False |
| ) |
| |
| |
| self.smoothing = SmoothingFunction().method1 |
| |
| def load_test_data_fast(self, max_samples: Optional[int] = None) -> List[Dict]: |
| """Fast loading of test data""" |
| print(f"\n📚 Loading test data from {self.test_data_path}") |
| |
| test_data = [] |
| |
| if not os.path.exists(self.test_data_path): |
| print("⚠️ Test data not found, using synthetic data") |
| return self.create_synthetic_test_data() |
| |
| try: |
| with open(self.test_data_path, 'r', encoding='utf-8') as f: |
| lines = f.readlines() |
| |
| if max_samples: |
| lines = lines[:max_samples] |
| |
| for line in tqdm(lines, desc="Loading data"): |
| try: |
| data = json.loads(line) |
| text = data.get('text', '') |
| |
| if "### Input:" in text and "### Response:" in text: |
| input_part = text.split("### Input:")[1].split("### Response:")[0].strip() |
| response_part = text.split("### Response:")[1].strip() |
| |
| test_data.append({ |
| 'input': input_part, |
| 'reference': response_part, |
| 'score': data.get('score', 0), |
| 'topic': data.get('topic', 'Unknown') |
| }) |
| except: |
| continue |
| |
| except Exception as e: |
| print(f"Error loading data: {e}") |
| return self.create_synthetic_test_data() |
| |
| if not test_data: |
| print("⚠️ No valid data found, using synthetic data") |
| return self.create_synthetic_test_data() |
| |
| print(f"✅ Loaded {len(test_data)} test examples") |
| |
| if self.wandb_enabled: |
| wandb.log({"test_data_size": len(test_data)}) |
| |
| return test_data |
| |
| def create_synthetic_test_data(self) -> List[Dict]: |
| """Create synthetic test data""" |
| return [ |
| { |
| 'input': f'ストレスを感じています。', |
| 'reference': f'お気持ちわかります。どのような状況でストレスを感じていますか?', |
| 'score': 75, |
| 'topic': 'stress' |
| } |
| for i in range(10) |
| ] |
| |
| def load_models_optimized(self): |
| """Load models with optimization for multi-GPU""" |
| print("\n🤖 Loading models with optimization...") |
| |
| |
| print(" Loading tokenizer...") |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.base_model_name, |
| use_fast=True |
| ) |
| except: |
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True) |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| print(" Loading base model...") |
| try: |
| base_model = AutoModelForCausalLM.from_pretrained( |
| self.base_model_name, |
| torch_dtype=torch.float16, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True |
| ) |
| except Exception as e: |
| print(f" Error loading base model: {e}") |
| print(" Using GPT2 as fallback...") |
| base_model = AutoModelForCausalLM.from_pretrained( |
| "gpt2", |
| torch_dtype=torch.float16 |
| ) |
| |
| |
| print(" Loading fine-tuned model...") |
| if os.path.exists(self.finetuned_model_path): |
| try: |
| finetuned_model = AutoModelForCausalLM.from_pretrained( |
| self.finetuned_model_path, |
| torch_dtype=torch.float16, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| local_files_only=True |
| ) |
| except Exception as e: |
| print(f" Error loading fine-tuned model: {e}") |
| finetuned_model = base_model |
| else: |
| print(" Fine-tuned model not found, using base model") |
| finetuned_model = base_model |
| |
| |
| base_model = base_model.to(self.device) |
| finetuned_model = finetuned_model.to(self.device) |
| |
| |
| if self.num_gpus > 1: |
| print(f" Setting up DataParallel for {self.num_gpus} GPUs...") |
| self.base_model = DataParallel(base_model, device_ids=self.device_ids) |
| self.finetuned_model = DataParallel(finetuned_model, device_ids=self.device_ids) |
| else: |
| self.base_model = base_model |
| self.finetuned_model = finetuned_model |
| |
| self.base_model.eval() |
| self.finetuned_model.eval() |
| |
| print("✅ Models loaded and optimized!") |
| |
| if self.wandb_enabled: |
| wandb.log({ |
| "model_loaded": True, |
| "num_gpus_used": self.num_gpus |
| }) |
| |
| def generate_batch_responses(self, model, prompts: List[str], max_length: int = 150) -> List[str]: |
| """Generate responses in batch for efficiency""" |
| if len(prompts) == 0: |
| return [] |
| |
| formatted_prompts = [ |
| f"""### Instruction: |
| あなたは思いやりのある心理カウンセラーです。 |
| |
| ### Input: |
| {prompt} |
| |
| ### Response: |
| """ for prompt in prompts |
| ] |
| |
| try: |
| |
| inputs = self.tokenizer( |
| formatted_prompts, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| padding=True, |
| padding_side= 'left' |
| ) |
| |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| |
| actual_model = model.module if isinstance(model, DataParallel) else model |
| |
| |
| with torch.no_grad(): |
| with torch.cuda.amp.autocast(): |
| outputs = actual_model.generate( |
| **inputs, |
| max_new_tokens=max_length, |
| temperature=0.7, |
| do_sample=True, |
| top_p=0.9, |
| num_beams=1, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| |
| responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| |
| |
| extracted_responses = [] |
| for i, response in enumerate(responses): |
| if "### Response:" in response: |
| extracted = response.split("### Response:")[-1].strip() |
| else: |
| extracted = response[len(formatted_prompts[i]):].strip() |
| extracted_responses.append(extracted if extracted else "応答を生成できませんでした。") |
| |
| return extracted_responses |
| |
| except Exception as e: |
| print(f"Error in batch generation: {e}") |
| |
| return ["申し訳ございません。応答を生成できませんでした。"] * len(prompts) |
| |
| def tokenize_japanese(self, text: str) -> List[str]: |
| """Tokenize Japanese text""" |
| if not text: |
| return ['empty'] |
| |
| if self.mecab: |
| try: |
| tokens = self.mecab.parse(text).strip().split() |
| return tokens if tokens else list(text) |
| except: |
| pass |
| |
| |
| return list(text.replace(' ', '')) |
| |
| def calculate_metrics_batch(self, references: List[str], hypotheses: List[str]) -> Dict: |
| """Calculate all metrics in batch""" |
| metrics = defaultdict(list) |
| |
| for ref, hyp in zip(references, hypotheses): |
| if not ref or not hyp: |
| |
| for n in range(1, 5): |
| metrics[f'BLEU-{n}'].append(0.0) |
| metrics['ROUGE-1'].append(0.0) |
| metrics['ROUGE-2'].append(0.0) |
| metrics['ROUGE-L'].append(0.0) |
| continue |
| |
| try: |
| |
| ref_tokens = self.tokenize_japanese(ref) |
| hyp_tokens = self.tokenize_japanese(hyp) |
| |
| |
| for n in range(1, 5): |
| weights = tuple([1/n] * n + [0] * (4-n)) |
| try: |
| score = sentence_bleu( |
| [ref_tokens], |
| hyp_tokens, |
| weights=weights, |
| smoothing_function=self.smoothing |
| ) |
| metrics[f'BLEU-{n}'].append(score) |
| except: |
| metrics[f'BLEU-{n}'].append(0.0) |
| |
| |
| try: |
| ref_spaced = ' '.join(ref_tokens) |
| hyp_spaced = ' '.join(hyp_tokens) |
| rouge_scores = self.rouge_scorer.score(ref_spaced, hyp_spaced) |
| metrics['ROUGE-1'].append(rouge_scores['rouge1'].fmeasure) |
| metrics['ROUGE-2'].append(rouge_scores['rouge2'].fmeasure) |
| metrics['ROUGE-L'].append(rouge_scores['rougeL'].fmeasure) |
| except: |
| metrics['ROUGE-1'].append(0.0) |
| metrics['ROUGE-2'].append(0.0) |
| metrics['ROUGE-L'].append(0.0) |
| |
| except Exception as e: |
| |
| for n in range(1, 5): |
| metrics[f'BLEU-{n}'].append(0.0) |
| metrics['ROUGE-1'].append(0.0) |
| metrics['ROUGE-2'].append(0.0) |
| metrics['ROUGE-L'].append(0.0) |
| |
| return dict(metrics) |
| |
| def run_fast_benchmark(self, num_samples: Optional[int] = None): |
| """Run optimized benchmark with batch processing""" |
| print("\n" + "="*80) |
| print("🚀 Running Fast Multi-GPU Benchmark") |
| print("="*80) |
| |
| start_time = time.time() |
| |
| |
| test_data = self.load_test_data_fast(max_samples=num_samples) |
| |
| if not test_data: |
| raise ValueError("No test data available!") |
| |
| |
| dataset = TestDataset(test_data) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=self.batch_size, |
| shuffle=False, |
| num_workers=0, |
| collate_fn=custom_collate_fn, |
| pin_memory=True if self.device.type == 'cuda' else False |
| ) |
| |
| |
| all_base_metrics = defaultdict(list) |
| all_finetuned_metrics = defaultdict(list) |
| |
| print(f"\n📊 Evaluating {len(test_data)} examples in {len(dataloader)} batches...") |
| print(f" Batch size: {self.batch_size}") |
| print(f" Using {self.num_gpus} GPU(s)") |
| |
| |
| successful_batches = 0 |
| for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")): |
| try: |
| |
| inputs = [item['input'] for item in batch] |
| references = [item['reference'] for item in batch] |
| |
| |
| base_responses = self.generate_batch_responses(self.base_model, inputs) |
| finetuned_responses = self.generate_batch_responses(self.finetuned_model, inputs) |
| |
| |
| base_metrics = self.calculate_metrics_batch(references, base_responses) |
| finetuned_metrics = self.calculate_metrics_batch(references, finetuned_responses) |
| |
| |
| for key, values in base_metrics.items(): |
| all_base_metrics[key].extend(values) |
| for key, values in finetuned_metrics.items(): |
| all_finetuned_metrics[key].extend(values) |
| |
| successful_batches += 1 |
| |
| |
| if self.wandb_enabled and batch_idx % 5 == 0: |
| progress = (batch_idx + 1) / len(dataloader) * 100 |
| |
| |
| current_bleu4_base = np.mean(all_base_metrics.get('BLEU-4', [0])) |
| current_bleu4_finetuned = np.mean(all_finetuned_metrics.get('BLEU-4', [0])) |
| current_rouge_l_base = np.mean(all_base_metrics.get('ROUGE-L', [0])) |
| current_rouge_l_finetuned = np.mean(all_finetuned_metrics.get('ROUGE-L', [0])) |
| |
| wandb.log({ |
| "progress": progress, |
| "batches_processed": batch_idx + 1, |
| "samples_processed": min((batch_idx + 1) * self.batch_size, len(test_data)), |
| "current_bleu4_base": current_bleu4_base, |
| "current_bleu4_finetuned": current_bleu4_finetuned, |
| "current_rouge_l_base": current_rouge_l_base, |
| "current_rouge_l_finetuned": current_rouge_l_finetuned |
| }) |
| |
| |
| if batch_idx == 0 and len(inputs) > 0: |
| for i in range(min(3, len(inputs))): |
| self.detailed_results.append({ |
| 'input': inputs[i], |
| 'reference': references[i], |
| 'base_response': base_responses[i] if i < len(base_responses) else "", |
| 'finetuned_response': finetuned_responses[i] if i < len(finetuned_responses) else "" |
| }) |
| |
| |
| print(f"\n📝 Sample Example:") |
| print(f"Input: {inputs[0][:100]}...") |
| print(f"Reference: {references[0][:100]}...") |
| print(f"Base response: {base_responses[0][:100]}...") |
| print(f"Fine-tuned response: {finetuned_responses[0][:100]}...") |
| |
| except Exception as e: |
| print(f"Error processing batch {batch_idx}: {e}") |
| continue |
| |
| print(f"\n✅ Successfully processed {successful_batches}/{len(dataloader)} batches") |
| |
| |
| self.results = self.calculate_final_statistics(all_base_metrics, all_finetuned_metrics) |
| |
| |
| total_time = time.time() - start_time |
| samples_per_second = len(test_data) / total_time if total_time > 0 else 0 |
| |
| print(f"\n⏱️ Benchmark completed in {total_time:.2f} seconds") |
| print(f" Processing speed: {samples_per_second:.2f} samples/second") |
| |
| |
| if self.wandb_enabled: |
| wandb.log({ |
| "total_time_seconds": total_time, |
| "samples_per_second": samples_per_second, |
| "total_samples": len(test_data), |
| "successful_batches": successful_batches, |
| **{f"final_{k}": v for k, v in self.results['summary'].items()} |
| }) |
| |
| |
| for metric_name, improvements in self.results['improvements'].items(): |
| wandb.log({f"improvement_{metric_name}": improvements}) |
| |
| |
| if self.results['metrics']: |
| self.create_wandb_visualizations() |
| |
| |
| self.print_results() |
| |
| return self.results |
| |
| def create_wandb_visualizations(self): |
| """Create WandB visualizations""" |
| if not self.wandb_enabled or not self.results.get('metrics'): |
| return |
| |
| try: |
| |
| data = [] |
| for metric in self.results['metrics']: |
| data.append([ |
| metric, |
| self.results['metrics'][metric]['base']['mean'], |
| self.results['metrics'][metric]['finetuned']['mean'], |
| self.results['improvements'][metric] |
| ]) |
| |
| table = wandb.Table( |
| columns=["Metric", "Base", "Fine-tuned", "Improvement (%)"], |
| data=data |
| ) |
| wandb.log({"results_comparison": table}) |
| |
| |
| wandb.log({ |
| "improvements_chart": wandb.plot.bar( |
| wandb.Table( |
| data=[[m, self.results['improvements'][m]] for m in self.results['improvements']], |
| columns=["Metric", "Improvement (%)"] |
| ), |
| "Metric", "Improvement (%)", |
| title="Model Improvements" |
| ) |
| }) |
| except Exception as e: |
| print(f"Error creating visualizations: {e}") |
| |
| def calculate_final_statistics(self, base_metrics: Dict, finetuned_metrics: Dict) -> Dict: |
| """Calculate final aggregate statistics""" |
| results = { |
| 'metrics': {}, |
| 'improvements': {}, |
| 'summary': {} |
| } |
| |
| |
| all_metric_names = set(base_metrics.keys()) | set(finetuned_metrics.keys()) |
| |
| for metric in all_metric_names: |
| base_values = base_metrics.get(metric, [0]) |
| finetuned_values = finetuned_metrics.get(metric, [0]) |
| |
| |
| base_values = [v for v in base_values if v is not None] |
| finetuned_values = [v for v in finetuned_values if v is not None] |
| |
| if not base_values: |
| base_values = [0] |
| if not finetuned_values: |
| finetuned_values = [0] |
| |
| results['metrics'][metric] = { |
| 'base': { |
| 'mean': float(np.mean(base_values)), |
| 'std': float(np.std(base_values)), |
| 'min': float(np.min(base_values)), |
| 'max': float(np.max(base_values)) |
| }, |
| 'finetuned': { |
| 'mean': float(np.mean(finetuned_values)), |
| 'std': float(np.std(finetuned_values)), |
| 'min': float(np.min(finetuned_values)), |
| 'max': float(np.max(finetuned_values)) |
| } |
| } |
| |
| |
| base_mean = np.mean(base_values) |
| finetuned_mean = np.mean(finetuned_values) |
| if base_mean > 0: |
| improvement = ((finetuned_mean - base_mean) / base_mean) * 100 |
| else: |
| improvement = 0 if finetuned_mean == 0 else 100 |
| |
| results['improvements'][metric] = improvement |
| |
| |
| bleu_metrics = [m for m in results['metrics'] if 'BLEU' in m] |
| rouge_metrics = [m for m in results['metrics'] if 'ROUGE' in m] |
| |
| results['summary'] = { |
| 'bleu_avg_improvement': np.mean([results['improvements'][m] for m in bleu_metrics]) if bleu_metrics else 0, |
| 'rouge_avg_improvement': np.mean([results['improvements'][m] for m in rouge_metrics]) if rouge_metrics else 0, |
| 'overall_improvement': np.mean(list(results['improvements'].values())) if results['improvements'] else 0 |
| } |
| |
| return results |
| |
| def print_results(self): |
| """Print formatted results""" |
| print("\n" + "="*80) |
| print("📊 BENCHMARK RESULTS") |
| print("="*80) |
| |
| if not self.results or 'metrics' not in self.results: |
| print("No results to display") |
| return |
| |
| |
| print("\n📘 BLEU Scores:") |
| print("-"*60) |
| print(f"{'Metric':<15} {'Base':<15} {'Fine-tuned':<15} {'Improvement':<15}") |
| print("-"*60) |
| |
| for metric in sorted([m for m in self.results['metrics'] if 'BLEU' in m]): |
| base = self.results['metrics'][metric]['base']['mean'] |
| finetuned = self.results['metrics'][metric]['finetuned']['mean'] |
| improvement = self.results['improvements'][metric] |
| print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%") |
| |
| |
| print("\n📕 ROUGE Scores:") |
| print("-"*60) |
| |
| for metric in sorted([m for m in self.results['metrics'] if 'ROUGE' in m]): |
| base = self.results['metrics'][metric]['base']['mean'] |
| finetuned = self.results['metrics'][metric]['finetuned']['mean'] |
| improvement = self.results['improvements'][metric] |
| print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%") |
| |
| |
| print("\n" + "="*80) |
| print("📈 SUMMARY") |
| print("="*80) |
| print(f"BLEU Average Improvement: {self.results['summary']['bleu_avg_improvement']:+.1f}%") |
| print(f"ROUGE Average Improvement: {self.results['summary']['rouge_avg_improvement']:+.1f}%") |
| print(f"Overall Improvement: {self.results['summary']['overall_improvement']:+.1f}%") |
| print("="*80) |
| |
| def save_results(self, output_dir: str = "./benchmark_results"): |
| """Save results""" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| with open(os.path.join(output_dir, "results.json"), 'w', encoding='utf-8') as f: |
| json.dump(self.results, f, ensure_ascii=False, indent=2, default=str) |
| |
| with open(os.path.join(output_dir, "examples.json"), 'w', encoding='utf-8') as f: |
| json.dump(self.detailed_results, f, ensure_ascii=False, indent=2) |
| |
| |
| if self.wandb_enabled: |
| try: |
| artifact = wandb.Artifact( |
| name=f"benchmark-results-{wandb.run.id}", |
| type="benchmark_results", |
| description="Japanese counseling model benchmark results" |
| ) |
| artifact.add_dir(output_dir) |
| wandb.log_artifact(artifact) |
| except Exception as e: |
| print(f"Error saving to WandB: {e}") |
| |
| print(f"✅ Results saved to {output_dir}/") |
| |
| def cleanup(self): |
| """Clean up resources""" |
| if self.wandb_enabled: |
| wandb.finish() |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| gc.collect() |
|
|
|
|
| def main(): |
| """Main execution""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description='Optimized Japanese Counseling Benchmark') |
| parser.add_argument('--base_model', type=str, default='LiquidAI/LFM2-1.2B') |
| parser.add_argument('--finetuned_model', type=str, default='./merged_counselor_model') |
| parser.add_argument('--test_data', type=str, default='./processed_data_score80/test.jsonl') |
| parser.add_argument('--batch_size', type=int, default=16, help='Batch size for processing') |
| parser.add_argument('--num_samples', type=int, default=None, help='Number of samples to evaluate') |
| parser.add_argument('--output_dir', type=str, default='./benchmark_results_fast') |
| parser.add_argument('--no_wandb', action='store_true', help='Disable WandB logging') |
| |
| args = parser.parse_args() |
| |
| try: |
| |
| print("🚀 Initializing Optimized Benchmark Suite") |
| benchmark = OptimizedJapaneseBenchmark( |
| base_model_name=args.base_model, |
| finetuned_model_path=args.finetuned_model, |
| test_data_path=args.test_data, |
| batch_size=args.batch_size, |
| use_wandb=not args.no_wandb |
| ) |
| |
| |
| benchmark.load_models_optimized() |
| |
| |
| results = benchmark.run_fast_benchmark(num_samples=args.num_samples) |
| |
| |
| benchmark.save_results(args.output_dir) |
| |
| |
| benchmark.cleanup() |
| |
| print("\n✅ Benchmark completed successfully!") |
| |
| except Exception as e: |
| print(f"\n❌ Error: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| if 'benchmark' in locals(): |
| benchmark.cleanup() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|