BlueDice commited on
Commit
4798617
1 Parent(s): d5cf094

Upload 2 files

Browse files
Files changed (2) hide show
  1. split-torch-model-v1.py +59 -0
  2. split-torch-model-v2.py +54 -0
split-torch-model-v1.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from torch import nn
4
+
5
+ device1 = torch.device("cuda:0")
6
+ device2 = torch.device("cuda:1")
7
+
8
+ class SplitModel(nn.Module):
9
+ def __init__(self, embedding_layer, dropout_layer, gptj_blocks, layer_norm, lm_head):
10
+ super(SplitModel, self).__init__()
11
+ self.embedding_layer = embedding_layer
12
+ self.dropout_layer = dropout_layer
13
+ self.gptj_blocks = gptj_blocks
14
+ self.layer_norm = layer_norm
15
+ self.lm_head = lm_head
16
+
17
+ def forward(self, input_ids, attention_mask):
18
+ tensor_ids = self.dropout_layer(self.embedding_layer(input_ids))
19
+ # GPTJBlock is missing the embedding positions that are necessary for self-attention.
20
+ # To fix this issue, you need to ensure that the position_ids are passed to each GPTJBlock during the forward pass.
21
+ position_ids = torch.arange(tensor_ids.shape[1], dtype=torch.long, device=tensor_ids.device)
22
+ for block in self.gptj_blocks:
23
+ tensor_ids = block(tensor_ids, attention_mask=attention_mask, position_ids=position_ids)[0]
24
+ tensor_ids = tensor_ids.to(device2)
25
+ tensor_ids = self.layer_norm(tensor_ids)
26
+ logits = self.lm_head(tensor_ids)
27
+ logits = logits.to(device1)
28
+ return logits
29
+
30
+ model_dir = "pt_fp32"
31
+ model_path = f"{model_dir}/torch_model.pt"
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
34
+ full_model = torch.load(model_path)
35
+
36
+ embedding_layer = full_model.transformer.wte.to(device1)
37
+ dropout_layer = full_model.transformer.drop.to(device1)
38
+ gptj_blocks = full_model.transformer.h.to(device1)
39
+ layer_norm = full_model.transformer.ln_f.to(device2)
40
+ lm_head = full_model.lm_head.to(device2)
41
+
42
+ split_model = SplitModel(embedding_layer, dropout_layer, gptj_blocks, layer_norm, lm_head)
43
+
44
+ input_text = "Hi I am Jade and I love"
45
+ input("Press enter please")
46
+ input_tokens = tokenizer.encode_plus(input_text, return_tensors="pt").to(device1)
47
+ input_ids = input_tokens["input_ids"]
48
+ temperature = 0.8
49
+ max_new_tokens = 50
50
+ with torch.no_grad():
51
+ for _ in range(max_new_tokens):
52
+ attention_mask = torch.ones_like(input_ids).to(device1)
53
+ logits = split_model(input_ids, attention_mask)[:, -1] / temperature
54
+ probabilities = torch.softmax(logits, dim=-1)
55
+ sampled_token_ids = torch.multinomial(probabilities, num_samples=1)
56
+ input_ids = torch.cat((input_ids, sampled_token_ids), dim=-1)
57
+ generated_ids = input_ids.squeeze().tolist()
58
+ output = tokenizer.decode(generated_ids, skip_special_tokens=True)
59
+ print(output)
split-torch-model-v2.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ import torch
3
+
4
+ device1 = torch.device("cuda:0")
5
+ device2 = torch.device("cuda:1")
6
+
7
+ class SplitModel(torch.nn.Module):
8
+ def __init__(self, base_model):
9
+ super(SplitModel, self).__init__()
10
+ self.embedding_layer = base_model.transformer.wte.to(device1)
11
+ # self.dropout_layer = base_model.transformer.drop.to(device1)
12
+ self.gptj_blocks1 = torch.nn.ModuleList(base_model.transformer.h[:14]).to(device1)
13
+ self.gptj_blocks2 = torch.nn.ModuleList(base_model.transformer.h[14:]).to(device2)
14
+ self.layer_norm = base_model.transformer.ln_f.to(device2)
15
+ self.lm_head = base_model.lm_head.to(device2)
16
+
17
+ def forward(self, input_ids, attention_mask):
18
+ # tensor_ids = self.dropout_layer(self.embedding_layer(input_ids))
19
+ tensor_ids = self.embedding_layer(input_ids)
20
+ position_ids = torch.arange(tensor_ids.shape[1], dtype=torch.long, device=tensor_ids.device)
21
+ for block in self.gptj_blocks1:
22
+ tensor_ids = block(tensor_ids, attention_mask=attention_mask, position_ids=position_ids)[0]
23
+ tensor_ids = tensor_ids.to(device2)
24
+ position_ids = position_ids.to(device2)
25
+ attention_mask = attention_mask.to(device2)
26
+ for block in self.gptj_blocks2:
27
+ tensor_ids = block(tensor_ids, attention_mask=attention_mask, position_ids=position_ids)[0]
28
+ tensor_ids = self.layer_norm(tensor_ids)
29
+ logits = self.lm_head(tensor_ids)
30
+ return logits.to(device1)
31
+
32
+ model_dir = "pt_fp32"
33
+ model_path = f"{model_dir}/torch_model.pt"
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
36
+ split_model = SplitModel(torch.load(model_path))
37
+
38
+ input_text = "Hi I am Jade and I love"
39
+ input_tokens = tokenizer.encode_plus(input_text, return_tensors="pt").to(device1)
40
+ input_ids = input_tokens["input_ids"]
41
+ temperature = 0.5
42
+ max_new_tokens = 50
43
+ with torch.no_grad():
44
+ # split_model.eval()
45
+ for _ in range(max_new_tokens):
46
+ attention_mask = torch.ones_like(input_ids).to(device1)
47
+ logits = split_model(input_ids, attention_mask)[:, -1] / temperature
48
+ probabilities = torch.softmax(logits, dim=-1)
49
+ sampled_token_ids = torch.multinomial(probabilities, num_samples=1)
50
+ input_ids = torch.cat((input_ids, sampled_token_ids), dim=-1)
51
+ del logits, probabilities, sampled_token_ids
52
+ generated_ids = input_ids.squeeze().tolist()
53
+ output = tokenizer.decode(generated_ids, skip_special_tokens=True)
54
+ print(output)