#!/usr/bin/env python3 """ Preprocessing script for The Vault dataset to adapt it for GLEN training. This script converts The Vault's code-text format to GLEN's expected format for generative retrieval training. """ import json import pandas as pd import os import argparse from typing import Dict, List, Any from tqdm import tqdm import hashlib def clean_text(text: str) -> str: """Clean text by removing problematic characters.""" if not text: return "" text = text.replace("\n", " ") text = text.replace("\t", " ") text = text.replace("``", "") text = text.replace('"', "'") # Remove excessive whitespace text = " ".join(text.split()) return text def create_document_id(code: str, identifier: str, repo: str) -> str: """Create a unique document ID for the code snippet.""" # Use a combination of repo, identifier, and hash of code for uniqueness content = f"{repo}:{identifier}:{code}" hash_obj = hashlib.md5(content.encode()) return hash_obj.hexdigest()[:10] # Use first 10 chars of hash def process_vault_sample(sample: Dict[str, Any], include_comments: bool = True) -> Dict[str, Any]: """Process a single sample from The Vault dataset.""" # Extract relevant fields identifier = sample.get('identifier', '') docstring = sample.get('docstring', '') short_docstring = sample.get('short_docstring', '') code = sample.get('code', '') language = sample.get('language', '') repo = sample.get('repo', '') path = sample.get('path', '') comments = sample.get('comment', []) # Choose between full docstring and short docstring description = docstring if docstring else short_docstring # Add comments if requested if include_comments and comments: comment_text = " ".join([clean_text(c) for c in comments if c]) if comment_text: description = f"{description} {comment_text}" if description else comment_text # Clean the text description = clean_text(description) code = clean_text(code) # Create document content (code + description) if description: doc_content = f"{description} [CODE] {code}" else: doc_content = f"Function {identifier} in {language}. [CODE] {code}" # Create unique document ID doc_id = create_document_id(code, identifier, repo) # Create query (description or function signature) if description: query = description else: query = f"Find function {identifier} in {language}" return { 'oldid': doc_id, 'doc_content': doc_content, 'query': query, 'identifier': identifier, 'language': language, 'repo': repo, 'path': path } def create_query_document_pairs(processed_samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Create query-document pairs for training.""" pairs = [] for sample in processed_samples: pairs.append({ 'query': sample['query'], 'oldid': sample['oldid'], 'docid': sample['oldid'], # For GLEN, docid is the same as oldid initially 'rank': 1, 'neg_docid_list': [], # Will be populated during training 'aug_query_list': [] # Will be populated during training }) return pairs def generate_document_ids(processed_samples: List[Dict[str, Any]], id_class: str = "t5_bm25_truncate_3") -> pd.DataFrame: """Generate document IDs in GLEN's expected format.""" id_data = [] for sample in processed_samples: # For now, use a simple ID assignment based on the oldid # In practice, you might want to use BM25 or other methods doc_id = "-".join(list(sample['oldid'][:5])) # Convert to GLEN's token format id_data.append({ 'oldid': sample['oldid'], id_class: doc_id }) return pd.DataFrame(id_data) def main(): parser = argparse.ArgumentParser(description='Preprocess The Vault dataset for GLEN') parser.add_argument('--input_dir', type=str, default='the_vault_dataset/', help='Input directory containing The Vault dataset') parser.add_argument('--output_dir', type=str, default='data/the_vault/', help='Output directory for processed data') parser.add_argument('--include_comments', action='store_true', help='Include code comments in descriptions') parser.add_argument('--max_samples', type=int, default=None, help='Maximum number of samples to process (for testing)') parser.add_argument('--create_test_set', action='store_true', help='Create test set for evaluation') args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Process each split splits = ['train_small', 'validate', 'test'] for split in splits: print(f"Processing {split} split...") input_file = os.path.join(args.input_dir, f"{split}.json") if not os.path.exists(input_file): print(f"Warning: {input_file} not found, skipping...") continue processed_samples = [] # Read and process the JSON file line by line with open(input_file, 'r', encoding='utf-8') as f: for i, line in enumerate(tqdm(f, desc=f"Processing {split}")): if args.max_samples and i >= args.max_samples: break try: sample = json.loads(line.strip()) processed_sample = process_vault_sample(sample, args.include_comments) # Filter out samples with very short code or descriptions if len(processed_sample['doc_content']) > 50: processed_samples.append(processed_sample) except json.JSONDecodeError as e: print(f"Error parsing line {i}: {e}") continue print(f"Processed {len(processed_samples)} valid samples from {split}") # Create documents file (DOC_VAULT_*) doc_output = split.replace('train_small', 'train') doc_df = pd.DataFrame([{ 'oldid': sample['oldid'], 'doc_content': sample['doc_content'] } for sample in processed_samples]) doc_file = os.path.join(args.output_dir, f"DOC_VAULT_{doc_output}.tsv") doc_df.to_csv(doc_file, sep='\t', index=False, encoding='utf-8') print(f"Saved document data to {doc_file}") # Create query-document pairs for training data if split == 'train_small': pairs = create_query_document_pairs(processed_samples) gtq_df = pd.DataFrame(pairs) gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_train.tsv") gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8') print(f"Saved training query-document pairs to {gtq_file}") # Create query-document pairs for evaluation data elif split in ['validate', 'test']: pairs = create_query_document_pairs(processed_samples) # Always use 'dev' for evaluation to match GLEN's expectations gtq_df = pd.DataFrame(pairs) gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_dev.tsv") gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8') print(f"Saved evaluation query-document pairs to {gtq_file}") # Generate document IDs id_df = generate_document_ids(processed_samples) # Create separate ID file for each split if split == 'train_small': id_file = os.path.join(args.output_dir, "ID_VAULT_t5_bm25_truncate_3.tsv") id_df.to_csv(id_file, sep='\t', index=False, encoding='utf-8') print(f"Saved document IDs to {id_file}") if __name__ == "__main__": main()