|
|
|
""" |
|
Preprocessing script for The Vault dataset to adapt it for GLEN training. |
|
|
|
This script converts The Vault's code-text format to GLEN's expected format for |
|
generative retrieval training. |
|
""" |
|
|
|
import json |
|
import pandas as pd |
|
import os |
|
import argparse |
|
from typing import Dict, List, Any |
|
from tqdm import tqdm |
|
import hashlib |
|
|
|
def clean_text(text: str) -> str: |
|
"""Clean text by removing problematic characters.""" |
|
if not text: |
|
return "" |
|
text = text.replace("\n", " ") |
|
text = text.replace("\t", " ") |
|
text = text.replace("``", "") |
|
text = text.replace('"', "'") |
|
|
|
text = " ".join(text.split()) |
|
return text |
|
|
|
def create_document_id(code: str, identifier: str, repo: str) -> str: |
|
"""Create a unique document ID for the code snippet.""" |
|
|
|
content = f"{repo}:{identifier}:{code}" |
|
hash_obj = hashlib.md5(content.encode()) |
|
return hash_obj.hexdigest()[:10] |
|
|
|
def process_vault_sample(sample: Dict[str, Any], include_comments: bool = True) -> Dict[str, Any]: |
|
"""Process a single sample from The Vault dataset.""" |
|
|
|
|
|
identifier = sample.get('identifier', '') |
|
docstring = sample.get('docstring', '') |
|
short_docstring = sample.get('short_docstring', '') |
|
code = sample.get('code', '') |
|
language = sample.get('language', '') |
|
repo = sample.get('repo', '') |
|
path = sample.get('path', '') |
|
comments = sample.get('comment', []) |
|
|
|
|
|
description = docstring if docstring else short_docstring |
|
|
|
|
|
if include_comments and comments: |
|
comment_text = " ".join([clean_text(c) for c in comments if c]) |
|
if comment_text: |
|
description = f"{description} {comment_text}" if description else comment_text |
|
|
|
|
|
description = clean_text(description) |
|
code = clean_text(code) |
|
|
|
|
|
if description: |
|
doc_content = f"{description} [CODE] {code}" |
|
else: |
|
doc_content = f"Function {identifier} in {language}. [CODE] {code}" |
|
|
|
|
|
doc_id = create_document_id(code, identifier, repo) |
|
|
|
|
|
if description: |
|
query = description |
|
else: |
|
query = f"Find function {identifier} in {language}" |
|
|
|
return { |
|
'oldid': doc_id, |
|
'doc_content': doc_content, |
|
'query': query, |
|
'identifier': identifier, |
|
'language': language, |
|
'repo': repo, |
|
'path': path |
|
} |
|
|
|
def create_query_document_pairs(processed_samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
"""Create query-document pairs for training.""" |
|
pairs = [] |
|
|
|
for sample in processed_samples: |
|
pairs.append({ |
|
'query': sample['query'], |
|
'oldid': sample['oldid'], |
|
'docid': sample['oldid'], |
|
'rank': 1, |
|
'neg_docid_list': [], |
|
'aug_query_list': [] |
|
}) |
|
|
|
return pairs |
|
|
|
def generate_document_ids(processed_samples: List[Dict[str, Any]], id_class: str = "t5_bm25_truncate_3") -> pd.DataFrame: |
|
"""Generate document IDs in GLEN's expected format.""" |
|
|
|
id_data = [] |
|
for sample in processed_samples: |
|
|
|
|
|
doc_id = "-".join(list(sample['oldid'][:5])) |
|
|
|
id_data.append({ |
|
'oldid': sample['oldid'], |
|
id_class: doc_id |
|
}) |
|
|
|
return pd.DataFrame(id_data) |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Preprocess The Vault dataset for GLEN') |
|
parser.add_argument('--input_dir', type=str, default='the_vault_dataset/', |
|
help='Input directory containing The Vault dataset') |
|
parser.add_argument('--output_dir', type=str, default='data/the_vault/', |
|
help='Output directory for processed data') |
|
parser.add_argument('--include_comments', action='store_true', |
|
help='Include code comments in descriptions') |
|
parser.add_argument('--max_samples', type=int, default=None, |
|
help='Maximum number of samples to process (for testing)') |
|
parser.add_argument('--create_test_set', action='store_true', |
|
help='Create test set for evaluation') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
splits = ['train_small', 'validate', 'test'] |
|
|
|
for split in splits: |
|
print(f"Processing {split} split...") |
|
|
|
input_file = os.path.join(args.input_dir, f"{split}.json") |
|
if not os.path.exists(input_file): |
|
print(f"Warning: {input_file} not found, skipping...") |
|
continue |
|
|
|
processed_samples = [] |
|
|
|
|
|
with open(input_file, 'r', encoding='utf-8') as f: |
|
for i, line in enumerate(tqdm(f, desc=f"Processing {split}")): |
|
if args.max_samples and i >= args.max_samples: |
|
break |
|
|
|
try: |
|
sample = json.loads(line.strip()) |
|
processed_sample = process_vault_sample(sample, args.include_comments) |
|
|
|
|
|
if len(processed_sample['doc_content']) > 50: |
|
processed_samples.append(processed_sample) |
|
|
|
except json.JSONDecodeError as e: |
|
print(f"Error parsing line {i}: {e}") |
|
continue |
|
|
|
print(f"Processed {len(processed_samples)} valid samples from {split}") |
|
|
|
|
|
doc_output = split.replace('train_small', 'train') |
|
doc_df = pd.DataFrame([{ |
|
'oldid': sample['oldid'], |
|
'doc_content': sample['doc_content'] |
|
} for sample in processed_samples]) |
|
|
|
doc_file = os.path.join(args.output_dir, f"DOC_VAULT_{doc_output}.tsv") |
|
doc_df.to_csv(doc_file, sep='\t', index=False, encoding='utf-8') |
|
print(f"Saved document data to {doc_file}") |
|
|
|
|
|
if split == 'train_small': |
|
pairs = create_query_document_pairs(processed_samples) |
|
gtq_df = pd.DataFrame(pairs) |
|
gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_train.tsv") |
|
gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8') |
|
print(f"Saved training query-document pairs to {gtq_file}") |
|
|
|
|
|
elif split in ['validate', 'test']: |
|
pairs = create_query_document_pairs(processed_samples) |
|
|
|
gtq_df = pd.DataFrame(pairs) |
|
gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_dev.tsv") |
|
gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8') |
|
print(f"Saved evaluation query-document pairs to {gtq_file}") |
|
|
|
|
|
id_df = generate_document_ids(processed_samples) |
|
|
|
|
|
if split == 'train_small': |
|
id_file = os.path.join(args.output_dir, "ID_VAULT_t5_bm25_truncate_3.tsv") |
|
id_df.to_csv(id_file, sep='\t', index=False, encoding='utf-8') |
|
print(f"Saved document IDs to {id_file}") |
|
|
|
if __name__ == "__main__": |
|
main() |