| |
| import json |
| import os |
| from pathlib import Path |
| import pandas as pd |
| from typing import List, Dict, Tuple |
| import numpy as np |
| from tqdm import tqdm |
| from sklearn.model_selection import train_test_split |
|
|
| class KokoroChatProcessor: |
| def __init__(self, data_path: str): |
| self.data_path = Path(data_path) |
| self.conversations = [] |
| self.processed_data = [] |
| |
| def load_all_conversations(self) -> List[Dict]: |
| """Load all JSON files from KokoroChat dataset""" |
| json_files = list(self.data_path.glob("**/*.json")) |
| print(f"Found {len(json_files)} conversation files") |
| |
| for json_file in tqdm(json_files, desc="Loading conversations"): |
| try: |
| with open(json_file, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| self.conversations.append(data) |
| except Exception as e: |
| print(f"Error loading {json_file}: {e}") |
| |
| return self.conversations |
| |
| def create_training_examples(self) -> List[Dict]: |
| """Convert conversations to training format""" |
| |
| for conv_data in tqdm(self.conversations, desc="Processing conversations"): |
| dialogue = conv_data.get('dialogue', []) |
| topic = conv_data.get('topic', {}) |
| review = conv_data.get('review_by_client_jp', {}) |
| |
| |
| conversation_pairs = [] |
| |
| for i in range(0, len(dialogue) - 1, 2): |
| if i + 1 < len(dialogue): |
| counselor_msg = dialogue[i] |
| client_msg = dialogue[i + 1] if i + 1 < len(dialogue) else None |
| |
| if counselor_msg['role'] == 'counselor' and client_msg and client_msg['role'] == 'client': |
| |
| context = self._build_context(dialogue[:i+1]) |
| |
| training_example = { |
| 'instruction': "あなたは共感的で専門的な心理カウンセラーです。クライアントの悩みに寄り添い、適切なサポートを提供してください。", |
| 'input': f"クライアント: {client_msg['utterance']}", |
| 'output': counselor_msg['utterance'], |
| 'context': context, |
| 'topic': topic.get('main_jp', ''), |
| 'quality_score': self._calculate_quality_score(review) |
| } |
| |
| self.processed_data.append(training_example) |
| |
| return self.processed_data |
| |
| def _build_context(self, dialogue_history: List[Dict], max_turns: int = 5) -> str: |
| """Build conversation context from history""" |
| context_parts = [] |
| start_idx = max(0, len(dialogue_history) - max_turns * 2) |
| |
| for msg in dialogue_history[start_idx:]: |
| role = "カウンセラー" if msg['role'] == 'counselor' else "クライアント" |
| context_parts.append(f"{role}: {msg['utterance']}") |
| |
| return "\n".join(context_parts) |
| |
| def _calculate_quality_score(self, review: Dict) -> float: |
| """Calculate quality score from client review""" |
| if not review or review.get('点数') is None: |
| return 0.5 |
| |
| |
| return review.get('点数', 50) / 100.0 |
| |
| def prepare_for_finetuning(self, test_size: float = 0.1, val_size: float = 0.1): |
| """Prepare train/val/test splits""" |
| |
| |
| high_quality = [ex for ex in self.processed_data if ex['quality_score'] > 0.6] |
| print(f"Selected {len(high_quality)} high-quality examples") |
| |
| |
| train_data, test_data = train_test_split(high_quality, test_size=test_size, random_state=42) |
| train_data, val_data = train_test_split(train_data, test_size=val_size, random_state=42) |
| |
| |
| def format_example(ex): |
| prompt = f"""### 指示: |
| {ex['instruction']} |
| |
| ### コンテキスト: |
| {ex['context']} |
| |
| ### 入力: |
| {ex['input']} |
| |
| ### 応答: |
| {ex['output']}""" |
| return {'text': prompt} |
| |
| train_formatted = [format_example(ex) for ex in train_data] |
| val_formatted = [format_example(ex) for ex in val_data] |
| test_formatted = [format_example(ex) for ex in test_data] |
| |
| return train_formatted, val_formatted, test_formatted |
|
|
| |
| processor = KokoroChatProcessor('KokoroChat/data') |
| processor.load_all_conversations() |
| processor.create_training_examples() |
| train_data, val_data, test_data = processor.prepare_for_finetuning() |
|
|
| |
| import pickle |
| with open('processed_data.pkl', 'wb') as f: |
| pickle.dump({ |
| 'train': train_data, |
| 'val': val_data, |
| 'test': test_data |
| }, f) |
|
|
| print(f"Training examples: {len(train_data)}") |
| print(f"Validation examples: {len(val_data)}") |
| print(f"Test examples: {len(test_data)}") |
|
|