add peft compatibility code
Browse files- mosaic_gpt.py +5 -0
mosaic_gpt.py
CHANGED
@@ -236,6 +236,7 @@ class MosaicGPT(PreTrainedModel):
|
|
236 |
def forward(
|
237 |
self,
|
238 |
input_ids: torch.LongTensor,
|
|
|
239 |
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
240 |
attention_mask: Optional[torch.ByteTensor] = None,
|
241 |
prefix_mask: Optional[torch.ByteTensor] = None,
|
@@ -243,7 +244,11 @@ class MosaicGPT(PreTrainedModel):
|
|
243 |
return_dict: Optional[bool] = None,
|
244 |
output_attentions: Optional[bool] = None,
|
245 |
output_hidden_states: Optional[bool] = None,
|
|
|
246 |
use_cache: Optional[bool] = None):
|
|
|
|
|
|
|
247 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
248 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
249 |
|
|
|
236 |
def forward(
|
237 |
self,
|
238 |
input_ids: torch.LongTensor,
|
239 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
240 |
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
241 |
attention_mask: Optional[torch.ByteTensor] = None,
|
242 |
prefix_mask: Optional[torch.ByteTensor] = None,
|
|
|
244 |
return_dict: Optional[bool] = None,
|
245 |
output_attentions: Optional[bool] = None,
|
246 |
output_hidden_states: Optional[bool] = None,
|
247 |
+
labels: Optional[torch.LongTensor] = None,
|
248 |
use_cache: Optional[bool] = None):
|
249 |
+
|
250 |
+
assert inputs_embeds is None # for compatibility for PEFT LoRA
|
251 |
+
assert labels is None # for compatibility for PEFT LoRA
|
252 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
253 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
254 |
|