DanielHesslow
commited on
Commit
•
bccb539
1
Parent(s):
8ca1e4f
add model
Browse files- config.json +1 -1
- rita_modeling.py +42 -9
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
"RITAModel"
|
5 |
],
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "Seledorn/RITA_s",
|
3 |
"architectures": [
|
4 |
"RITAModel"
|
5 |
],
|
rita_modeling.py
CHANGED
@@ -13,6 +13,7 @@ from transformers.modeling_outputs import (
|
|
13 |
BaseModelOutputWithPastAndCrossAttentions,
|
14 |
CausalLMOutputWithCrossAttentions,
|
15 |
CausalLMOutputWithPast,
|
|
|
16 |
)
|
17 |
|
18 |
from transformers.modeling_utils import PreTrainedModel
|
@@ -222,18 +223,50 @@ class RITAModel(PreTrainedModel):
|
|
222 |
self.final_norm = nn.LayerNorm(config.d_model)
|
223 |
self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
|
224 |
|
225 |
-
def forward(
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
for layer in self.layers:
|
230 |
-
x = layer(x, attn_mask=
|
231 |
x = self.final_norm(x) # N x L x D
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
#Some common HF functions.
|
239 |
def get_input_embeddings(self):
|
|
|
13 |
BaseModelOutputWithPastAndCrossAttentions,
|
14 |
CausalLMOutputWithCrossAttentions,
|
15 |
CausalLMOutputWithPast,
|
16 |
+
CausalLMOutput,
|
17 |
)
|
18 |
|
19 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
223 |
self.final_norm = nn.LayerNorm(config.d_model)
|
224 |
self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
|
225 |
|
226 |
+
def forward(
|
227 |
+
self,
|
228 |
+
input_ids=None,
|
229 |
+
past_key_values=None,
|
230 |
+
attention_mask=None,
|
231 |
+
token_type_ids=None,
|
232 |
+
position_ids=None,
|
233 |
+
head_mask=None,
|
234 |
+
inputs_embeds=None,
|
235 |
+
encoder_hidden_states=None,
|
236 |
+
encoder_attention_mask=None,
|
237 |
+
labels=None,
|
238 |
+
use_cache=None,
|
239 |
+
output_attentions=None,
|
240 |
+
output_hidden_states=None,
|
241 |
+
return_dict=None) -> torch.FloatTensor:
|
242 |
+
|
243 |
+
if inputs_embeds == None:
|
244 |
+
x = self.embedding(input_ids) # N x L x D
|
245 |
+
else:
|
246 |
+
x = inputs_embeds
|
247 |
+
|
248 |
+
if attention_mask == None:
|
249 |
+
attention_mask = (torch.triu(torch.ones(input_ids.size(1), input_ids.size(1))) == 0).transpose(0, 1).contiguous().to(input_ids.device)
|
250 |
for layer in self.layers:
|
251 |
+
x = layer(x, attn_mask=attention_mask)
|
252 |
x = self.final_norm(x) # N x L x D
|
253 |
|
254 |
+
logits = self.projector(x)
|
255 |
+
loss = None
|
256 |
+
if labels is not None:
|
257 |
+
# Shift so that tokens < n predict n
|
258 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
259 |
+
shift_labels = labels[..., 1:].contiguous()
|
260 |
+
# Flatten the tokens
|
261 |
+
loss_fct = CrossEntropyLoss()
|
262 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
263 |
+
|
264 |
+
return CausalLMOutput(
|
265 |
+
loss=loss,
|
266 |
+
logits=logits,
|
267 |
+
hidden_states=x,
|
268 |
+
)
|
269 |
+
|
270 |
|
271 |
#Some common HF functions.
|
272 |
def get_input_embeddings(self):
|