File size: 14,361 Bytes
5d6df7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import pandas as pd
import os
from google import genai
from google.genai import types
import json
from tqdm import tqdm
from typing import List, Dict
import time

def configure_genai(api_key: str):
    """Configure the Gemini API with the provided key."""
    os.environ["GEMINI_API_KEY"] = api_key

def load_predictions(task: str, layer: int) -> pd.DataFrame:
    """Load predictions from CSV file."""
    predictions_path = os.path.join("src", "codebert", task, f"layer{layer}", f"predictions_layer_{layer}.csv")
    if os.path.exists(predictions_path):
        try:
            df = pd.read_csv(predictions_path, delimiter='\t')
            df['Token'] = df['Token'].astype(str)
            df['predicted_cluster'] = df['Top 1'].astype(str)
            return df
        except Exception as e:
            print(f"Error loading predictions: {str(e)}")
            return None
    return None

def load_clusters(task: str, layer: int) -> Dict:
    """Load cluster data from clusters file."""
    clusters_path = os.path.join("src", "codebert", task, f"layer{layer}", "clusters-350.txt")
    if not os.path.exists(clusters_path):
        return None
        
    clusters = {}
    try:
        with open(clusters_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                    
                try:
                    parts = [p.strip() for p in line.split('|||')]
                    if len(parts) == 5:
                        token, occurrence, line_num, col_num, cluster_id = parts
                        cluster_id = cluster_id.split('|')[0].strip()
                        
                        if not cluster_id.isdigit():
                            continue
                            
                        cluster_id = str(int(cluster_id))
                        
                        if cluster_id not in clusters:
                            clusters[cluster_id] = []
                        clusters[cluster_id].append({
                            'token': token,
                            'line_num': int(line_num),
                            'col_num': int(col_num)
                        })
                except Exception:
                    continue
                    
    except Exception as e:
        print(f"Error loading clusters: {str(e)}")
        return None
    
    return clusters

def load_sentences(task: str, layer: int, file_name: str) -> List[str]:
    """Load sentences from specified file."""
    file_path = os.path.join("src", "codebert", task, f"layer{layer}", file_name)
    if not os.path.exists(file_path):
        file_path = os.path.join("src", "codebert", task, file_name)
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return f.readlines()
    except Exception as e:
        print(f"Error loading sentences from {file_path}: {str(e)}")
        return []

def get_gemini_explanation(sentence: str, highlighted_token: str, cluster_words: List[str]) -> str:
    """Get explanation from Gemini about the relationship between the token and cluster words."""
    highlighted_sentence = sentence.replace(highlighted_token, f"[[{highlighted_token}]]")
    
    prompt = f"""Do you find any common semantic, structural, lexical and topical relation between the word highlighted in the sentence (enclosed in [[ ]]) and the following list of words? Give a more specific and concise summary about the most prominent relation among these words.

Sentence: {highlighted_sentence}
List of words: {', '.join(cluster_words)}

Answer concisely and to the point."""

    # Create the client
    client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))  # Ensure this is correct
    model = "gemini-2.0-flash"
    contents = [
        types.Content(
            role="user",
            parts=[
                types.Part.from_text(text=prompt),  # Ensure this is the correct usage
            ],
        ),
    ]
    generate_content_config = types.GenerateContentConfig(
        temperature=1.0,
        response_mime_type="text/plain",
    )

    explanation = ""
    for chunk in client.models.generate_content_stream(
        model=model,
        contents=contents,
        config=generate_content_config,
    ):
        explanation += chunk.text

    return explanation.strip()

def is_cls_token(token: str) -> bool:
    """Check if a token is a CLS token."""
    return token.startswith('[CLS]')

