Spaces:
Sleeping
Sleeping
import pandas as pd | |
import os | |
from google import genai | |
from google.genai import types | |
import json | |
from tqdm import tqdm | |
from typing import List, Dict | |
import time | |
def configure_genai(api_key: str): | |
"""Configure the Gemini API with the provided key.""" | |
os.environ["GEMINI_API_KEY"] = api_key | |
def load_predictions(task: str, layer: int) -> pd.DataFrame: | |
"""Load predictions from CSV file.""" | |
predictions_path = os.path.join("src", "codebert", task, f"layer{layer}", f"predictions_layer_{layer}.csv") | |
if os.path.exists(predictions_path): | |
try: | |
df = pd.read_csv(predictions_path, delimiter='\t') | |
df['Token'] = df['Token'].astype(str) | |
df['predicted_cluster'] = df['Top 1'].astype(str) | |
return df | |
except Exception as e: | |
print(f"Error loading predictions: {str(e)}") | |
return None | |
return None | |
def load_clusters(task: str, layer: int) -> Dict: | |
"""Load cluster data from clusters file.""" | |
clusters_path = os.path.join("src", "codebert", task, f"layer{layer}", "clusters-350.txt") | |
if not os.path.exists(clusters_path): | |
return None | |
clusters = {} | |
try: | |
with open(clusters_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
line = line.strip() | |
if not line: | |
continue | |
try: | |
parts = [p.strip() for p in line.split('|||')] | |
if len(parts) == 5: | |
token, occurrence, line_num, col_num, cluster_id = parts | |
cluster_id = cluster_id.split('|')[0].strip() | |
if not cluster_id.isdigit(): | |
continue | |
cluster_id = str(int(cluster_id)) | |
if cluster_id not in clusters: | |
clusters[cluster_id] = [] | |
clusters[cluster_id].append({ | |
'token': token, | |
'line_num': int(line_num), | |
'col_num': int(col_num) | |
}) | |
except Exception: | |
continue | |
except Exception as e: | |
print(f"Error loading clusters: {str(e)}") | |
return None | |
return clusters | |
def load_sentences(task: str, layer: int, file_name: str) -> List[str]: | |
"""Load sentences from specified file.""" | |
file_path = os.path.join("src", "codebert", task, f"layer{layer}", file_name) | |
if not os.path.exists(file_path): | |
file_path = os.path.join("src", "codebert", task, file_name) | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
return f.readlines() | |
except Exception as e: | |
print(f"Error loading sentences from {file_path}: {str(e)}") | |
return [] | |
def get_gemini_explanation(sentence: str, highlighted_token: str, cluster_words: List[str]) -> str: | |
"""Get explanation from Gemini about the relationship between the token and cluster words.""" | |
highlighted_sentence = sentence.replace(highlighted_token, f"[[{highlighted_token}]]") | |
prompt = f"""Do you find any common semantic, structural, lexical and topical relation between the word highlighted in the sentence (enclosed in [[ ]]) and the following list of words? Give a more specific and concise summary about the most prominent relation among these words. | |
Sentence: {highlighted_sentence} | |
List of words: {', '.join(cluster_words)} | |
Answer concisely and to the point.""" | |
# Create the client | |
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) # Ensure this is correct | |
model = "gemini-2.0-flash" | |
contents = [ | |
types.Content( | |
role="user", | |
parts=[ | |
types.Part.from_text(text=prompt), # Ensure this is the correct usage | |
], | |
), | |
] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=1.0, | |
response_mime_type="text/plain", | |
) | |
explanation = "" | |
for chunk in client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
): | |
explanation += chunk.text | |
return explanation.strip() | |
def is_cls_token(token: str) -> bool: | |
"""Check if a token is a CLS token.""" | |
return token.startswith('[CLS]') | |
def get_gemini_explanation_for_cls(sentence: str, cluster_words: List[str], context_sentences: List[str]) -> str: | |
"""Get explanation from Gemini about the CLS token and its relationship with the cluster.""" | |
# Include context sentences in the prompt | |
context_text = "\n".join(context_sentences) if context_sentences else "No context sentences available." | |
prompt = f"""[CLS] tokens represent the entire sentence. For this sentence, explain the semantic, structural, lexical, or topical meaning in relation to the list of words from similar contexts. What cohesive meaning does this sentence share with the contextual themes? | |
Original Sentence: {sentence} | |
List of cluster words: {', '.join(cluster_words)} | |
Context Sentences of the list of cluster words: | |
{context_text} | |
Answer concisely and to the point about the semantic or topical meaning this sentence shares with the contexts.""" | |
# Create the client | |
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) | |
model = "gemini-2.0-flash" | |
contents = [ | |
types.Content( | |
role="user", | |
parts=[ | |
types.Part.from_text(text=prompt), | |
], | |
), | |
] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=1.0, | |
response_mime_type="text/plain", | |
) | |
explanation = "" | |
for chunk in client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
): | |
explanation += chunk.text | |
return explanation.strip() | |
def get_gemini_explanation_with_retry(sentence: str, highlighted_token: str, cluster_words: List[str], max_retries: int = 3) -> str: | |
"""Get explanation from Gemini with retry logic.""" | |
retry_count = 0 | |
while retry_count < max_retries: | |
try: | |
return get_gemini_explanation(sentence, highlighted_token, cluster_words) | |
except Exception as e: | |
retry_count += 1 | |
error_type = type(e).__name__ | |
print(f"\nEncountered {error_type}: {str(e)}") | |
if retry_count < max_retries: | |
wait_time = 60 # Wait for 60 seconds before retrying | |
print(f"Waiting {wait_time} seconds before retry {retry_count}/{max_retries}...") | |
time.sleep(wait_time) | |
else: | |
print(f"Max retries ({max_retries}) reached. Returning error message.") | |
return f"Error generating explanation after {max_retries} attempts: {str(e)}" | |
def get_gemini_explanation_for_cls_with_retry(sentence: str, cluster_words: List[str], context_sentences: List[str], max_retries: int = 3) -> str: | |
"""Get explanation for CLS tokens with retry logic.""" | |
retry_count = 0 | |
while retry_count < max_retries: | |
try: | |
return get_gemini_explanation_for_cls(sentence, cluster_words, context_sentences) | |
except Exception as e: | |
retry_count += 1 | |
error_type = type(e).__name__ | |
print(f"\nEncountered {error_type}: {str(e)}") | |
if retry_count < max_retries: | |
wait_time = 60 # Wait for 60 seconds before retrying | |
print(f"Waiting {wait_time} seconds before retry {retry_count}/{max_retries}...") | |
time.sleep(wait_time) | |
else: | |
print(f"Max retries ({max_retries}) reached. Returning error message.") | |
return f"Error generating explanation after {max_retries} attempts: {str(e)}" | |
def process_tokens(task: str, layer: int, api_key: str): | |
"""Process the first 15 tokens for a given task and layer with API rate limiting and error handling.""" | |
# Configure Gemini | |
configure_genai(api_key) | |
# Load necessary data | |
predictions_df = load_predictions(task, layer) | |
clusters = load_clusters(task, layer) | |
dev_sentences = load_sentences(task, layer, "dev.in") | |
input_sentences = load_sentences(task, layer, "input.in") | |
if predictions_df is None or clusters is None: | |
print("Failed to load required data") | |
return | |
# Limit to first 15 tokens | |
predictions_df = predictions_df.head(15) | |
print(f"Limited processing to first {len(predictions_df)} tokens") | |
results = [] | |
batch_size = 15 # API limit of 15 calls per minute | |
call_count = 0 | |
start_time = time.time() | |
# Create output directory if it doesn't exist | |
output_dir = os.path.join("src", "codebert", task, f"layer{layer}") | |
os.makedirs(output_dir, exist_ok=True) | |
# Check if there's an interim file to resume from | |
interim_file = os.path.join(output_dir, f"token_explanations_layer_{layer}_test15.json") | |
if os.path.exists(interim_file): | |
try: | |
with open(interim_file, 'r', encoding='utf-8') as f: | |
results = json.load(f) | |
print(f"Resuming from {len(results)} previously processed tokens") | |
# Skip tokens we've already processed | |
processed_indices = {(result['line_idx'], result['position_idx']) for result in results} | |
except Exception as e: | |
print(f"Error loading interim file: {str(e)}") | |
processed_indices = set() | |
else: | |
processed_indices = set() | |
# Process limited number of tokens, showing progress with tqdm | |
for idx, row in tqdm(predictions_df.iterrows(), total=len(predictions_df), desc="Processing tokens"): | |
token = row['Token'] | |
line_idx = row['line_idx'] | |
position_idx = row['position_idx'] | |
predicted_cluster = row['predicted_cluster'] | |
# Skip if we've already processed this token | |
if (line_idx, position_idx) in processed_indices: | |
continue | |
# Get original sentence | |
if line_idx < len(dev_sentences): | |
original_sentence = dev_sentences[line_idx].strip() | |
else: | |
continue | |
# Get unique cluster words | |
if predicted_cluster in clusters: | |
cluster_words = list(set(token_info['token'] for token_info in clusters[predicted_cluster])) | |
# Gather context sentences from the predicted cluster | |
context_sentences = [] | |
for token_info in clusters[predicted_cluster]: | |
context_line_num = token_info['line_num'] | |
if context_line_num < len(input_sentences): | |
context_sentences.append(input_sentences[context_line_num].strip()) | |
else: | |
continue | |
# Rate limiting: check if we've reached the batch limit | |
call_count += 1 | |
if call_count >= batch_size: | |
elapsed = time.time() - start_time | |
# If we've made batch_size calls in less than 60 seconds, wait until the minute is up | |
if elapsed < 60: | |
wait_time = 60 - elapsed | |
print(f"\nReached API limit of {batch_size} calls. Waiting for {wait_time:.2f} seconds...") | |
time.sleep(wait_time) | |
# Reset counters | |
call_count = 0 | |
start_time = time.time() | |
# Choose the right explanation function based on token type | |
try: | |
if is_cls_token(token): | |
# Special handling for CLS tokens with retry | |
explanation = get_gemini_explanation_for_cls_with_retry(original_sentence, cluster_words, context_sentences) | |
else: | |
# Standard handling for other tokens with retry | |
explanation = get_gemini_explanation_with_retry(original_sentence, token, cluster_words) | |
# Store results | |
result = { | |
'token': token, | |
'is_cls_token': is_cls_token(token), | |
'line_idx': int(line_idx), | |
'position_idx': int(position_idx), | |
'predicted_cluster': predicted_cluster, | |
'original_sentence': original_sentence, | |
'cluster_words': cluster_words, | |
'context_sentences': context_sentences, | |
'explanation': explanation | |
} | |
results.append(result) | |
# Add to processed indices | |
processed_indices.add((line_idx, position_idx)) | |
# Save after each token for this small test run | |
with open(interim_file, 'w', encoding='utf-8') as f: | |
json.dump(results, f, indent=2, ensure_ascii=False) | |
print(f"\nSaved results to: {interim_file}") | |
except Exception as e: | |
print(f"\nUnexpected error processing token {token}: {str(e)}") | |
# Save current results before potentially exiting | |
with open(interim_file, 'w', encoding='utf-8') as f: | |
json.dump(results, f, indent=2, ensure_ascii=False) | |
print(f"Emergency save to: {interim_file}") | |
# Wait a minute before continuing | |
print("Waiting 60 seconds before continuing...") | |
time.sleep(60) | |
# Reset batch counters | |
call_count = 0 | |
start_time = time.time() | |
# Save final results with a different name to indicate it's the test run | |
output_file = os.path.join(output_dir, f"token_explanations_layer_{layer}_first15.json") | |
with open(output_file, 'w', encoding='utf-8') as f: | |
json.dump(results, f, indent=2, ensure_ascii=False) | |
print(f"Results saved to: {output_file}") | |
def main(): | |
# Configuration | |
API_KEY = "AIzaSyCUCwrqcDNTSaHsn5Ln_91A0L03W864iYU" # Replace with your API key | |
TASK = "language_classification" # Replace with your task name | |
LAYER = 11 # Replace with your layer number | |
process_tokens(TASK, LAYER, API_KEY) | |
if __name__ == "__main__": | |
main() |