| |
| """ |
| Generate programming problems from function_dataset_v2.csv using Gemini API. |
| Filters by relevance score and controls API cost. |
| """ |
|
|
| import csv |
| import json |
| import os |
| import sys |
| import vertexai |
| from vertexai.generative_models import GenerativeModel |
| from datetime import datetime |
| from typing import Dict, Optional, Tuple, List |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| import threading |
|
|
|
|
| |
| PROJECT_ID = "tangou" |
| MODEL_NAME = "gemini-2.5-flash-lite" |
| MIN_RELEVANCE_SCORE = 1 |
| MAX_BUDGET_USD = 50.0 |
|
|
| |
| |
| INPUT_PRICE_PER_MILLION = 0.1 |
| OUTPUT_PRICE_PER_MILLION = 0.4 |
|
|
| |
| |
| |
|
|
| PROMPT_TEMPLATE = """You are an expert in scientific computing and computational chemistry/biology/physics. Please create a high-quality programming problem inspired by the following code snippet from a real scientific computing project. |
| |
| The problem should focus on scientific computing concepts such as: |
| - Numerical algorithms and simulations |
| - Data analysis and visualization |
| - Mathematical modeling |
| - Scientific data processing |
| - Computational methods in chemistry, biology, or physics |
| |
| Code snippet for inspiration: |
| ```python |
| {code} |
| ``` |
| |
| Present your output in two distinct sections: |
| |
| [Problem Description] |
| Create a **completely self-contained** problem description that: |
| - Does NOT directly reference the code snippet above |
| - Provides all necessary context and background |
| - Clearly states what needs to be implemented |
| - Specifies input/output format and constraints |
| - Is inspired by the scientific computing concepts in the code but creates a NEW, interesting problem |
| - Assumes common programming knowledge but explains any domain-specific concepts |
| |
| [Solution] |
| Provide a comprehensive, **correct** Python solution that: |
| - Accurately solves the problem described |
| - Includes clear comments explaining the approach |
| - Uses appropriate scientific computing libraries (numpy, scipy, etc.) when relevant |
| - Is complete and runnable |
| - Follows best practices for scientific computing |
| |
| Remember: The problem should be INSPIRED by the code, not a direct copy. Create something educational and interesting for scientific computing practitioners.""" |
|
|
|
|
| class GeminiAPIClient: |
| """Client for Gemini API with cost tracking.""" |
| |
| def __init__(self, project_id: str, model_name: str): |
| """Initialize Gemini API client. |
| |
| Args: |
| project_id: Google Cloud project ID |
| model_name: Name of the Gemini model to use |
| """ |
| vertexai.init(project=project_id) |
| self.model = GenerativeModel(model_name) |
| self.total_input_tokens = 0 |
| self.total_output_tokens = 0 |
| self.total_requests = 0 |
| self.total_cost = 0.0 |
| self._lock = threading.Lock() |
| |
| def generate_content(self, prompt: str) -> Tuple[str, Dict]: |
| """Generate content using Gemini API and track usage. |
| |
| Args: |
| prompt: The prompt to send to the API |
| |
| Returns: |
| Tuple of (response_text, usage_info) |
| usage_info contains: input_tokens, output_tokens, cost |
| """ |
| try: |
| response = self.model.generate_content(prompt) |
| usage_metadata = response.usage_metadata |
| |
| input_tokens = usage_metadata.prompt_token_count |
| output_tokens = usage_metadata.candidates_token_count |
| |
| |
| input_cost = (input_tokens / 1_000_000) * INPUT_PRICE_PER_MILLION |
| output_cost = (output_tokens / 1_000_000) * OUTPUT_PRICE_PER_MILLION |
| request_cost = input_cost + output_cost |
| |
| |
| with self._lock: |
| self.total_input_tokens += input_tokens |
| self.total_output_tokens += output_tokens |
| self.total_requests += 1 |
| self.total_cost += request_cost |
| |
| usage_info = { |
| 'input_tokens': input_tokens, |
| 'output_tokens': output_tokens, |
| 'total_tokens': input_tokens + output_tokens, |
| 'input_cost': input_cost, |
| 'output_cost': output_cost, |
| 'request_cost': request_cost |
| } |
| |
| return response.text, usage_info |
| |
| except Exception as e: |
| print(f"Error generating content: {e}") |
| raise |
| |
| def get_total_usage(self) -> Dict: |
| """Get total usage statistics. |
| |
| Returns: |
| Dictionary with total usage information |
| """ |
| return { |
| 'total_requests': self.total_requests, |
| 'total_input_tokens': self.total_input_tokens, |
| 'total_output_tokens': self.total_output_tokens, |
| 'total_tokens': self.total_input_tokens + self.total_output_tokens, |
| 'total_cost': self.total_cost |
| } |
| |
| def print_usage_summary(self): |
| """Print a summary of API usage and costs.""" |
| usage = self.get_total_usage() |
| print("\n" + "="*70) |
| print("API USAGE SUMMARY") |
| print("="*70) |
| print(f"Total Requests: {usage['total_requests']}") |
| print(f"Total Input Tokens: {usage['total_input_tokens']:,}") |
| print(f"Total Output Tokens: {usage['total_output_tokens']:,}") |
| print(f"Total Tokens: {usage['total_tokens']:,}") |
| print(f"\nTotal Cost: ${usage['total_cost']:.6f}") |
| print(f"Budget Remaining: ${MAX_BUDGET_USD - usage['total_cost']:.6f}") |
| print("="*70) |
|
|
|
|
| def process_function_dataset( |
| input_file: str, |
| output_file: str, |
| min_score: int = MIN_RELEVANCE_SCORE, |
| max_budget: float = MAX_BUDGET_USD, |
| max_samples: Optional[int] = None, |
| start_from: int = 0, |
| max_workers: int = 5 |
| ): |
| """Process function dataset and generate programming problems. |
| |
| Args: |
| input_file: Path to function_dataset_v2.csv |
| output_file: Path to output JSONL file |
| min_score: Minimum relevance score to process |
| max_budget: Maximum budget in USD |
| max_samples: Maximum number of samples to process (None for all) |
| start_from: Skip first N rows (for resuming) |
| max_workers: Maximum number of concurrent workers (default: 5) |
| """ |
| print(f"Starting programming problem generation...") |
| print(f"Input: {input_file}") |
| print(f"Output: {output_file}") |
| print(f"Min Relevance Score: {min_score}") |
| print(f"Max Budget: ${max_budget:.2f}") |
| print(f"Max Workers: {max_workers}") |
| if max_samples: |
| print(f"Max Samples: {max_samples}") |
| print(f"Starting from row: {start_from}") |
| print() |
| |
| |
| processed_rows = set() |
| if os.path.exists(output_file): |
| print(f"Checking existing output file for already processed rows...") |
| try: |
| with open(output_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| try: |
| data = json.loads(line.strip()) |
| if 'row_number' in data: |
| processed_rows.add(data['row_number']) |
| except json.JSONDecodeError: |
| continue |
| print(f"Found {len(processed_rows)} already processed rows. These will be skipped.") |
| except Exception as e: |
| print(f"Warning: Could not read existing output file: {e}") |
| else: |
| print(f"No existing output file found. Will create new file.") |
| print() |
| |
| |
| client = GeminiAPIClient(PROJECT_ID, MODEL_NAME) |
| |
| |
| total_rows = 0 |
| processed = 0 |
| skipped_low_score = 0 |
| skipped_no_code = 0 |
| skipped_already_processed = 0 |
| errors = 0 |
| |
| |
| tasks = [] |
| |
| with open(input_file, 'r', encoding='utf-8') as infile: |
| reader = csv.DictReader(infile) |
| |
| for row in reader: |
| total_rows += 1 |
| |
| |
| if total_rows <= start_from: |
| continue |
| |
| |
| if total_rows in processed_rows: |
| skipped_already_processed += 1 |
| continue |
| |
| |
| if max_samples and len(tasks) >= max_samples: |
| break |
| |
| |
| try: |
| relevance_score = int(row.get('relevance_score', 0)) |
| except (ValueError, TypeError): |
| relevance_score = 0 |
| |
| if relevance_score < min_score: |
| skipped_low_score += 1 |
| continue |
| |
| |
| function_content = row.get('function_content', '').strip() |
| if not function_content or len(function_content) < 50: |
| skipped_no_code += 1 |
| continue |
| |
| |
| metadata = { |
| 'original_index': row.get('original_index'), |
| 'function_name': row.get('function_name'), |
| 'repo_name': row.get('repo_name'), |
| 'path': row.get('path'), |
| 'language': row.get('language'), |
| 'relevance_score': relevance_score, |
| 'function_start_line': row.get('function_start_line'), |
| 'function_end_line': row.get('function_end_line'), |
| } |
| |
| |
| prompt = PROMPT_TEMPLATE.format(code=function_content) |
| |
| tasks.append({ |
| 'row_number': total_rows, |
| 'metadata': metadata, |
| 'prompt': prompt, |
| 'function_content': function_content |
| }) |
| |
| print(f"Total rows read: {total_rows}") |
| print(f"Tasks to process: {len(tasks)}") |
| print(f"Skipped (low score): {skipped_low_score}") |
| print(f"Skipped (no/short code): {skipped_no_code}") |
| print(f"\nStarting concurrent processing with {max_workers} workers...\n") |
| |
| |
| def process_task(task): |
| """Process a single task.""" |
| try: |
| row_number = task['row_number'] |
| metadata = task['metadata'] |
| prompt = task['prompt'] |
| |
| print(f"Processing row {row_number} (score={metadata['relevance_score']}, func={metadata['function_name']})...", end=' ') |
| |
| response_text, usage_info = client.generate_content(prompt) |
| |
| print(f"✓ (${usage_info['request_cost']:.6f}, {usage_info['total_tokens']} tokens)") |
| |
| |
| return { |
| 'success': True, |
| 'data': { |
| 'metadata': metadata, |
| 'prompt': prompt, |
| 'response': response_text, |
| 'usage': usage_info, |
| 'timestamp': datetime.now().isoformat(), |
| 'row_number': row_number |
| } |
| } |
| |
| except Exception as e: |
| print(f"✗ Error: {e}") |
| return { |
| 'success': False, |
| 'error': str(e), |
| 'row_number': task['row_number'] |
| } |
| |
| |
| mode = 'a' if start_from > 0 else 'w' |
| |
| |
| with open(output_file, mode, encoding='utf-8') as outfile: |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| |
| future_to_task = {executor.submit(process_task, task): task for task in tasks} |
| |
| |
| for future in as_completed(future_to_task): |
| |
| if client.total_cost >= max_budget: |
| print(f"\n⚠️ Budget limit reached (${client.total_cost:.6f} >= ${max_budget:.2f})") |
| print(f"Cancelling remaining tasks...") |
| |
| for f in future_to_task: |
| f.cancel() |
| break |
| |
| result = future.result() |
| |
| if result['success']: |
| |
| outfile.write(json.dumps(result['data'], ensure_ascii=False) + '\n') |
| outfile.flush() |
| |
| processed += 1 |
| |
| |
| if processed % 10 == 0: |
| print(f"\n--- Progress: {processed} problems generated, ${client.total_cost:.6f} spent ---\n") |
| else: |
| errors += 1 |
| |
| |
| print("\n" + "="*70) |
| print("PROCESSING COMPLETE") |
| print("="*70) |
| print(f"Total rows read: {total_rows}") |
| print(f"Successfully processed: {processed}") |
| print(f"Skipped (low score): {skipped_low_score}") |
| print(f"Skipped (no/short code): {skipped_no_code}") |
| print(f"Errors: {errors}") |
| |
| client.print_usage_summary() |
| |
| print(f"\nResults saved to: {output_file}") |
| |
| return processed |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser( |
| description='Generate programming problems from function dataset using Gemini API' |
| ) |
| parser.add_argument( |
| '--input', |
| default='function_dataset_v2.csv', |
| help='Input CSV file (default: function_dataset_v2.csv)' |
| ) |
| parser.add_argument( |
| '--output', |
| default='programming_problems.jsonl', |
| help='Output JSONL file (default: programming_problems.jsonl)' |
| ) |
| parser.add_argument( |
| '--min-score', |
| type=int, |
| default=MIN_RELEVANCE_SCORE, |
| help=f'Minimum relevance score (default: {MIN_RELEVANCE_SCORE})' |
| ) |
| parser.add_argument( |
| '--max-budget', |
| type=float, |
| default=MAX_BUDGET_USD, |
| help=f'Maximum budget in USD (default: {MAX_BUDGET_USD})' |
| ) |
| parser.add_argument( |
| '--max-samples', |
| type=int, |
| default=None, |
| help='Maximum number of samples to process (default: no limit)' |
| ) |
| parser.add_argument( |
| '--start-from', |
| type=int, |
| default=0, |
| help='Start from row N (for resuming, default: 0)' |
| ) |
| parser.add_argument( |
| '--max-workers', |
| type=int, |
| default=10, |
| help='Maximum number of concurrent workers (default: 10)' |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| if not os.path.exists(args.input): |
| print(f"Error: Input file not found: {args.input}") |
| sys.exit(1) |
| |
| try: |
| process_function_dataset( |
| input_file=args.input, |
| output_file=args.output, |
| min_score=args.min_score, |
| max_budget=args.max_budget, |
| max_samples=args.max_samples, |
| start_from=args.start_from, |
| max_workers=args.max_workers |
| ) |
| print("\n✅ Success!") |
| except KeyboardInterrupt: |
| print("\n\n⚠️ Interrupted by user. Progress has been saved to output file.") |
| print(" You can resume by using --start-from <row_number>") |
| sys.exit(0) |
| except Exception as e: |
| print(f"\n❌ Error: {e}") |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|