| | import torch.nn as nn |
| | from safetensors import safe_open |
| | from transformers import GenerationConfig |
| |
|
| | from dataclasses import dataclass, field |
| | from typing import Optional, Callable, Dict, List |
| | import os |
| | import logging |
| | import json |
| |
|
| | |
| |
|
| | |
| | CONVERSATION_TEMPLATE = r"""{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system |
| | You are a helpful assistant.<|im_end|> |
| | {% endif %}<|im_start|>{{ message['role'] }} |
| | {% if message['content'] is string %}{{ message['content'] }}<|im_end|> |
| | {% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|> |
| | {% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant |
| | {% endif %}""" |
| |
|
| | |
| | def load_state_dict_from_safetensor(model_path) -> Dict: |
| | """Load a safetensor file from the given path and return a state_dict. |
| | |
| | Args: |
| | model_path (str): Path to the safetensor file. |
| | |
| | Returns: |
| | Dict[str, torch.Tensor]: A dictionary of model parameters, |
| | where keys are parameter names and values are corresponding tensors. |
| | """ |
| | model_state_dict = {} |
| | with safe_open(model_path, framework="pt") as f: |
| | for key in f.keys(): |
| | model_state_dict[key] = f.get_tensor(key) |
| | return model_state_dict |
| |
|
| | def fix_model_parameters(model: nn.Module): |
| | """Freeze all parameters of the given model. |
| | |
| | Args: |
| | model (nn.Module): The PyTorch model whose parameters will be frozen. |
| | """ |
| | for parameter in model.parameters(): |
| | parameter.requires_grad = False |
| |
|
| | def open_model_parameters(model: nn.Module): |
| | """Unfreeze all parameters of the given model. |
| | |
| | Args: |
| | model (nn.Module): The PyTorch model whose parameters will be unfrozen. |
| | """ |
| | for parameter in model.parameters(): |
| | parameter.requires_grad = True |
| |
|
| | def log_trainable_params(model: nn.Module): |
| | """Log all trainable parameters of the given model. |
| | |
| | Args: |
| | model (nn.Module): The PyTorch model to inspect. |
| | """ |
| | logging.info("Trainable parameters in the model:") |
| | for name, param in model.named_parameters(): |
| | if param.requires_grad: |
| | logging.info(f" {name}: {param.numel()} params, shape={param.shape}") |
| |
|
| |
|
| |
|
| | |
| | @dataclass |
| | class EvalConfig: |
| | output_dir: str = None |
| | batch_size: int = 1 |
| | generation_config: GenerationConfig = None |
| |
|
| | @dataclass |
| | class StaticEvalRecorder: |
| | compute_metrics: List[Callable[[str, str, str], float]] = field(default_factory=list) |
| | log_file: Optional[str] = None |
| | writer: Optional[object] = None |
| |
|
| | |
| | metric_sums: Dict[str, float] = field(init=False) |
| | metric_counts: Dict[str, int] = field(init=False) |
| |
|
| | def __post_init__(self): |
| | self.metric_sums = {metric.__name__: 0.0 for metric in self.compute_metrics} |
| | self.metric_counts = {metric.__name__: 0 for metric in self.compute_metrics} |
| | if self.log_file: |
| | os.makedirs(os.path.dirname(self.log_file), exist_ok=True) |
| | with open(self.log_file, 'w') as f: |
| | f.write('') |
| |
|
| | def record_batch(self, completions: List[str], examples: List[Dict]): |
| | """Record results for a batch of model outputs. |
| | |
| | Args: |
| | completions (List[str]): The model's answers (outputs). |
| | examples (List[Dict]): Each completion's corresponding question and related attributes. |
| | Each example is expected to contain the keys: "prompt" and "solution". |
| | """ |
| | |
| | keys = [key for key in examples[0]] |
| | |
| | reward_kwargs = {key: [example[key] for example in examples] for key in keys} |
| | reward_kwargs['completions'] = completions |
| | |
| | |
| | batched_results = {} |
| | for metric in self.compute_metrics: |
| | metric_name = metric.__name__ |
| | batched_scores = metric(**reward_kwargs) |
| | batched_results[metric_name] = batched_scores |
| | |
| | |
| | for i, (completion, example) in enumerate(zip(completions, examples)): |
| | |
| | metrics_result = { |
| | metric_name: batched_results[metric_name][i] |
| | for metric_name in batched_results |
| | } |
| |
|
| | |
| | for metric_name, score in metrics_result.items(): |
| | self.metric_sums[metric_name] += score |
| | self.metric_counts[metric_name] += 1 |
| | |
| | |
| | prompt = example.get("prompt", "") |
| | solution = example.get("solution", "") |
| | record = { |
| | 'prompt': prompt, |
| | 'solution': solution, |
| | 'completion': completion, |
| | 'metrics': metrics_result |
| | } |
| |
|
| | |
| | if self.log_file: |
| | with open(self.log_file, 'a') as f: |
| | f.write(json.dumps(record, ensure_ascii=False) + '\n') |
| | |
| | |
| | if self.writer: |
| | mean_metrics = self.get_mean_metrics() |
| | for name, value in mean_metrics.items(): |
| | self.writer.add_scalar(name, value, global_step=self.metric_counts[name]) |
| |
|
| |
|
| | def get_mean_metrics(self) -> Dict[str, float]: |
| | return { |
| | name: (self.metric_sums[name] / self.metric_counts[name]) if self.metric_counts[name] > 0 else 0.0 |
| | for name in self.metric_sums |
| | } |
| |
|
| | def finalize(self): |
| | mean_metrics = self.get_mean_metrics() |
| | final_record = { |
| | 'summary_metrics': mean_metrics |
| | } |
| |
|
| | if self.log_file: |
| | with open(self.log_file, 'a', encoding='utf-8') as f: |
| | f.write(json.dumps(final_record, ensure_ascii=False) + '\n') |
| |
|
| | if self.writer: |
| | mean_metrics = self.get_mean_metrics() |
| | for name, value in mean_metrics.items(): |
| | self.writer.add_scalar(name + "_final", value, global_step=self.metric_counts[name]) |
| |
|
| |
|
| | @dataclass |
| | class DynamicEvalRecorder: |
| | log_file: Optional[str] = None |
| | writer: object = field(default=None) |
| |
|
| | def __post_init__(self): |
| | if self.log_file is None: |
| | raise ValueError("log_file path must be provided") |
| |
|
| | |
| | os.makedirs(os.path.dirname(self.log_file), exist_ok=True) |
| | self.logger = logging.getLogger("DynamicEvalRecorder") |
| |
|
| | |
| | self._total_reward = 0.0 |
| | self._count = 0 |
| |
|
| | |
| | with open(self.log_file, "w", encoding="utf-8") as f: |
| | f.write("DynamicEvalRecorder Log\n\n") |
| |
|
| | def record_batch(self, conversations: List[str], rewards: List[float]): |
| | """Record a batch of conversations and their associated rewards. |
| | |
| | Args: |
| | conversations (List[str]): List of conversation texts. |
| | rewards (List[float]): List of reward values corresponding to conversations. |
| | """ |
| | if len(conversations) != len(rewards): |
| | raise ValueError("conversations and rewards must have the same length") |
| |
|
| | |
| | with open(self.log_file, "a", encoding="utf-8") as f: |
| | for conv, rew in zip(conversations, rewards): |
| | f.write(f"Conversation:\n{conv}\n") |
| | f.write(f"Reward: {rew:.4f}\n") |
| | f.write("-" * 40 + "\n") |
| |
|
| | |
| | self._total_reward += rew |
| | self._count += 1 |
| |
|
| | |
| | avg_reward = self._total_reward / self._count if self._count > 0 else 0.0 |
| |
|
| | |
| | if self.writer is not None: |
| | self.writer.add_scalar("reward/avg", avg_reward, self._count) |
| |
|
| | |
| | self.logger.info(f"Recorded {len(conversations)} items, avg_reward={avg_reward:.4f}") |
| |
|
| | def finalize(self): |
| | """Finalize evaluation: write final average reward to both log file and TensorBoard.""" |
| | |
| | avg_reward = self._total_reward / self._count if self._count > 0 else 0.0 |
| |
|
| | |
| | with open(self.log_file, "a", encoding="utf-8") as f: |
| | f.write("\nFinal Results\n") |
| | f.write("=" * 40 + "\n") |
| | f.write(f"Average Reward: {avg_reward:.4f}\n") |
| |
|
| | |
| | if self.writer: |
| | self.writer.add_scalar("ave_reward_final", avg_reward, global_step=self._count) |
| |
|
| |
|