GLEN-model / scripts /preprocess_vault_dataset.py
QuanTH02's picture
15-06-v2
08894ba
#!/usr/bin/env python3
"""
Preprocessing script for The Vault dataset to adapt it for GLEN training.
This script converts The Vault's code-text format to GLEN's expected format for
generative retrieval training.
"""
import json
import pandas as pd
import os
import argparse
from typing import Dict, List, Any
from tqdm import tqdm
import hashlib
def clean_text(text: str) -> str:
"""Clean text by removing problematic characters."""
if not text:
return ""
text = text.replace("\n", " ")
text = text.replace("\t", " ")
text = text.replace("``", "")
text = text.replace('"', "'")
# Remove excessive whitespace
text = " ".join(text.split())
return text
def create_document_id(code: str, identifier: str, repo: str) -> str:
"""Create a unique document ID for the code snippet."""
# Use a combination of repo, identifier, and hash of code for uniqueness
content = f"{repo}:{identifier}:{code}"
hash_obj = hashlib.md5(content.encode())
return hash_obj.hexdigest()[:10] # Use first 10 chars of hash
def process_vault_sample(sample: Dict[str, Any], include_comments: bool = True) -> Dict[str, Any]:
"""Process a single sample from The Vault dataset."""
# Extract relevant fields
identifier = sample.get('identifier', '')
docstring = sample.get('docstring', '')
short_docstring = sample.get('short_docstring', '')
code = sample.get('code', '')
language = sample.get('language', '')
repo = sample.get('repo', '')
path = sample.get('path', '')
comments = sample.get('comment', [])
# Choose between full docstring and short docstring
description = docstring if docstring else short_docstring
# Add comments if requested
if include_comments and comments:
comment_text = " ".join([clean_text(c) for c in comments if c])
if comment_text:
description = f"{description} {comment_text}" if description else comment_text
# Clean the text
description = clean_text(description)
code = clean_text(code)
# Create document content (code + description)
if description:
doc_content = f"{description} [CODE] {code}"
else:
doc_content = f"Function {identifier} in {language}. [CODE] {code}"
# Create unique document ID
doc_id = create_document_id(code, identifier, repo)
# Create query (description or function signature)
if description:
query = description
else:
query = f"Find function {identifier} in {language}"
return {
'oldid': doc_id,
'doc_content': doc_content,
'query': query,
'identifier': identifier,
'language': language,
'repo': repo,
'path': path
}
def create_query_document_pairs(processed_samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Create query-document pairs for training."""
pairs = []
for sample in processed_samples:
pairs.append({
'query': sample['query'],
'oldid': sample['oldid'],
'docid': sample['oldid'], # For GLEN, docid is the same as oldid initially
'rank': 1,
'neg_docid_list': [], # Will be populated during training
'aug_query_list': [] # Will be populated during training
})
return pairs
def generate_document_ids(processed_samples: List[Dict[str, Any]], id_class: str = "t5_bm25_truncate_3") -> pd.DataFrame:
"""Generate document IDs in GLEN's expected format."""
id_data = []
for sample in processed_samples:
# For now, use a simple ID assignment based on the oldid
# In practice, you might want to use BM25 or other methods
doc_id = "-".join(list(sample['oldid'][:5])) # Convert to GLEN's token format
id_data.append({
'oldid': sample['oldid'],
id_class: doc_id
})
return pd.DataFrame(id_data)
def main():
parser = argparse.ArgumentParser(description='Preprocess The Vault dataset for GLEN')
parser.add_argument('--input_dir', type=str, default='the_vault_dataset/',
help='Input directory containing The Vault dataset')
parser.add_argument('--output_dir', type=str, default='data/the_vault/',
help='Output directory for processed data')
parser.add_argument('--include_comments', action='store_true',
help='Include code comments in descriptions')
parser.add_argument('--max_samples', type=int, default=None,
help='Maximum number of samples to process (for testing)')
parser.add_argument('--create_test_set', action='store_true',
help='Create test set for evaluation')
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Process each split
splits = ['train_small', 'validate', 'test']
for split in splits:
print(f"Processing {split} split...")
input_file = os.path.join(args.input_dir, f"{split}.json")
if not os.path.exists(input_file):
print(f"Warning: {input_file} not found, skipping...")
continue
processed_samples = []
# Read and process the JSON file line by line
with open(input_file, 'r', encoding='utf-8') as f:
for i, line in enumerate(tqdm(f, desc=f"Processing {split}")):
if args.max_samples and i >= args.max_samples:
break
try:
sample = json.loads(line.strip())
processed_sample = process_vault_sample(sample, args.include_comments)
# Filter out samples with very short code or descriptions
if len(processed_sample['doc_content']) > 50:
processed_samples.append(processed_sample)
except json.JSONDecodeError as e:
print(f"Error parsing line {i}: {e}")
continue
print(f"Processed {len(processed_samples)} valid samples from {split}")
# Create documents file (DOC_VAULT_*)
doc_output = split.replace('train_small', 'train')
doc_df = pd.DataFrame([{
'oldid': sample['oldid'],
'doc_content': sample['doc_content']
} for sample in processed_samples])
doc_file = os.path.join(args.output_dir, f"DOC_VAULT_{doc_output}.tsv")
doc_df.to_csv(doc_file, sep='\t', index=False, encoding='utf-8')
print(f"Saved document data to {doc_file}")
# Create query-document pairs for training data
if split == 'train_small':
pairs = create_query_document_pairs(processed_samples)
gtq_df = pd.DataFrame(pairs)
gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_train.tsv")
gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8')
print(f"Saved training query-document pairs to {gtq_file}")
# Create query-document pairs for evaluation data
elif split in ['validate', 'test']:
pairs = create_query_document_pairs(processed_samples)
# Always use 'dev' for evaluation to match GLEN's expectations
gtq_df = pd.DataFrame(pairs)
gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_dev.tsv")
gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8')
print(f"Saved evaluation query-document pairs to {gtq_file}")
# Generate document IDs
id_df = generate_document_ids(processed_samples)
# Create separate ID file for each split
if split == 'train_small':
id_file = os.path.join(args.output_dir, "ID_VAULT_t5_bm25_truncate_3.tsv")
id_df.to_csv(id_file, sep='\t', index=False, encoding='utf-8')
print(f"Saved document IDs to {id_file}")
if __name__ == "__main__":
main()