LongMountain's picture
first model commit
3ef8780
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers.utils import ModelOutput
@dataclass
class BaseModelOutputWithPastAndSeer(ModelOutput):
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
mask_gate_predictions: Optional[Tuple[torch.FloatTensor, ...]] = None
mask_ground_truths: Optional[Tuple[torch.FloatTensor, ...]] = None
mask_loss: torch.FloatTensor = None
@dataclass
class CausalLMOutputWithPastAndSeer(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
mask_gate_predictions: Optional[Tuple[torch.FloatTensor, ...]] = None
mask_ground_truths: Optional[Tuple[torch.FloatTensor, ...]] = None
mask_loss: torch.FloatTensor = None