abhi-mosaic
commited on
Commit
•
8b737ec
1
Parent(s):
6ec8c48
Update modeling_mpt.py
Browse files- modeling_mpt.py +12 -3
modeling_mpt.py
CHANGED
@@ -56,7 +56,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
56 |
for module in self.modules():
|
57 |
if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
|
58 |
if config.verbose:
|
59 |
-
|
60 |
module.register_parameter('bias', None)
|
61 |
if config.verbose and config.verbose > 2:
|
62 |
print(self)
|
@@ -131,6 +131,10 @@ class MPTModel(MPTPreTrainedModel):
|
|
131 |
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
132 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
133 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
|
|
|
|
|
134 |
if not return_dict:
|
135 |
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
136 |
if output_attentions:
|
@@ -228,7 +232,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
228 |
def get_decoder(self):
|
229 |
return self.transformer
|
230 |
|
231 |
-
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
232 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
233 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
234 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
@@ -237,7 +241,12 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
237 |
if self.logit_scale == 0:
|
238 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
239 |
logits *= self.logit_scale
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
def param_init_fn(self, module):
|
243 |
init_fn_name = self.config.init_config['name']
|
|
|
56 |
for module in self.modules():
|
57 |
if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
|
58 |
if config.verbose:
|
59 |
+
warnings.warn(f'Removing bias ({module.bias}) from {module}.')
|
60 |
module.register_parameter('bias', None)
|
61 |
if config.verbose and config.verbose > 2:
|
62 |
print(self)
|
|
|
131 |
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
132 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
133 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
134 |
+
if attention_mask is not None:
|
135 |
+
attention_mask = attention_mask.bool()
|
136 |
+
if prefix_mask is not None:
|
137 |
+
prefix_mask = prefix_mask.bool()
|
138 |
if not return_dict:
|
139 |
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
140 |
if output_attentions:
|
|
|
232 |
def get_decoder(self):
|
233 |
return self.transformer
|
234 |
|
235 |
+
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
236 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
237 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
238 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
|
|
241 |
if self.logit_scale == 0:
|
242 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
243 |
logits *= self.logit_scale
|
244 |
+
loss = None
|
245 |
+
if labels is not None:
|
246 |
+
labels = torch.roll(labels, shifts=-1)
|
247 |
+
labels[:, -1] = -100
|
248 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
249 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
|
250 |
|
251 |
def param_init_fn(self, module):
|
252 |
init_fn_name = self.config.init_config['name']
|