Upload 2 files
Browse files- split-torch-model-v1.py +59 -0
- split-torch-model-v2.py +54 -0
split-torch-model-v1.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
device1 = torch.device("cuda:0")
|
6 |
+
device2 = torch.device("cuda:1")
|
7 |
+
|
8 |
+
class SplitModel(nn.Module):
|
9 |
+
def __init__(self, embedding_layer, dropout_layer, gptj_blocks, layer_norm, lm_head):
|
10 |
+
super(SplitModel, self).__init__()
|
11 |
+
self.embedding_layer = embedding_layer
|
12 |
+
self.dropout_layer = dropout_layer
|
13 |
+
self.gptj_blocks = gptj_blocks
|
14 |
+
self.layer_norm = layer_norm
|
15 |
+
self.lm_head = lm_head
|
16 |
+
|
17 |
+
def forward(self, input_ids, attention_mask):
|
18 |
+
tensor_ids = self.dropout_layer(self.embedding_layer(input_ids))
|
19 |
+
# GPTJBlock is missing the embedding positions that are necessary for self-attention.
|
20 |
+
# To fix this issue, you need to ensure that the position_ids are passed to each GPTJBlock during the forward pass.
|
21 |
+
position_ids = torch.arange(tensor_ids.shape[1], dtype=torch.long, device=tensor_ids.device)
|
22 |
+
for block in self.gptj_blocks:
|
23 |
+
tensor_ids = block(tensor_ids, attention_mask=attention_mask, position_ids=position_ids)[0]
|
24 |
+
tensor_ids = tensor_ids.to(device2)
|
25 |
+
tensor_ids = self.layer_norm(tensor_ids)
|
26 |
+
logits = self.lm_head(tensor_ids)
|
27 |
+
logits = logits.to(device1)
|
28 |
+
return logits
|
29 |
+
|
30 |
+
model_dir = "pt_fp32"
|
31 |
+
model_path = f"{model_dir}/torch_model.pt"
|
32 |
+
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
34 |
+
full_model = torch.load(model_path)
|
35 |
+
|
36 |
+
embedding_layer = full_model.transformer.wte.to(device1)
|
37 |
+
dropout_layer = full_model.transformer.drop.to(device1)
|
38 |
+
gptj_blocks = full_model.transformer.h.to(device1)
|
39 |
+
layer_norm = full_model.transformer.ln_f.to(device2)
|
40 |
+
lm_head = full_model.lm_head.to(device2)
|
41 |
+
|
42 |
+
split_model = SplitModel(embedding_layer, dropout_layer, gptj_blocks, layer_norm, lm_head)
|
43 |
+
|
44 |
+
input_text = "Hi I am Jade and I love"
|
45 |
+
input("Press enter please")
|
46 |
+
input_tokens = tokenizer.encode_plus(input_text, return_tensors="pt").to(device1)
|
47 |
+
input_ids = input_tokens["input_ids"]
|
48 |
+
temperature = 0.8
|
49 |
+
max_new_tokens = 50
|
50 |
+
with torch.no_grad():
|
51 |
+
for _ in range(max_new_tokens):
|
52 |
+
attention_mask = torch.ones_like(input_ids).to(device1)
|
53 |
+
logits = split_model(input_ids, attention_mask)[:, -1] / temperature
|
54 |
+
probabilities = torch.softmax(logits, dim=-1)
|
55 |
+
sampled_token_ids = torch.multinomial(probabilities, num_samples=1)
|
56 |
+
input_ids = torch.cat((input_ids, sampled_token_ids), dim=-1)
|
57 |
+
generated_ids = input_ids.squeeze().tolist()
|
58 |
+
output = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
59 |
+
print(output)
|
split-torch-model-v2.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer
|
2 |
+
import torch
|
3 |
+
|
4 |
+
device1 = torch.device("cuda:0")
|
5 |
+
device2 = torch.device("cuda:1")
|
6 |
+
|
7 |
+
class SplitModel(torch.nn.Module):
|
8 |
+
def __init__(self, base_model):
|
9 |
+
super(SplitModel, self).__init__()
|
10 |
+
self.embedding_layer = base_model.transformer.wte.to(device1)
|
11 |
+
# self.dropout_layer = base_model.transformer.drop.to(device1)
|
12 |
+
self.gptj_blocks1 = torch.nn.ModuleList(base_model.transformer.h[:14]).to(device1)
|
13 |
+
self.gptj_blocks2 = torch.nn.ModuleList(base_model.transformer.h[14:]).to(device2)
|
14 |
+
self.layer_norm = base_model.transformer.ln_f.to(device2)
|
15 |
+
self.lm_head = base_model.lm_head.to(device2)
|
16 |
+
|
17 |
+
def forward(self, input_ids, attention_mask):
|
18 |
+
# tensor_ids = self.dropout_layer(self.embedding_layer(input_ids))
|
19 |
+
tensor_ids = self.embedding_layer(input_ids)
|
20 |
+
position_ids = torch.arange(tensor_ids.shape[1], dtype=torch.long, device=tensor_ids.device)
|
21 |
+
for block in self.gptj_blocks1:
|
22 |
+
tensor_ids = block(tensor_ids, attention_mask=attention_mask, position_ids=position_ids)[0]
|
23 |
+
tensor_ids = tensor_ids.to(device2)
|
24 |
+
position_ids = position_ids.to(device2)
|
25 |
+
attention_mask = attention_mask.to(device2)
|
26 |
+
for block in self.gptj_blocks2:
|
27 |
+
tensor_ids = block(tensor_ids, attention_mask=attention_mask, position_ids=position_ids)[0]
|
28 |
+
tensor_ids = self.layer_norm(tensor_ids)
|
29 |
+
logits = self.lm_head(tensor_ids)
|
30 |
+
return logits.to(device1)
|
31 |
+
|
32 |
+
model_dir = "pt_fp32"
|
33 |
+
model_path = f"{model_dir}/torch_model.pt"
|
34 |
+
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
36 |
+
split_model = SplitModel(torch.load(model_path))
|
37 |
+
|
38 |
+
input_text = "Hi I am Jade and I love"
|
39 |
+
input_tokens = tokenizer.encode_plus(input_text, return_tensors="pt").to(device1)
|
40 |
+
input_ids = input_tokens["input_ids"]
|
41 |
+
temperature = 0.5
|
42 |
+
max_new_tokens = 50
|
43 |
+
with torch.no_grad():
|
44 |
+
# split_model.eval()
|
45 |
+
for _ in range(max_new_tokens):
|
46 |
+
attention_mask = torch.ones_like(input_ids).to(device1)
|
47 |
+
logits = split_model(input_ids, attention_mask)[:, -1] / temperature
|
48 |
+
probabilities = torch.softmax(logits, dim=-1)
|
49 |
+
sampled_token_ids = torch.multinomial(probabilities, num_samples=1)
|
50 |
+
input_ids = torch.cat((input_ids, sampled_token_ids), dim=-1)
|
51 |
+
del logits, probabilities, sampled_token_ids
|
52 |
+
generated_ids = input_ids.squeeze().tolist()
|
53 |
+
output = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
54 |
+
print(output)
|