# -*- coding: utf-8 -*- # @Time : 2023/5/6 4:12 p.m. # @Author : JianingWang # @File : critic.py from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoConfig from models.basic_modules.generation import generate def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: tensor = tensor * mask tensor = tensor.sum(dim=dim) mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean """ Critic model. """ class Critic(nn.Module): """ Critic model base class. Args: model (nn.Module): Critic model. value_head (nn.Module): Value head to get value. """ def __init__( self, model: nn.Module, value_head: nn.Module, use_action_mask: bool = False, ) -> None: self.model = model self.value_head = value_head # critic layer for predict value function self.use_action_mask = use_action_mask def forward(self, sequences: torch.LongTensor, action_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) last_hidden_states = outputs['last_hidden_state'] values = self.value_head(last_hidden_states).squeeze(-1) if action_mask is not None and self.use_action_mask: num_actions = action_mask.size(1) prompt_mask = attention_mask[:, :-num_actions] values = values[:, :-num_actions] value = masked_mean(values, prompt_mask, dim=1) return value values = values[:, :-1] value = values.mean(dim=1) return value """ Auto Model for Critic """ class AutoModelCritic(Critic): """ AutoModel Critic model. Args: pretrained (str): Pretrained model name or path. config (AutoConfig): Model config. checkpoint (bool): Enable gradient checkpointing. """ def __init__(self, pretrained: Optional[str] = None, config: Optional[AutoConfig] = None, checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none', **kwargs) -> None: if pretrained is not None: model = AutoModel.from_pretrained(pretrained) elif config is not None: model = AutoModel(config) else: model = AutoModel(AutoConfig()) if checkpoint: model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.word_embed_proj_dim, 1) super().__init__(model, value_head, **kwargs)