|
""" |
|
Initial Solution Generator |
|
|
|
AZR κΈ°λ° TestTime RLVRμ μν μ΄κΈ° μ루μ
μμ±κΈ° |
|
κΈ°μ‘΄ Test-Time-RLVRμ generate_initial_solution ν¨μλ₯Ό ν΄λμ€ννμ¬ νμ₯ |
|
""" |
|
|
|
import re |
|
import torch |
|
from typing import Dict, Any, Optional, Tuple, List |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
from .config import TestTimeConfig |
|
from .logger import TestTimeLogger |
|
from .prompts import get_prompt, get_temperature, get_diversity_instruction |
|
|
|
|
|
from ..rewards.custom_evaluate import extract_code |
|
|
|
|
|
try: |
|
from vllm import LLM, SamplingParams |
|
VLLM_AVAILABLE = True |
|
except ImportError: |
|
VLLM_AVAILABLE = False |
|
|
|
|
|
class InitialSolutionGenerator: |
|
"""λ²€μΉλ§ν¬ λ¬Έμ μ λν μ΄κΈ° μ루μ
μμ±""" |
|
|
|
def __init__(self, model, tokenizer, config: TestTimeConfig, |
|
logger: Optional[TestTimeLogger] = None, use_vllm: bool = True): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.config = config |
|
self.logger = logger or TestTimeLogger() |
|
self.use_vllm = use_vllm and VLLM_AVAILABLE |
|
|
|
|
|
if use_vllm and not VLLM_AVAILABLE: |
|
self.logger.log_info("β οΈ VLLM requested but not available, falling back to HuggingFace") |
|
elif self.use_vllm: |
|
self.logger.log_info("π Using VLLM for optimized inference") |
|
else: |
|
self.logger.log_info("π§ Using HuggingFace Transformers for inference") |
|
|
|
def generate(self, problem: Dict[str, Any]) -> str: |
|
"""λ¬Έμ μ λν μ΄κΈ° μ루μ
μμ± (AZR μ½λ νκ° ν둬ννΈ μ¬μ©)""" |
|
|
|
problem_prompt = problem['prompt'] |
|
problem_id = problem.get('task_id', 'unknown') |
|
|
|
|
|
|
|
|
|
|
|
if 'HumanEval' in problem_id: |
|
|
|
entry_point = problem.get('entry_point', 'unknown') |
|
|
|
|
|
import re |
|
function_count = len(re.findall(r'^\s*def\s+\w+', problem_prompt, re.MULTILINE)) |
|
|
|
if function_count > 1: |
|
|
|
prompt = get_prompt("solution_humaneval_multi", |
|
problem_prompt=problem_prompt, |
|
entry_point=entry_point) |
|
else: |
|
|
|
prompt = get_prompt("solution_humaneval_basic", |
|
problem_prompt=problem_prompt) |
|
else: |
|
|
|
prompt = get_prompt("solution_mbpp_basic", |
|
problem_prompt=problem_prompt) |
|
|
|
self.logger.log_info(f"π Generating initial solution for {problem_id}") |
|
self.logger.log_info(f"π Full prompt: {prompt}") |
|
|
|
|
|
if self.use_vllm and isinstance(self.model, LLM): |
|
solution = self._generate_with_vllm(prompt) |
|
else: |
|
solution = self._generate_with_huggingface(prompt) |
|
|
|
|
|
extracted_solution = self._extract_python_code(solution) |
|
|
|
|
|
if extracted_solution and extracted_solution != solution: |
|
self.logger.log_info(f"π Extracted Python code from markdown block") |
|
solution = extracted_solution |
|
elif not extracted_solution: |
|
self.logger.log_info(f"π No markdown code block found, using original text") |
|
|
|
|
|
if 'HumanEval' in problem_id: |
|
solution = self._add_imports_from_prompt(solution, problem_prompt) |
|
|
|
|
|
solution = self._fix_function_definition(solution, prompt, problem_id) |
|
|
|
self.logger.log_info(f"β
Generated solution ({len(solution)} chars)") |
|
self.logger.log_info(f"π Solution preview: {solution[:200]}...") |
|
|
|
|
|
self.logger.log_info(f"π Full solution for debugging:") |
|
self.logger.log_info(f"--- START SOLUTION ---") |
|
self.logger.log_info(solution) |
|
self.logger.log_info(f"--- END SOLUTION ---") |
|
|
|
return solution |
|
|
|
def generate_diverse(self, problem: Dict[str, Any], temperature: float = 0.7, variation_id: int = 0) -> str: |
|
"""λ€μν μ루μ
μμ± (λμ temperature μ¬μ©)""" |
|
|
|
problem_prompt = problem['prompt'] |
|
problem_id = problem.get('task_id', 'unknown') |
|
|
|
|
|
diversity_instruction = get_diversity_instruction(variation_id) |
|
|
|
|
|
if 'HumanEval' in problem_id: |
|
entry_point = problem.get('entry_point', 'unknown') |
|
|
|
import re |
|
function_count = len(re.findall(r'^\s*def\s+\w+', problem_prompt, re.MULTILINE)) |
|
|
|
if function_count > 1: |
|
prompt = get_prompt("diverse_humaneval_multi", |
|
diversity_instruction=diversity_instruction, |
|
problem_prompt=problem_prompt, |
|
entry_point=entry_point) |
|
else: |
|
prompt = get_prompt("diverse_humaneval_basic", |
|
diversity_instruction=diversity_instruction, |
|
problem_prompt=problem_prompt) |
|
else: |
|
|
|
prompt = get_prompt("diverse_mbpp_basic", |
|
diversity_instruction=diversity_instruction, |
|
problem_prompt=problem_prompt) |
|
|
|
self.logger.log_info(f"π¨ Generating diverse solution #{variation_id+1} for {problem_id}") |
|
|
|
|
|
try: |
|
from vllm import LLM |
|
if isinstance(self.model, LLM): |
|
solution = self._generate_with_vllm_diverse(prompt, temperature) |
|
else: |
|
solution = self._generate_with_huggingface_diverse(prompt, temperature) |
|
except ImportError: |
|
solution = self._generate_with_huggingface_diverse(prompt, temperature) |
|
|
|
|
|
extracted_solution = self._extract_python_code(solution) |
|
if extracted_solution and extracted_solution != solution: |
|
self.logger.log_info(f"π Extracted Python code from markdown block") |
|
solution = extracted_solution |
|
|
|
if 'HumanEval' in problem_id: |
|
solution = self._add_imports_from_prompt(solution, problem_prompt) |
|
|
|
solution = self._fix_function_definition(solution, prompt, problem_id) |
|
|
|
self.logger.log_info(f"β
Generated diverse solution #{variation_id+1} ({len(solution)} chars)") |
|
|
|
return solution |
|
|
|
def _generate_with_vllm(self, prompt: str) -> str: |
|
"""VLLM λ°±μλλ‘ μμ± (AZR λ°©μ)""" |
|
|
|
|
|
sampling_params = SamplingParams( |
|
temperature=0.05, |
|
max_tokens=2048, |
|
top_p=1.0, |
|
stop=["\n```\n"], |
|
) |
|
|
|
|
|
outputs = self.model.generate([prompt], sampling_params, use_tqdm=False) |
|
solution = outputs[0].outputs[0].text.replace("\t", " ") |
|
|
|
return solution.strip() |
|
|
|
def _generate_with_vllm_diverse(self, prompt: str, temperature: float = 0.7) -> str: |
|
"""λ€μν μ루μ
μμ±μ© VLLM λ°±μλ (λμ temperature)""" |
|
|
|
|
|
sampling_params = SamplingParams( |
|
temperature=temperature, |
|
max_tokens=2048, |
|
top_p=0.95, |
|
stop=["\n```\n"], |
|
) |
|
|
|
|
|
outputs = self.model.generate([prompt], sampling_params, use_tqdm=False) |
|
solution = outputs[0].outputs[0].text.replace("\t", " ") |
|
|
|
return solution.strip() |
|
|
|
def generate_batch(self, prompts: List[str], temperature: float = 0.7) -> List[str]: |
|
"""λ°°μΉλ‘ μ¬λ¬ ν둬ννΈ λμ μ²λ¦¬""" |
|
|
|
|
|
if self.use_vllm and isinstance(self.model, LLM): |
|
raw_solutions = self._generate_batch_with_vllm(prompts, temperature) |
|
else: |
|
|
|
raw_solutions = [self._generate_with_huggingface(prompt) for prompt in prompts] |
|
|
|
|
|
processed_solutions = [] |
|
for i, (prompt, solution) in enumerate(zip(prompts, raw_solutions)): |
|
|
|
extracted = self._extract_python_code(solution) |
|
if extracted and extracted != solution: |
|
self.logger.log_info(f"π Extracted Python code from markdown block for batch item {i+1}") |
|
solution = extracted |
|
|
|
|
|
|
|
if 'HumanEval' in prompt: |
|
|
|
|
|
solution = self._add_imports_from_prompt(solution, prompt) |
|
|
|
|
|
|
|
solution = self._fix_function_definition(solution, prompt) |
|
|
|
processed_solutions.append(solution) |
|
|
|
return processed_solutions |
|
|
|
def _generate_batch_with_vllm(self, prompts: List[str], temperature: float = 0.7) -> List[str]: |
|
"""VLLMμΌλ‘ λ°°μΉ μ²λ¦¬""" |
|
|
|
|
|
|
|
sampling_params = SamplingParams( |
|
temperature=temperature, |
|
top_p=0.85, |
|
max_tokens=1024, |
|
stop=[] |
|
) |
|
|
|
|
|
outputs = self.model.generate(prompts, sampling_params, use_tqdm=False) |
|
|
|
|
|
solutions = [] |
|
for i, output in enumerate(outputs): |
|
solution = output.outputs[0].text.replace("\t", " ") |
|
|
|
finish_reason = output.outputs[0].finish_reason |
|
if finish_reason != "stop" and i < 3: |
|
self.logger.log_warning(f"Output {i} finish_reason: {finish_reason}, length: {len(solution)}") |
|
solutions.append(solution.strip()) |
|
|
|
return solutions |
|
|
|
def _generate_with_huggingface(self, prompt: str) -> str: |
|
"""HuggingFace λ°±μλλ‘ μμ± (attention mask μμ )""" |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096) |
|
|
|
|
|
if 'attention_mask' not in inputs: |
|
inputs['attention_mask'] = torch.ones_like(inputs['input_ids']) |
|
|
|
|
|
device = getattr(self.model, 'device', 'cuda' if torch.cuda.is_available() else 'cpu') |
|
if isinstance(device, str): |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
else: |
|
|
|
inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
outputs = self.model.generate( |
|
inputs['input_ids'], |
|
attention_mask=inputs['attention_mask'], |
|
max_new_tokens=2048, |
|
do_sample=False, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
|
|
solution = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
solution = solution[len(prompt):].strip() |
|
|
|
return solution |
|
|
|
def _generate_with_huggingface_diverse(self, prompt: str, temperature: float = 0.7) -> str: |
|
"""λ€μν μ루μ
μμ±μ© HuggingFace λ°±μλ (λμ temperature)""" |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096) |
|
|
|
|
|
if 'attention_mask' not in inputs: |
|
inputs['attention_mask'] = torch.ones_like(inputs['input_ids']) |
|
|
|
|
|
device = getattr(self.model, 'device', 'cuda' if torch.cuda.is_available() else 'cpu') |
|
if isinstance(device, str): |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
else: |
|
|
|
inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
outputs = self.model.generate( |
|
inputs['input_ids'], |
|
attention_mask=inputs['attention_mask'], |
|
max_new_tokens=2048, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=0.95, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
eos_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
|
|
solution = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
solution = solution[len(prompt):].strip() |
|
|
|
return solution |
|
|
|
def _extract_python_code(self, solution: str) -> str: |
|
"""κ°μ λ Python μ½λ μΆμΆ (AZR λ°©μ + μΆκ° ν¨ν΄)""" |
|
|
|
|
|
try: |
|
extracted = extract_code(solution, language="python") |
|
if extracted: |
|
return extracted |
|
except: |
|
pass |
|
|
|
|
|
patterns = [ |
|
r'```python\n(.*?)```', |
|
r'```\n(.*?)```', |
|
r'```py\n(.*?)```', |
|
r'```Python\n(.*?)```', |
|
r'Here is.*?:\n\n```python\n(.*?)```', |
|
r'Here is.*?:\n\n```\n(.*?)```', |
|
] |
|
|
|
for pattern in patterns: |
|
matches = re.findall(pattern, solution, re.DOTALL | re.IGNORECASE) |
|
if matches: |
|
return matches[-1].strip() |
|
|
|
|
|
lines = solution.split('\n') |
|
code_lines = [] |
|
in_function = False |
|
|
|
for line in lines: |
|
if line.strip().startswith('def '): |
|
in_function = True |
|
code_lines.append(line) |
|
elif in_function and (line.startswith(' ') or line.strip() == ''): |
|
code_lines.append(line) |
|
elif in_function and line.strip() and not line.startswith(' '): |
|
|
|
break |
|
|
|
if code_lines: |
|
return '\n'.join(code_lines) |
|
|
|
|
|
return solution |
|
|
|
def _add_imports_from_prompt(self, solution: str, prompt: str) -> str: |
|
"""HumanEval ν둬ννΈμμ import λ¬Έμ μΆμΆνμ¬ μ루μ
μ μΆκ° (EvalPlus λ°©μ)""" |
|
|
|
|
|
if 'from typing import' in solution or 'import typing' in solution: |
|
return solution |
|
|
|
|
|
import_lines = [] |
|
prompt_lines = prompt.split('\n') |
|
|
|
for line in prompt_lines: |
|
stripped = line.strip() |
|
|
|
if (stripped.startswith('from ') and 'import' in stripped) or stripped.startswith('import '): |
|
import_lines.append(line) |
|
|
|
elif stripped.startswith('def '): |
|
break |
|
|
|
|
|
if not import_lines: |
|
return solution |
|
|
|
|
|
self.logger.log_info(f"π§ Adding imports from prompt: {import_lines}") |
|
|
|
|
|
solution_lines = solution.split('\n') |
|
first_non_empty_line = None |
|
for i, line in enumerate(solution_lines): |
|
if line.strip(): |
|
first_non_empty_line = i |
|
break |
|
|
|
|
|
if first_non_empty_line is not None: |
|
|
|
imports_text = '\n'.join(import_lines) + '\n\n' |
|
|
|
|
|
if solution_lines[first_non_empty_line].strip().startswith(('import ', 'from ')): |
|
|
|
last_import_idx = first_non_empty_line |
|
for i in range(first_non_empty_line, len(solution_lines)): |
|
if solution_lines[i].strip() and not solution_lines[i].strip().startswith(('import ', 'from ')): |
|
break |
|
if solution_lines[i].strip().startswith(('import ', 'from ')): |
|
last_import_idx = i |
|
|
|
|
|
solution_lines.insert(last_import_idx + 1, '') |
|
solution_lines.insert(last_import_idx + 1, '\n'.join(import_lines)) |
|
return '\n'.join(solution_lines) |
|
else: |
|
|
|
return imports_text + solution |
|
|
|
return imports_text + solution |
|
|
|
def _fix_function_definition(self, solution: str, prompt: str, problem_id: str = "") -> str: |
|
"""ν¨μ μ μκ° λλ½λ κ²½μ° λ³΅κ΅¬ + lpw μ€νμΌ μ€λ³΅ μ²λ¦¬""" |
|
|
|
|
|
func_def_match = re.search(r'def\s+(\w+)\([^)]*\)(?:\s*->\s*[^:]+)?:', prompt) |
|
if not func_def_match: |
|
return solution |
|
|
|
entry_point = func_def_match.group(1) |
|
func_def_line = func_def_match.group(0) |
|
|
|
|
|
if 'HumanEval' in problem_id: |
|
|
|
return solution |
|
|
|
|
|
|
|
if (prompt in solution) or (f'def {entry_point}(' in solution): |
|
|
|
self.logger.log_info(f"β
Function definition already present for {entry_point}") |
|
return solution |
|
|
|
|
|
if solution and not solution.startswith('def '): |
|
|
|
lines = solution.split('\n') |
|
fixed_lines = [func_def_line] |
|
|
|
for line in lines: |
|
if line.strip(): |
|
|
|
if line.strip().startswith('if __name__'): |
|
|
|
fixed_lines.append('') |
|
fixed_lines.append(line.strip()) |
|
else: |
|
|
|
if not line.startswith(' ') and line.strip(): |
|
line = ' ' + line.lstrip() |
|
fixed_lines.append(line) |
|
else: |
|
fixed_lines.append(line) |
|
|
|
solution = '\n'.join(fixed_lines) |
|
self.logger.log_info(f"π§ Fixed function definition for {entry_point}") |
|
|
|
return solution |
|
|
|
def generate_fallback_solution(self, problem: Dict[str, Any]) -> str: |
|
"""λ¬Έμ μμ± μ€ν¨ μ λ체 μ루μ
μμ±""" |
|
|
|
entry_point = problem.get('entry_point', 'solution') |
|
problem_description = problem.get('prompt', '') |
|
|
|
|
|
if 'similar_elements' in problem_description: |
|
|
|
solution = f"""def {entry_point}(test_tup1, test_tup2): |
|
return tuple(set(test_tup1) & set(test_tup2))""" |
|
elif 'kth_element' in problem_description: |
|
|
|
solution = f"""def {entry_point}(arr, k): |
|
return sorted(arr)[k-1]""" |
|
else: |
|
|
|
solution = f"""def {entry_point}(*args): |
|
# TODO: Implement this function |
|
return None""" |
|
|
|
self.logger.log_info(f"π Generated fallback solution for {entry_point}") |
|
return solution |
|
|
|
def validate_syntax(self, solution: str) -> Tuple[bool, Optional[str]]: |
|
"""μ루μ
ꡬ문 κ²μ¦""" |
|
try: |
|
compile(solution, '<string>', 'exec') |
|
return True, None |
|
except SyntaxError as e: |
|
return False, str(e) |
|
except Exception as e: |
|
return False, str(e) |
|
|
|
def extract_function_signature(self, prompt: str) -> Optional[Dict[str, str]]: |
|
"""ν둬ννΈμμ ν¨μ μκ·Έλμ² μΆμΆ""" |
|
|
|
|
|
pattern = r'def\s+(\w+)\(([^)]*)\)(?:\s*->\s*([^:]+))?:' |
|
match = re.search(pattern, prompt) |
|
|
|
if match: |
|
func_name = match.group(1) |
|
args = match.group(2) |
|
return_type = match.group(3) |
|
|
|
return { |
|
'name': func_name, |
|
'args': args.strip(), |
|
'return_type': return_type.strip() if return_type else None, |
|
'full_signature': match.group(0) |
|
} |
|
|
|
return None |
|
|
|
def format_solution(self, raw_solution: str, problem: Dict[str, Any]) -> str: |
|
"""μ루μ
νμ μ 리""" |
|
|
|
|
|
solution = raw_solution.strip() |
|
|
|
|
|
if not solution.startswith('def '): |
|
signature = self.extract_function_signature(problem.get('prompt', '')) |
|
if signature: |
|
|
|
lines = solution.split('\n') |
|
indented_lines = [' ' + line if line.strip() else line for line in lines] |
|
solution = signature['full_signature'] + '\n' + '\n'.join(indented_lines) |
|
|
|
|
|
lines = solution.split('\n') |
|
code_lines = [] |
|
in_function = False |
|
|
|
for line in lines: |
|
if line.strip().startswith('def '): |
|
in_function = True |
|
code_lines.append(line) |
|
elif in_function: |
|
code_lines.append(line) |
|
elif line.strip() and not any(keyword in line.lower() for keyword in |
|
['explanation', 'here', 'this function', 'the solution']): |
|
code_lines.append(line) |
|
|
|
return '\n'.join(code_lines).strip() |
|
|
|
@staticmethod |
|
def extract_docstring_from_function(code: str) -> str: |
|
"""ν¨μ μ½λμμ docstringμ μΆμΆ""" |
|
import re |
|
|
|
|
|
|
|
docstring_patterns = [ |
|
r'def\s+\w+\([^)]*\):\s*\n\s*"""(.*?)"""', |
|
r'def\s+\w+\([^)]*\):\s*\n\s*\'\'\'(.*?)\'\'\'', |
|
] |
|
|
|
for pattern in docstring_patterns: |
|
match = re.search(pattern, code, re.DOTALL) |
|
if match: |
|
docstring = match.group(1).strip() |
|
|
|
lines = docstring.split('\n') |
|
cleaned_lines = [] |
|
for line in lines: |
|
cleaned_line = line.strip() |
|
if cleaned_line: |
|
cleaned_lines.append(cleaned_line) |
|
|
|
return ' '.join(cleaned_lines) |
|
|
|
|
|
return "Find the function that produces these outputs from these inputs." |
|
|
|
def _extract_function_code(self, code: str) -> str: |
|
"""μ½λμμ ν¨μ μ μμ νμν import μΆμΆ""" |
|
import re |
|
|
|
lines = code.strip().split('\n') |
|
import_lines = [] |
|
func_lines = [] |
|
in_function = False |
|
indent_level = 0 |
|
|
|
|
|
for line in lines: |
|
stripped = line.strip() |
|
if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'): |
|
import_lines.append(line) |
|
|
|
|
|
for line in lines: |
|
if line.strip().startswith('def '): |
|
in_function = True |
|
func_lines = [line] |
|
|
|
indent_level = len(line) - len(line.lstrip()) |
|
elif in_function: |
|
|
|
if not line.strip() or (line.strip() and len(line) - len(line.lstrip()) > indent_level): |
|
func_lines.append(line) |
|
else: |
|
|
|
break |
|
|
|
|
|
if func_lines: |
|
result_lines = import_lines + [''] + func_lines if import_lines else func_lines |
|
return '\n'.join(result_lines) |
|
else: |
|
return code |
|
|
|
def evaluate_solution(self, problem: Dict[str, Any], solution: str) -> Dict[str, Any]: |
|
"""LLM μ루μ
μ λ²€μΉλ§ν¬ ν
μ€νΈλ‘ νκ° (EvalPlus νμ)""" |
|
try: |
|
|
|
self.logger.log_info("π Attempting to import EvalPlus...") |
|
from evalplus.evaluate import check_correctness |
|
from evalplus.gen.util import trusted_exec |
|
from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS |
|
from evalplus.eval import PASS |
|
self.logger.log_info("β
Using EvalPlus for evaluation") |
|
except ImportError as e: |
|
|
|
self.logger.log_error(f"β EvalPlus is required but not available: {e}") |
|
import traceback |
|
self.logger.log_error(f"π Import traceback: {traceback.format_exc()}") |
|
return { |
|
'correct': False, |
|
'passed_tests': 0, |
|
'total_tests': 0, |
|
'error': f"EvalPlus import failed: {e}. Please install EvalPlus properly.", |
|
'execution_results': [], |
|
'base_passed': 0, |
|
'plus_passed': 0, |
|
'base_total': 0, |
|
'plus_total': 0 |
|
} |
|
except Exception as e: |
|
self.logger.log_error(f"β EvalPlus import failed with unexpected error: {e}") |
|
return { |
|
'correct': False, |
|
'passed_tests': 0, |
|
'total_tests': 0, |
|
'error': f"EvalPlus import error: {e}", |
|
'execution_results': [], |
|
'base_passed': 0, |
|
'plus_passed': 0, |
|
'base_total': 0, |
|
'plus_total': 0 |
|
} |
|
|
|
result = { |
|
'correct': False, |
|
'passed_tests': 0, |
|
'total_tests': 0, |
|
'error': None, |
|
'execution_results': [], |
|
'base_passed': 0, |
|
'plus_passed': 0, |
|
'base_total': 0, |
|
'plus_total': 0 |
|
} |
|
|
|
try: |
|
|
|
extracted_code = self._extract_function_code(solution) |
|
if not extracted_code: |
|
result['error'] = "No function definition found" |
|
return result |
|
|
|
|
|
task_id = problem.get('task_id', '') |
|
if task_id.startswith('Mbpp'): |
|
dataset = 'mbpp' |
|
elif task_id.startswith('HumanEval'): |
|
dataset = 'humaneval' |
|
else: |
|
|
|
dataset = 'mbpp' |
|
|
|
|
|
entry_point = problem.get('entry_point', '') |
|
canonical_solution = problem.get('canonical_solution', '') |
|
|
|
if not canonical_solution: |
|
result['error'] = "No canonical_solution found" |
|
return result |
|
|
|
|
|
expected_output = {} |
|
|
|
|
|
base_inputs = problem.get('base_input', []) |
|
if base_inputs: |
|
expected_output['base'], expected_output['base_time'] = trusted_exec( |
|
problem.get('prompt', '') + canonical_solution, |
|
base_inputs, |
|
entry_point, |
|
record_time=True, |
|
output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS |
|
) |
|
|
|
|
|
plus_inputs = problem.get('plus_input', []) |
|
if plus_inputs: |
|
expected_output['plus'], expected_output['plus_time'] = trusted_exec( |
|
problem.get('prompt', '') + canonical_solution, |
|
plus_inputs, |
|
entry_point, |
|
record_time=True, |
|
output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS |
|
) |
|
|
|
|
|
evalplus_result = check_correctness( |
|
dataset=dataset, |
|
completion_id=0, |
|
problem=problem, |
|
solution=extracted_code, |
|
expected_output=expected_output, |
|
base_only=False, |
|
fast_check=False, |
|
identifier=task_id |
|
) |
|
|
|
|
|
if 'base' in evalplus_result: |
|
base_stat, base_details = evalplus_result['base'] |
|
result['base_total'] = len(base_inputs) |
|
if base_stat == PASS: |
|
result['base_passed'] = result['base_total'] |
|
else: |
|
result['base_passed'] = sum(1 for d in base_details if d) if base_details else 0 |
|
|
|
result['passed_tests'] += result['base_passed'] |
|
result['total_tests'] += result['base_total'] |
|
|
|
if 'plus' in evalplus_result: |
|
plus_stat, plus_details = evalplus_result['plus'] |
|
result['plus_total'] = len(plus_inputs) |
|
if plus_stat == PASS: |
|
result['plus_passed'] = result['plus_total'] |
|
else: |
|
result['plus_passed'] = sum(1 for d in plus_details if d) if plus_details else 0 |
|
|
|
result['passed_tests'] += result['plus_passed'] |
|
result['total_tests'] += result['plus_total'] |
|
|
|
|
|
result['correct'] = (result['passed_tests'] == result['total_tests']) and result['total_tests'] > 0 |
|
|
|
|
|
if not result['correct']: |
|
if base_stat != PASS: |
|
result['error'] = f"Base tests failed: {base_stat}" |
|
elif 'plus' in evalplus_result and plus_stat != PASS: |
|
result['error'] = f"Plus tests failed: {plus_stat}" |
|
|
|
|
|
self.logger.log_info(f"EvalPlus evaluation for {task_id}:") |
|
self.logger.log_info(f" Base: {result['base_passed']}/{result['base_total']}") |
|
self.logger.log_info(f" Plus: {result['plus_passed']}/{result['plus_total']}") |
|
self.logger.log_info(f" Total: {result['passed_tests']}/{result['total_tests']}") |
|
self.logger.log_info(f" Correct: {result['correct']}") |
|
|
|
except Exception as e: |
|
result['error'] = f"Evaluation failed: {str(e)}" |
|
import traceback |
|
self.logger.log_info(f"Evaluation traceback: {traceback.format_exc()}") |
|
|
|
return result |
|
|
|
|
|
@staticmethod |
|
def load_model_with_optimizations(model_name: str, device: str, |
|
config: TestTimeConfig, use_vllm: bool = True, tensor_parallel_size: int = 1) -> Tuple[Any, Any]: |
|
"""λͺ¨λΈκ³Ό ν ν¬λμ΄μ λ‘λ (AZR μ€νμΌ μ΅μ ν, VLLM μ§μ)""" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
if use_vllm and VLLM_AVAILABLE and device.startswith('cuda'): |
|
try: |
|
|
|
import os |
|
if 'CUDA_VISIBLE_DEVICES' not in os.environ: |
|
gpu_id = device.split(':')[1] if ':' in device else '0' |
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id |
|
else: |
|
|
|
gpu_id = os.environ['CUDA_VISIBLE_DEVICES'] |
|
print(f"π― Using existing CUDA_VISIBLE_DEVICES: {gpu_id}") |
|
|
|
|
|
model = LLM( |
|
model=model_name, |
|
dtype=str(config.torch_dtype).split('.')[-1], |
|
trust_remote_code=True, |
|
gpu_memory_utilization=config.gpu_memory_utilization, |
|
max_model_len=getattr(config, 'max_model_len', 2048), |
|
tensor_parallel_size=tensor_parallel_size, |
|
) |
|
print(f"β
VLLM model loaded successfully on GPU {gpu_id} (tensor_parallel_size={tensor_parallel_size})") |
|
return model, tokenizer |
|
except Exception as e: |
|
import traceback |
|
print(f"β οΈ VLLM loading failed: {e}") |
|
print(f"π Full traceback: {traceback.format_exc()}") |
|
print(f"π Falling back to HuggingFace") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=config.torch_dtype, |
|
device_map=device if device.startswith('cuda') else None, |
|
trust_remote_code=True, |
|
attn_implementation="flash_attention_2" if config.use_flash_attention and device.startswith('cuda') else None, |
|
use_cache=False, |
|
) |
|
|
|
|
|
|
|
|
|
if hasattr(model, 'gradient_checkpointing_disable'): |
|
model.gradient_checkpointing_disable() |
|
|
|
print(f"β
HuggingFace model loaded successfully") |
|
return model, tokenizer |