def get_gemini_explanation_for_cls(sentence: str, cluster_words: List[str], context_sentences: List[str]) -> str:
    """Get explanation from Gemini about the CLS token and its relationship with the cluster."""
    
    # Include context sentences in the prompt
    context_text = "\n".join(context_sentences) if context_sentences else "No context sentences available."
    
    prompt = f"""[CLS] tokens represent the entire sentence. For this sentence, explain the semantic, structural, lexical, or topical meaning in relation to the list of words from similar contexts. What cohesive meaning does this sentence share with the contextual themes?

Original Sentence: {sentence}
List of cluster words: {', '.join(cluster_words)}

Context Sentences of the list of cluster words:
{context_text}

Answer concisely and to the point about the semantic or topical meaning this sentence shares with the contexts."""

    # Create the client
    client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
    model = "gemini-2.0-flash"
    contents = [
        types.Content(
            role="user",
            parts=[
                types.Part.from_text(text=prompt),
            ],
        ),
    ]
    generate_content_config = types.GenerateContentConfig(
        temperature=1.0,
        response_mime_type="text/plain",
    )

    explanation = ""
    for chunk in client.models.generate_content_stream(
        model=model,
        contents=contents,
        config=generate_content_config,
    ):
        explanation += chunk.text

    return explanation.strip()

def get_gemini_explanation_with_retry(sentence: str, highlighted_token: str, cluster_words: List[str], max_retries: int = 3) -> str:
    """Get explanation from Gemini with retry logic."""
    retry_count = 0
    while retry_count < max_retries:
        try:
            return get_gemini_explanation(sentence, highlighted_token, cluster_words)
        except Exception as e:
            retry_count += 1
            error_type = type(e).__name__
            print(f"\nEncountered {error_type}: {str(e)}")
            if retry_count < max_retries:
                wait_time = 60  # Wait for 60 seconds before retrying
                print(f"Waiting {wait_time} seconds before retry {retry_count}/{max_retries}...")
                time.sleep(wait_time)
            else:
                print(f"Max retries ({max_retries}) reached. Returning error message.")
                return f"Error generating explanation after {max_retries} attempts: {str(e)}"

def get_gemini_explanation_for_cls_with_retry(sentence: str, cluster_words: List[str], context_sentences: List[str], max_retries: int = 3) -> str:
    """Get explanation for CLS tokens with retry logic."""
    retry_count = 0
    while retry_count < max_retries:
        try:
            return get_gemini_explanation_for_cls(sentence, cluster_words, context_sentences)
        except Exception as e:
            retry_count += 1
            error_type = type(e).__name__
            print(f"\nEncountered {error_type}: {str(e)}")
            if retry_count < max_retries:
                wait_time = 60  # Wait for 60 seconds before retrying
                print(f"Waiting {wait_time} seconds before retry {retry_count}/{max_retries}...")
                time.sleep(wait_time)
            else:
                print(f"Max retries ({max_retries}) reached. Returning error message.")
                return f"Error generating explanation after {max_retries} attempts: {str(e)}"

