| | |
| | """ |
| | Fix batch5 by correctly converting environment role to observation. |
| | """ |
| |
|
| | import json |
| | from datasets import load_dataset |
| | from tqdm import tqdm |
| |
|
| | def convert_to_llamafactory_format(sample): |
| | """ |
| | Convert from Dolci format to LlamaFactory format. |
| | |
| | Dolci format (messages): |
| | - role: system/user/assistant/environment |
| | - content: text content |
| | - function_calls: function call string (in assistant messages) |
| | - functions: available functions JSON string (in system message) |
| | |
| | LlamaFactory format (conversations): |
| | - from: human/gpt/function_call/observation/system |
| | - value: text or JSON |
| | """ |
| | messages = sample.get('messages', []) |
| | conversations = [] |
| | tools = None |
| | system_prompt = None |
| |
|
| | for i, msg in enumerate(messages): |
| | role = msg.get('role', '') |
| | content = msg.get('content', '') |
| | function_calls = msg.get('function_calls') |
| | functions = msg.get('functions') |
| |
|
| | |
| | if role == 'system': |
| | if functions and not tools: |
| | tools = functions |
| | if content: |
| | system_prompt = content |
| | continue |
| |
|
| | |
| | if role == 'user': |
| | conversations.append({ |
| | 'from': 'human', |
| | 'value': content |
| | }) |
| | elif role == 'assistant': |
| | |
| | if function_calls: |
| | |
| | conversations.append({ |
| | 'from': 'function_call', |
| | 'value': function_calls |
| | }) |
| | elif content: |
| | |
| | conversations.append({ |
| | 'from': 'gpt', |
| | 'value': content |
| | }) |
| | elif role == 'environment': |
| | |
| | conversations.append({ |
| | 'from': 'observation', |
| | 'value': content |
| | }) |
| |
|
| | result = {'conversations': conversations} |
| |
|
| | if system_prompt: |
| | result['system'] = system_prompt |
| | if tools: |
| | result['tools'] = tools |
| |
|
| | return result |
| |
|
| | def get_sample_hash(sample): |
| | """Create a hash for a sample to identify duplicates.""" |
| | messages = sample.get('messages', []) |
| | for msg in messages: |
| | if msg.get('role') == 'user': |
| | return hash(msg.get('content', '')) |
| | return None |
| |
|
| | def has_tool_calling(messages): |
| | """Check if messages contain function_call.""" |
| | for msg in messages: |
| | if msg.get('function_calls'): |
| | return True |
| | return False |
| |
|
| | def main(): |
| | print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...") |
| | dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") |
| |
|
| | total_samples = len(dataset) |
| | print(f"Total samples in dataset: {total_samples}") |
| |
|
| | |
| | print("\nLoading existing batches to avoid duplicates...") |
| | existing_hashes = set() |
| | for batch_num in range(1, 5): |
| | batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json" |
| | try: |
| | with open(batch_file, 'r', encoding='utf-8') as f: |
| | batch_data = json.load(f) |
| | for sample in batch_data: |
| | conversations = sample.get('conversations', []) |
| | for conv in conversations: |
| | if conv.get('from') == 'human': |
| | sample_hash = hash(conv.get('value', '')) |
| | existing_hashes.add(sample_hash) |
| | break |
| | print(f" Loaded batch{batch_num}: {len(batch_data)} samples") |
| | except FileNotFoundError: |
| | print(f" Warning: {batch_file} not found, skipping...") |
| |
|
| | print(f"Total existing samples to avoid: {len(existing_hashes)}") |
| |
|
| | |
| | start_idx = max(0, total_samples - 20000) |
| | last_20k = dataset.select(range(start_idx, total_samples)) |
| | print(f"\nProcessing last 20k samples (from index {start_idx} to {total_samples})") |
| |
|
| | |
| | tool_calling_samples = [] |
| | for idx, sample in enumerate(tqdm(last_20k, desc="Filtering tool calling samples")): |
| | messages = sample.get('messages', []) |
| | if has_tool_calling(messages): |
| | sample_hash = get_sample_hash(sample) |
| |
|
| | |
| | if sample_hash not in existing_hashes: |
| | converted = convert_to_llamafactory_format(sample) |
| |
|
| | |
| | conversations = converted.get('conversations', []) |
| | roles = [c['from'] for c in conversations] |
| |
|
| | |
| | if 'function_call' in roles and 'observation' in roles: |
| | tool_calling_samples.append(converted) |
| |
|
| | print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples with proper format") |
| |
|
| | |
| | if len(tool_calling_samples) > 10000: |
| | selected_samples = tool_calling_samples[:10000] |
| | print(f"Selected first 10,000 samples for batch5") |
| | else: |
| | selected_samples = tool_calling_samples |
| | print(f"Using all {len(selected_samples)} samples for batch5") |
| |
|
| | if not selected_samples: |
| | print("\n❌ No new tool calling samples found!") |
| | return |
| |
|
| | |
| | output_file = "data/dolci_10k_with_tool_call_batch5.json" |
| | print(f"\nSaving to {output_file}...") |
| | with open(output_file, 'w', encoding='utf-8') as f: |
| | json.dump(selected_samples, f, ensure_ascii=False, indent=2) |
| |
|
| | print(f"✓ Successfully created batch5 with {len(selected_samples)} samples") |
| |
|
| | |
| | print("\n=== Verifying format ===") |
| | role_patterns = {} |
| | for sample in selected_samples[:100]: |
| | roles = [c['from'] for c in sample['conversations']] |
| | pattern = ' -> '.join(roles) |
| | role_patterns[pattern] = role_patterns.get(pattern, 0) + 1 |
| |
|
| | print("Top patterns in first 100 samples:") |
| | for pattern, count in sorted(role_patterns.items(), key=lambda x: -x[1])[:5]: |
| | print(f" [{count:3d}] {pattern}") |
| |
|
| | |
| | if selected_samples: |
| | print("\nSample entry:") |
| | print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:1000] + "...") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|