File size: 8,156 Bytes
3d5551b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08894ba
 
3d5551b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08894ba
3d5551b
08894ba
3d5551b
 
 
 
 
 
 
 
08894ba
3d5551b
08894ba
3d5551b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python3
"""
Preprocessing script for The Vault dataset to adapt it for GLEN training.

This script converts The Vault's code-text format to GLEN's expected format for
generative retrieval training.
"""

import json
import pandas as pd
import os
import argparse
from typing import Dict, List, Any
from tqdm import tqdm
import hashlib

def clean_text(text: str) -> str:
    """Clean text by removing problematic characters."""
    if not text:
        return ""
    text = text.replace("\n", " ")
    text = text.replace("\t", " ")
    text = text.replace("``", "")
    text = text.replace('"', "'")
    # Remove excessive whitespace
    text = " ".join(text.split())
    return text

def create_document_id(code: str, identifier: str, repo: str) -> str:
    """Create a unique document ID for the code snippet."""
    # Use a combination of repo, identifier, and hash of code for uniqueness
    content = f"{repo}:{identifier}:{code}"
    hash_obj = hashlib.md5(content.encode())
    return hash_obj.hexdigest()[:10]  # Use first 10 chars of hash

def process_vault_sample(sample: Dict[str, Any], include_comments: bool = True) -> Dict[str, Any]:
    """Process a single sample from The Vault dataset."""
    
    # Extract relevant fields
    identifier = sample.get('identifier', '')
    docstring = sample.get('docstring', '')
    short_docstring = sample.get('short_docstring', '')
    code = sample.get('code', '')
    language = sample.get('language', '')
    repo = sample.get('repo', '')
    path = sample.get('path', '')
    comments = sample.get('comment', [])
    
    # Choose between full docstring and short docstring
    description = docstring if docstring else short_docstring
    
    # Add comments if requested
    if include_comments and comments:
        comment_text = " ".join([clean_text(c) for c in comments if c])
        if comment_text:
            description = f"{description} {comment_text}" if description else comment_text
    
    # Clean the text
    description = clean_text(description)
    code = clean_text(code)
    
    # Create document content (code + description)
    if description:
        doc_content = f"{description} [CODE] {code}"
    else:
        doc_content = f"Function {identifier} in {language}. [CODE] {code}"
    
    # Create unique document ID
    doc_id = create_document_id(code, identifier, repo)
    
    # Create query (description or function signature)
    if description:
        query = description
    else:
        query = f"Find function {identifier} in {language}"
    
    return {
        'oldid': doc_id,
        'doc_content': doc_content,
        'query': query,
        'identifier': identifier,
        'language': language,
        'repo': repo,
        'path': path
    }

def create_query_document_pairs(processed_samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Create query-document pairs for training."""
    pairs = []
    
    for sample in processed_samples:
        pairs.append({
            'query': sample['query'],
            'oldid': sample['oldid'],
            'docid': sample['oldid'],  # For GLEN, docid is the same as oldid initially
            'rank': 1,
            'neg_docid_list': [],  # Will be populated during training
            'aug_query_list': []   # Will be populated during training
        })
    
    return pairs

def generate_document_ids(processed_samples: List[Dict[str, Any]], id_class: str = "t5_bm25_truncate_3") -> pd.DataFrame:
    """Generate document IDs in GLEN's expected format."""
    
    id_data = []
    for sample in processed_samples:
        # For now, use a simple ID assignment based on the oldid
        # In practice, you might want to use BM25 or other methods
        doc_id = "-".join(list(sample['oldid'][:5]))  # Convert to GLEN's token format
        
        id_data.append({
            'oldid': sample['oldid'],
            id_class: doc_id
        })
    
    return pd.DataFrame(id_data)

def main():
    parser = argparse.ArgumentParser(description='Preprocess The Vault dataset for GLEN')
    parser.add_argument('--input_dir', type=str, default='the_vault_dataset/', 
                        help='Input directory containing The Vault dataset')
    parser.add_argument('--output_dir', type=str, default='data/the_vault/', 
                        help='Output directory for processed data')
    parser.add_argument('--include_comments', action='store_true', 
                        help='Include code comments in descriptions')
    parser.add_argument('--max_samples', type=int, default=None,
                        help='Maximum number of samples to process (for testing)')
    parser.add_argument('--create_test_set', action='store_true',
                        help='Create test set for evaluation')
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Process each split
    splits = ['train_small', 'validate', 'test']
    
    for split in splits:
        print(f"Processing {split} split...")
        
        input_file = os.path.join(args.input_dir, f"{split}.json")
        if not os.path.exists(input_file):
            print(f"Warning: {input_file} not found, skipping...")
            continue
        
        processed_samples = []
        
        # Read and process the JSON file line by line
        with open(input_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(tqdm(f, desc=f"Processing {split}")):
                if args.max_samples and i >= args.max_samples:
                    break
                
                try:
                    sample = json.loads(line.strip())
                    processed_sample = process_vault_sample(sample, args.include_comments)
                    
                    # Filter out samples with very short code or descriptions
                    if len(processed_sample['doc_content']) > 50:
                        processed_samples.append(processed_sample)
                        
                except json.JSONDecodeError as e:
                    print(f"Error parsing line {i}: {e}")
                    continue
        
        print(f"Processed {len(processed_samples)} valid samples from {split}")
        
        # Create documents file (DOC_VAULT_*)
        doc_output = split.replace('train_small', 'train')
        doc_df = pd.DataFrame([{
            'oldid': sample['oldid'],
            'doc_content': sample['doc_content']
        } for sample in processed_samples])
        
        doc_file = os.path.join(args.output_dir, f"DOC_VAULT_{doc_output}.tsv")
        doc_df.to_csv(doc_file, sep='\t', index=False, encoding='utf-8')
        print(f"Saved document data to {doc_file}")
        
        # Create query-document pairs for training data
        if split == 'train_small':
            pairs = create_query_document_pairs(processed_samples)
            gtq_df = pd.DataFrame(pairs)
            gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_train.tsv")
            gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8')
            print(f"Saved training query-document pairs to {gtq_file}")
        
        # Create query-document pairs for evaluation data
        elif split in ['validate', 'test']:
            pairs = create_query_document_pairs(processed_samples)
            # Always use 'dev' for evaluation to match GLEN's expectations
            gtq_df = pd.DataFrame(pairs)
            gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_dev.tsv")
            gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8')
            print(f"Saved evaluation query-document pairs to {gtq_file}")
        
        # Generate document IDs
        id_df = generate_document_ids(processed_samples)
        
        # Create separate ID file for each split
        if split == 'train_small':
            id_file = os.path.join(args.output_dir, "ID_VAULT_t5_bm25_truncate_3.tsv")
            id_df.to_csv(id_file, sep='\t', index=False, encoding='utf-8')
            print(f"Saved document IDs to {id_file}")

if __name__ == "__main__":
    main()