Upload CodeGenMeasurementPredictor
Browse files- README.md +199 -0
- config.json +53 -0
- configuration_code_gen_measuremet_pred.py +11 -0
- configuration_measurement_pred.py +27 -0
- model.safetensors +3 -0
- modeling_code_gen_measurement_pred.py +13 -0
- modeling_measurement_pred.py +102 -0
README.md
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
tags: []
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Model ID
|
7 |
+
|
8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
<!-- Provide a longer summary of what this model is. -->
|
17 |
+
|
18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
19 |
+
|
20 |
+
- **Developed by:** [More Information Needed]
|
21 |
+
- **Funded by [optional]:** [More Information Needed]
|
22 |
+
- **Shared by [optional]:** [More Information Needed]
|
23 |
+
- **Model type:** [More Information Needed]
|
24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
25 |
+
- **License:** [More Information Needed]
|
26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
+
|
28 |
+
### Model Sources [optional]
|
29 |
+
|
30 |
+
<!-- Provide the basic links for the model. -->
|
31 |
+
|
32 |
+
- **Repository:** [More Information Needed]
|
33 |
+
- **Paper [optional]:** [More Information Needed]
|
34 |
+
- **Demo [optional]:** [More Information Needed]
|
35 |
+
|
36 |
+
## Uses
|
37 |
+
|
38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
[More Information Needed]
|
45 |
+
|
46 |
+
### Downstream Use [optional]
|
47 |
+
|
48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
49 |
+
|
50 |
+
[More Information Needed]
|
51 |
+
|
52 |
+
### Out-of-Scope Use
|
53 |
+
|
54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
+
|
56 |
+
[More Information Needed]
|
57 |
+
|
58 |
+
## Bias, Risks, and Limitations
|
59 |
+
|
60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
61 |
+
|
62 |
+
[More Information Needed]
|
63 |
+
|
64 |
+
### Recommendations
|
65 |
+
|
66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
67 |
+
|
68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
69 |
+
|
70 |
+
## How to Get Started with the Model
|
71 |
+
|
72 |
+
Use the code below to get started with the model.
|
73 |
+
|
74 |
+
[More Information Needed]
|
75 |
+
|
76 |
+
## Training Details
|
77 |
+
|
78 |
+
### Training Data
|
79 |
+
|
80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
81 |
+
|
82 |
+
[More Information Needed]
|
83 |
+
|
84 |
+
### Training Procedure
|
85 |
+
|
86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
+
|
88 |
+
#### Preprocessing [optional]
|
89 |
+
|
90 |
+
[More Information Needed]
|
91 |
+
|
92 |
+
|
93 |
+
#### Training Hyperparameters
|
94 |
+
|
95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
+
|
97 |
+
#### Speeds, Sizes, Times [optional]
|
98 |
+
|
99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
+
|
101 |
+
[More Information Needed]
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
|
105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
+
|
107 |
+
### Testing Data, Factors & Metrics
|
108 |
+
|
109 |
+
#### Testing Data
|
110 |
+
|
111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
112 |
+
|
113 |
+
[More Information Needed]
|
114 |
+
|
115 |
+
#### Factors
|
116 |
+
|
117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
+
|
119 |
+
[More Information Needed]
|
120 |
+
|
121 |
+
#### Metrics
|
122 |
+
|
123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
+
|
125 |
+
[More Information Needed]
|
126 |
+
|
127 |
+
### Results
|
128 |
+
|
129 |
+
[More Information Needed]
|
130 |
+
|
131 |
+
#### Summary
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## Model Examination [optional]
|
136 |
+
|
137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
138 |
+
|
139 |
+
[More Information Needed]
|
140 |
+
|
141 |
+
## Environmental Impact
|
142 |
+
|
143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
+
|
145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
+
|
147 |
+
- **Hardware Type:** [More Information Needed]
|
148 |
+
- **Hours used:** [More Information Needed]
|
149 |
+
- **Cloud Provider:** [More Information Needed]
|
150 |
+
- **Compute Region:** [More Information Needed]
|
151 |
+
- **Carbon Emitted:** [More Information Needed]
|
152 |
+
|
153 |
+
## Technical Specifications [optional]
|
154 |
+
|
155 |
+
### Model Architecture and Objective
|
156 |
+
|
157 |
+
[More Information Needed]
|
158 |
+
|
159 |
+
### Compute Infrastructure
|
160 |
+
|
161 |
+
[More Information Needed]
|
162 |
+
|
163 |
+
#### Hardware
|
164 |
+
|
165 |
+
[More Information Needed]
|
166 |
+
|
167 |
+
#### Software
|
168 |
+
|
169 |
+
[More Information Needed]
|
170 |
+
|
171 |
+
## Citation [optional]
|
172 |
+
|
173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
+
|
175 |
+
**BibTeX:**
|
176 |
+
|
177 |
+
[More Information Needed]
|
178 |
+
|
179 |
+
**APA:**
|
180 |
+
|
181 |
+
[More Information Needed]
|
182 |
+
|
183 |
+
## Glossary [optional]
|
184 |
+
|
185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
+
|
187 |
+
[More Information Needed]
|
188 |
+
|
189 |
+
## More Information [optional]
|
190 |
+
|
191 |
+
[More Information Needed]
|
192 |
+
|
193 |
+
## Model Card Authors [optional]
|
194 |
+
|
195 |
+
[More Information Needed]
|
196 |
+
|
197 |
+
## Model Card Contact
|
198 |
+
|
199 |
+
[More Information Needed]
|
config.json
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "multirun/2024-12-16/09-44-11/0/checkpoint-3905",
|
3 |
+
"activation_function": "gelu_new",
|
4 |
+
"aggregate_weight": 0.3,
|
5 |
+
"architectures": [
|
6 |
+
"CodeGenMeasurementPredictor"
|
7 |
+
],
|
8 |
+
"attn_pdrop": 0.0,
|
9 |
+
"auto_map": {
|
10 |
+
"AutoConfig": "configuration_code_gen_measuremet_pred.CodeGenMeasurementPredictorConfig",
|
11 |
+
"AutoModelForSequenceClassification": "modeling_code_gen_measurement_pred.CodeGenMeasurementPredictor"
|
12 |
+
},
|
13 |
+
"bos_token_id": 1,
|
14 |
+
"emb_dim": 1024,
|
15 |
+
"embd_pdrop": 0.0,
|
16 |
+
"eos_token_id": 50256,
|
17 |
+
"gradient_checkpointing": false,
|
18 |
+
"initializer_range": 0.02,
|
19 |
+
"layer_norm_epsilon": 1e-05,
|
20 |
+
"model_type": "codegen_mp",
|
21 |
+
"n_ctx": 2048,
|
22 |
+
"n_embd": 1024,
|
23 |
+
"n_head": 16,
|
24 |
+
"n_inner": null,
|
25 |
+
"n_layer": 20,
|
26 |
+
"n_positions": 2048,
|
27 |
+
"n_sensors": 3,
|
28 |
+
"resid_pdrop": 0.0,
|
29 |
+
"rotary_dim": 32,
|
30 |
+
"scale_attn_weights": true,
|
31 |
+
"sensor_token": " omit",
|
32 |
+
"sensor_token_id": 42848,
|
33 |
+
"sensors_weight": 0.7,
|
34 |
+
"summary_activation": null,
|
35 |
+
"summary_first_dropout": 0.1,
|
36 |
+
"summary_proj_to_labels": true,
|
37 |
+
"summary_type": "cls_index",
|
38 |
+
"summary_use_proj": true,
|
39 |
+
"task_specific_params": {
|
40 |
+
"text-generation": {
|
41 |
+
"do_sample": true,
|
42 |
+
"max_length": 50,
|
43 |
+
"temperature": 1.0
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"tie_word_embeddings": false,
|
47 |
+
"tokenizer_class": "GPT2Tokenizer",
|
48 |
+
"torch_dtype": "float32",
|
49 |
+
"transformers_version": "4.41.0",
|
50 |
+
"use_aggregated": true,
|
51 |
+
"use_cache": false,
|
52 |
+
"vocab_size": 51200
|
53 |
+
}
|
configuration_code_gen_measuremet_pred.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.codegen import CodeGenConfig
|
2 |
+
from .configuration_measurement_pred import MeasurementPredictorConfig
|
3 |
+
|
4 |
+
class CodeGenMeasurementPredictorConfig(MeasurementPredictorConfig, CodeGenConfig):
|
5 |
+
model_type = "codegen_mp"
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
kwargs["sensor_token_id"] = 42848
|
8 |
+
super().__init__(**kwargs)
|
9 |
+
|
10 |
+
def get_emb_dim(self):
|
11 |
+
return self.n_embd
|
configuration_measurement_pred.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from transformers import PretrainedConfig
|
3 |
+
|
4 |
+
class MeasurementPredictorConfig(PretrainedConfig):
|
5 |
+
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
sensor_token=" omit",
|
9 |
+
sensor_token_id=None, # 35991
|
10 |
+
n_sensors=3,
|
11 |
+
use_aggregated=True,
|
12 |
+
sensors_weight = 0.7,
|
13 |
+
aggregate_weight=0.3,
|
14 |
+
**kwargs
|
15 |
+
):
|
16 |
+
self.sensor_token = sensor_token
|
17 |
+
self.sensor_token_id = sensor_token_id
|
18 |
+
self.n_sensors = n_sensors
|
19 |
+
self.use_aggregated = use_aggregated
|
20 |
+
self.sensors_weight = sensors_weight
|
21 |
+
self.aggregate_weight = aggregate_weight
|
22 |
+
super().__init__(**kwargs)
|
23 |
+
self.emb_dim = self.get_emb_dim()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def get_emb_dim(self):
|
27 |
+
raise NotImplementedError
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ab5328afaddfbf32103575e7db654ebfc0661323ecb891dee5caa9a8d1da284
|
3 |
+
size 1216963976
|
modeling_code_gen_measurement_pred.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.codegen import CodeGenPreTrainedModel, CodeGenModel
|
2 |
+
|
3 |
+
from .modeling_measurement_pred import MeasurementPredictorMixin
|
4 |
+
from .configuration_code_gen_measuremet_pred import CodeGenMeasurementPredictorConfig
|
5 |
+
|
6 |
+
|
7 |
+
class CodeGenMeasurementPredictor(CodeGenPreTrainedModel, MeasurementPredictorMixin):
|
8 |
+
config_class = CodeGenMeasurementPredictorConfig
|
9 |
+
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__(config)
|
12 |
+
self.transformer = CodeGenModel(config)
|
13 |
+
self.post_init()
|
modeling_measurement_pred.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import BCEWithLogitsLoss
|
5 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
6 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
|
7 |
+
|
8 |
+
class MeasurementPredictorMixin(PreTrainedModel):
|
9 |
+
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__(config)
|
12 |
+
self.sensor_token = config.sensor_token
|
13 |
+
self.sensor_token_id = config.sensor_token_id
|
14 |
+
self.n_sensors = config.n_sensors
|
15 |
+
self.sensor_probes = torch.nn.ModuleList([
|
16 |
+
torch.nn.Linear(config.emb_dim, 1) for _ in range(config.n_sensors)
|
17 |
+
])
|
18 |
+
self.use_aggregated = config.use_aggregated
|
19 |
+
if config.use_aggregated:
|
20 |
+
self.aggregate_probe = torch.nn.Linear(config.emb_dim, 1)
|
21 |
+
self.sensors_weight = config.sensors_weight
|
22 |
+
self.aggregate_weight = config.aggregate_weight
|
23 |
+
|
24 |
+
def check_tokenizer(self, tokenizer: PreTrainedTokenizer):
|
25 |
+
sensor_token_id = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(self.sensor_token))[0]
|
26 |
+
assert sensor_token_id == self.sensor_token_id
|
27 |
+
|
28 |
+
def set_sensor_token(self, sensor_token: str, tokenizer: PreTrainedTokenizer):
|
29 |
+
sensor_token_id = tokenizer.tokenize(sensor_token)[0]
|
30 |
+
self.sensor_token = sensor_token
|
31 |
+
self.sensor_token_id = sensor_token_id
|
32 |
+
|
33 |
+
def forward(
|
34 |
+
self,
|
35 |
+
input_ids: Optional[torch.LongTensor] = None,
|
36 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
37 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
38 |
+
position_ids: Optional[torch.LongTensor] = None,
|
39 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
40 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
41 |
+
labels: Optional[torch.LongTensor] = None,
|
42 |
+
use_cache: Optional[bool] = None,
|
43 |
+
output_attentions: Optional[bool] = None,
|
44 |
+
output_hidden_states: Optional[bool] = None,
|
45 |
+
return_dict: Optional[bool] = None,
|
46 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
47 |
+
r"""
|
48 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
49 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
50 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
51 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
52 |
+
"""
|
53 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
54 |
+
|
55 |
+
base_model_output: BaseModelOutputWithPast = self.base_model(
|
56 |
+
input_ids,
|
57 |
+
past_key_values=past_key_values,
|
58 |
+
attention_mask=attention_mask,
|
59 |
+
position_ids=position_ids,
|
60 |
+
head_mask=head_mask,
|
61 |
+
inputs_embeds=inputs_embeds,
|
62 |
+
use_cache=use_cache,
|
63 |
+
output_attentions=output_attentions,
|
64 |
+
output_hidden_states=output_hidden_states,
|
65 |
+
return_dict=return_dict,
|
66 |
+
)
|
67 |
+
flat_tensor_token_idxs = (input_ids == self.sensor_token_id).nonzero(as_tuple=True)[1]
|
68 |
+
tensor_token_idxs = flat_tensor_token_idxs.view(-1, self.n_sensors)
|
69 |
+
sensor_embs = base_model_output.last_hidden_state.gather(
|
70 |
+
1, tensor_token_idxs.unsqueeze(-1).expand(-1, -1, self.config.emb_dim)
|
71 |
+
)
|
72 |
+
assert sensor_embs.shape == (input_ids.shape[0], self.n_sensors, self.config.emb_dim), f"{sensor_embs.shape} != {(input_ids.shape[0], self.n_sensors, self.config.emb_dim)}"
|
73 |
+
sensor_logits = torch.concat([self.sensor_probes[i](sensor_embs[:, i, :])
|
74 |
+
for i in range(self.n_sensors)], dim=-1)
|
75 |
+
logits = sensor_logits
|
76 |
+
|
77 |
+
if self.use_aggregated:
|
78 |
+
last_emb = base_model_output.last_hidden_state[:, -1, :]
|
79 |
+
aggregate_logits = self.aggregate_probe(last_emb)
|
80 |
+
logits = torch.concat([logits, aggregate_logits], dim=-1)
|
81 |
+
|
82 |
+
loss = None
|
83 |
+
if labels is not None:
|
84 |
+
loss_fct = BCEWithLogitsLoss()
|
85 |
+
sensor_loss = loss_fct(sensor_logits, labels[:, :self.n_sensors]) * self.sensors_weight
|
86 |
+
loss = sensor_loss
|
87 |
+
if self.use_aggregated: #TOOD: should be use aggregate
|
88 |
+
aggregate_loss = loss_fct(aggregate_logits, labels[:, -1:]) * self.aggregate_weight
|
89 |
+
loss += aggregate_loss
|
90 |
+
|
91 |
+
if not return_dict:
|
92 |
+
output = (logits, ) + base_model_output[1:]
|
93 |
+
return ((loss,) + output) if loss is not None else output
|
94 |
+
|
95 |
+
return SequenceClassifierOutputWithPast(
|
96 |
+
loss=loss,
|
97 |
+
logits=logits,
|
98 |
+
past_key_values=base_model_output.past_key_values,
|
99 |
+
hidden_states=base_model_output.hidden_states,
|
100 |
+
attentions=base_model_output.attentions,
|
101 |
+
)
|
102 |
+
|