|
import warnings |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
|
|
from transformers.utils import ModelOutput |
|
|
|
@dataclass |
|
class BaseModelOutputWithPastAndSeer(ModelOutput): |
|
last_hidden_state: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
mask_gate_predictions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
mask_ground_truths: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
mask_loss: torch.FloatTensor = None |
|
|
|
|
|
@dataclass |
|
class CausalLMOutputWithPastAndSeer(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
mask_gate_predictions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
mask_ground_truths: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
mask_loss: torch.FloatTensor = None |
|
|
|
|