| |
| """ |
| Extract batch5 from the last 20k samples of allenai/Dolci-Instruct-SFT-Tool-Use dataset. |
| Select samples with tool calling (function_call and observation). |
| """ |
|
|
| import json |
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| def has_tool_calling(conversations): |
| """Check if conversations contain function_call and observation.""" |
| roles = [conv.get('from') or conv.get('role') for conv in conversations] |
| return 'function_call' in roles and 'observation' in roles |
|
|
| def convert_to_llamafactory_format(sample): |
| """Convert dataset sample to LlamaFactory format.""" |
| conversations = sample.get('conversations', []) |
|
|
| |
| converted_conversations = [] |
| for conv in conversations: |
| role = conv.get('from') or conv.get('role') |
| value = conv.get('value') or conv.get('content') |
| converted_conversations.append({ |
| 'from': role, |
| 'value': value |
| }) |
|
|
| result = { |
| 'conversations': converted_conversations |
| } |
|
|
| |
| if 'system' in sample and sample['system']: |
| result['system'] = sample['system'] |
|
|
| |
| if 'tools' in sample and sample['tools']: |
| result['tools'] = sample['tools'] |
|
|
| return result |
|
|
| def main(): |
| print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...") |
| dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") |
|
|
| total_samples = len(dataset) |
| print(f"Total samples in dataset: {total_samples}") |
|
|
| |
| start_idx = max(0, total_samples - 20000) |
| last_20k = dataset.select(range(start_idx, total_samples)) |
| print(f"Processing last 20k samples (from index {start_idx} to {total_samples})") |
|
|
| |
| tool_calling_samples = [] |
| for sample in tqdm(last_20k, desc="Filtering samples with tool calling"): |
| conversations = sample.get('conversations', []) |
| if has_tool_calling(conversations): |
| converted = convert_to_llamafactory_format(sample) |
| tool_calling_samples.append(converted) |
|
|
| print(f"\nFound {len(tool_calling_samples)} samples with tool calling") |
|
|
| |
| if len(tool_calling_samples) > 10000: |
| selected_samples = tool_calling_samples[:10000] |
| print(f"Selected first 10,000 samples for batch5") |
| else: |
| selected_samples = tool_calling_samples |
| print(f"Using all {len(selected_samples)} samples for batch5") |
|
|
| |
| output_file = "data/dolci_10k_with_tool_call_batch5.json" |
| print(f"\nSaving to {output_file}...") |
| with open(output_file, 'w', encoding='utf-8') as f: |
| json.dump(selected_samples, f, ensure_ascii=False, indent=2) |
|
|
| print(f"✓ Successfully created batch5 with {len(selected_samples)} samples") |
|
|
| |
| if selected_samples: |
| print("\nSample entry:") |
| print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:500] + "...") |
|
|
| if __name__ == "__main__": |
| main() |
|
|