File size: 8,156 Bytes
3d5551b 08894ba 3d5551b 08894ba 3d5551b 08894ba 3d5551b 08894ba 3d5551b 08894ba 3d5551b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
#!/usr/bin/env python3
"""
Preprocessing script for The Vault dataset to adapt it for GLEN training.
This script converts The Vault's code-text format to GLEN's expected format for
generative retrieval training.
"""
import json
import pandas as pd
import os
import argparse
from typing import Dict, List, Any
from tqdm import tqdm
import hashlib
def clean_text(text: str) -> str:
"""Clean text by removing problematic characters."""
if not text:
return ""
text = text.replace("\n", " ")
text = text.replace("\t", " ")
text = text.replace("``", "")
text = text.replace('"', "'")
# Remove excessive whitespace
text = " ".join(text.split())
return text
def create_document_id(code: str, identifier: str, repo: str) -> str:
"""Create a unique document ID for the code snippet."""
# Use a combination of repo, identifier, and hash of code for uniqueness
content = f"{repo}:{identifier}:{code}"
hash_obj = hashlib.md5(content.encode())
return hash_obj.hexdigest()[:10] # Use first 10 chars of hash
def process_vault_sample(sample: Dict[str, Any], include_comments: bool = True) -> Dict[str, Any]:
"""Process a single sample from The Vault dataset."""
# Extract relevant fields
identifier = sample.get('identifier', '')
docstring = sample.get('docstring', '')
short_docstring = sample.get('short_docstring', '')
code = sample.get('code', '')
language = sample.get('language', '')
repo = sample.get('repo', '')
path = sample.get('path', '')
comments = sample.get('comment', [])
# Choose between full docstring and short docstring
description = docstring if docstring else short_docstring
# Add comments if requested
if include_comments and comments:
comment_text = " ".join([clean_text(c) for c in comments if c])
if comment_text:
description = f"{description} {comment_text}" if description else comment_text
# Clean the text
description = clean_text(description)
code = clean_text(code)
# Create document content (code + description)
if description:
doc_content = f"{description} [CODE] {code}"
else:
doc_content = f"Function {identifier} in {language}. [CODE] {code}"
# Create unique document ID
doc_id = create_document_id(code, identifier, repo)
# Create query (description or function signature)
if description:
query = description
else:
query = f"Find function {identifier} in {language}"
return {
'oldid': doc_id,
'doc_content': doc_content,
'query': query,
'identifier': identifier,
'language': language,
'repo': repo,
'path': path
}
def create_query_document_pairs(processed_samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Create query-document pairs for training."""
pairs = []
for sample in processed_samples:
pairs.append({
'query': sample['query'],
'oldid': sample['oldid'],
'docid': sample['oldid'], # For GLEN, docid is the same as oldid initially
'rank': 1,
'neg_docid_list': [], # Will be populated during training
'aug_query_list': [] # Will be populated during training
})
return pairs
def generate_document_ids(processed_samples: List[Dict[str, Any]], id_class: str = "t5_bm25_truncate_3") -> pd.DataFrame:
"""Generate document IDs in GLEN's expected format."""
id_data = []
for sample in processed_samples:
# For now, use a simple ID assignment based on the oldid
# In practice, you might want to use BM25 or other methods
doc_id = "-".join(list(sample['oldid'][:5])) # Convert to GLEN's token format
id_data.append({
'oldid': sample['oldid'],
id_class: doc_id
})
return pd.DataFrame(id_data)
def main():
parser = argparse.ArgumentParser(description='Preprocess The Vault dataset for GLEN')
parser.add_argument('--input_dir', type=str, default='the_vault_dataset/',
help='Input directory containing The Vault dataset')
parser.add_argument('--output_dir', type=str, default='data/the_vault/',
help='Output directory for processed data')
parser.add_argument('--include_comments', action='store_true',
help='Include code comments in descriptions')
parser.add_argument('--max_samples', type=int, default=None,
help='Maximum number of samples to process (for testing)')
parser.add_argument('--create_test_set', action='store_true',
help='Create test set for evaluation')
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Process each split
splits = ['train_small', 'validate', 'test']
for split in splits:
print(f"Processing {split} split...")
input_file = os.path.join(args.input_dir, f"{split}.json")
if not os.path.exists(input_file):
print(f"Warning: {input_file} not found, skipping...")
continue
processed_samples = []
# Read and process the JSON file line by line
with open(input_file, 'r', encoding='utf-8') as f:
for i, line in enumerate(tqdm(f, desc=f"Processing {split}")):
if args.max_samples and i >= args.max_samples:
break
try:
sample = json.loads(line.strip())
processed_sample = process_vault_sample(sample, args.include_comments)
# Filter out samples with very short code or descriptions
if len(processed_sample['doc_content']) > 50:
processed_samples.append(processed_sample)
except json.JSONDecodeError as e:
print(f"Error parsing line {i}: {e}")
continue
print(f"Processed {len(processed_samples)} valid samples from {split}")
# Create documents file (DOC_VAULT_*)
doc_output = split.replace('train_small', 'train')
doc_df = pd.DataFrame([{
'oldid': sample['oldid'],
'doc_content': sample['doc_content']
} for sample in processed_samples])
doc_file = os.path.join(args.output_dir, f"DOC_VAULT_{doc_output}.tsv")
doc_df.to_csv(doc_file, sep='\t', index=False, encoding='utf-8')
print(f"Saved document data to {doc_file}")
# Create query-document pairs for training data
if split == 'train_small':
pairs = create_query_document_pairs(processed_samples)
gtq_df = pd.DataFrame(pairs)
gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_train.tsv")
gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8')
print(f"Saved training query-document pairs to {gtq_file}")
# Create query-document pairs for evaluation data
elif split in ['validate', 'test']:
pairs = create_query_document_pairs(processed_samples)
# Always use 'dev' for evaluation to match GLEN's expectations
gtq_df = pd.DataFrame(pairs)
gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_dev.tsv")
gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8')
print(f"Saved evaluation query-document pairs to {gtq_file}")
# Generate document IDs
id_df = generate_document_ids(processed_samples)
# Create separate ID file for each split
if split == 'train_small':
id_file = os.path.join(args.output_dir, "ID_VAULT_t5_bm25_truncate_3.tsv")
id_df.to_csv(id_file, sep='\t', index=False, encoding='utf-8')
print(f"Saved document IDs to {id_file}")
if __name__ == "__main__":
main() |