def process_tokens(task: str, layer: int, api_key: str):
    """Process the first 15 tokens for a given task and layer with API rate limiting and error handling."""
    # Configure Gemini
    configure_genai(api_key)
    
    # Load necessary data
    predictions_df = load_predictions(task, layer)
    clusters = load_clusters(task, layer)
    dev_sentences = load_sentences(task, layer, "dev.in")
    input_sentences = load_sentences(task, layer, "input.in")
    
    if predictions_df is None or clusters is None:
        print("Failed to load required data")
        return
    
    # Limit to first 15 tokens
    predictions_df = predictions_df.head(15)
    print(f"Limited processing to first {len(predictions_df)} tokens")
    
    results = []
    batch_size = 15  # API limit of 15 calls per minute
    call_count = 0
    start_time = time.time()
    
    # Create output directory if it doesn't exist
    output_dir = os.path.join("src", "codebert", task, f"layer{layer}")
    os.makedirs(output_dir, exist_ok=True)
    
    # Check if there's an interim file to resume from
    interim_file = os.path.join(output_dir, f"token_explanations_layer_{layer}_test15.json")
    if os.path.exists(interim_file):
        try:
            with open(interim_file, 'r', encoding='utf-8') as f:
                results = json.load(f)
                print(f"Resuming from {len(results)} previously processed tokens")
                # Skip tokens we've already processed
                processed_indices = {(result['line_idx'], result['position_idx']) for result in results}
        except Exception as e:
            print(f"Error loading interim file: {str(e)}")
            processed_indices = set()
    else:
        processed_indices = set()
    
    # Process limited number of tokens, showing progress with tqdm
    for idx, row in tqdm(predictions_df.iterrows(), total=len(predictions_df), desc="Processing tokens"):
        token = row['Token']
        line_idx = row['line_idx']
        position_idx = row['position_idx']
        predicted_cluster = row['predicted_cluster']
        
        # Skip if we've already processed this token
        if (line_idx, position_idx) in processed_indices:
            continue
            
        # Get original sentence
        if line_idx < len(dev_sentences):
            original_sentence = dev_sentences[line_idx].strip()
        else:
            continue
            
        # Get unique cluster words
        if predicted_cluster in clusters:
            cluster_words = list(set(token_info['token'] for token_info in clusters[predicted_cluster]))
            
            # Gather context sentences from the predicted cluster
            context_sentences = []
            for token_info in clusters[predicted_cluster]:
                context_line_num = token_info['line_num']
                if context_line_num < len(input_sentences):
                    context_sentences.append(input_sentences[context_line_num].strip())
        else:
            continue
        
        # Rate limiting: check if we've reached the batch limit
        call_count += 1
        if call_count >= batch_size:
            elapsed = time.time() - start_time
            # If we've made batch_size calls in less than 60 seconds, wait until the minute is up
            if elapsed < 60:
                wait_time = 60 - elapsed
                print(f"\nReached API limit of {batch_size} calls. Waiting for {wait_time:.2f} seconds...")
                time.sleep(wait_time)
            # Reset counters
            call_count = 0
            start_time = time.time()
            
        # Choose the right explanation function based on token type
        try:
            if is_cls_token(token):
                # Special handling for CLS tokens with retry
                explanation = get_gemini_explanation_for_cls_with_retry(original_sentence, cluster_words, context_sentences)
            else:
                # Standard handling for other tokens with retry
                explanation = get_gemini_explanation_with_retry(original_sentence, token, cluster_words)
            
            # Store results
            result = {
                'token': token,
                'is_cls_token': is_cls_token(token),
                'line_idx': int(line_idx),
                'position_idx': int(position_idx),
                'predicted_cluster': predicted_cluster,
                'original_sentence': original_sentence,
                'cluster_words': cluster_words,
                'context_sentences': context_sentences,
                'explanation': explanation
            }
            results.append(result)
            
            # Add to processed indices
            processed_indices.add((line_idx, position_idx))
            
            # Save after each token for this small test run
            with open(interim_file, 'w', encoding='utf-8') as f:
                json.dump(results, f, indent=2, ensure_ascii=False)
            print(f"\nSaved results to: {interim_file}")
        
        except Exception as e:
            print(f"\nUnexpected error processing token {token}: {str(e)}")
            # Save current results before potentially exiting
            with open(interim_file, 'w', encoding='utf-8') as f:
                json.dump(results, f, indent=2, ensure_ascii=False)
            print(f"Emergency save to: {interim_file}")
            
            # Wait a minute before continuing
            print("Waiting 60 seconds before continuing...")
            time.sleep(60)
            
            # Reset batch counters
            call_count = 0
            start_time = time.time()

    # Save final results with a different name to indicate it's the test run
    output_file = os.path.join(output_dir, f"token_explanations_layer_{layer}_first15.json")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"Results saved to: {output_file}")

def main():
    # Configuration
    API_KEY = "AIzaSyCUCwrqcDNTSaHsn5Ln_91A0L03W864iYU"  # Replace with your API key
    TASK = "language_classification"   # Replace with your task name
    LAYER = 11                          # Replace with your layer number
    
    process_tokens(TASK, LAYER, API_KEY)

if __name__ == "__main__":
    main()