technicolor commited on
Commit
7dfc25a
1 Parent(s): 086ec01

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -6
README.md CHANGED
@@ -9,7 +9,8 @@ To do:
9
  4. Loss function.
10
 
11
  To run TE_Embedding model:
12
- `import os
 
13
  from transformers import (AutoConfig,
14
  AutoTokenizer,AutoModelForCausalLM
15
  )
@@ -30,7 +31,7 @@ class TEmbeddingModel(torch.nn.Module):
30
  [torch.nn.Linear(self.hidden_size, self.hidden_size//len(self.prompt_suffixes))
31
  for _ in range(len(self.prompt_suffixes))])
32
  self.tokenizer, self.llama = self.load_llama()
33
- self.device = torch.device('cuda')
34
  self.tanh = torch.nn.Tanh()
35
  self.suffixes_ids = []
36
  self.suffixes_ids_len = []
@@ -79,12 +80,12 @@ class TEmbeddingModel(torch.nn.Module):
79
  suffixes_ones = self.suffixes_ones.unsqueeze(0)
80
  suffixes_ones = suffixes_ones.repeat(batch_size, 1)
81
  device = next(self.parameters()).device
82
- attention_mask = torch.cat([attention_mask, suffixes_ones], dim=-1).to('cuda')
83
 
84
  suffixes_ids = self.suffixes_ids.unsqueeze(0)
85
  suffixes_ids = suffixes_ids.repeat(batch_size, 1)
86
- input_ids = torch.cat([input_ids, suffixes_ids], dim=-1).to('cuda')
87
- last_hidden_state = self.llama.base_model(attention_mask=attention_mask, input_ids=input_ids).last_hidden_state.to('cuda')
88
  index = -1
89
  for i in range(len(self.suffixes_ids_len)):
90
  embedding = last_hidden_state[:, index, :]
@@ -119,4 +120,5 @@ if __name__ == "__main__":
119
  output = TE_model(["Hello", "Nice to meet you"])
120
  cos_sim = F.cosine_similarity(output[0],output[1],dim=0)
121
  print(cos_sim)
122
- `
 
 
9
  4. Loss function.
10
 
11
  To run TE_Embedding model:
12
+ ```python
13
+ import os
14
  from transformers import (AutoConfig,
15
  AutoTokenizer,AutoModelForCausalLM
16
  )
 
31
  [torch.nn.Linear(self.hidden_size, self.hidden_size//len(self.prompt_suffixes))
32
  for _ in range(len(self.prompt_suffixes))])
33
  self.tokenizer, self.llama = self.load_llama()
34
+ # self.device = torch.device('cuda')
35
  self.tanh = torch.nn.Tanh()
36
  self.suffixes_ids = []
37
  self.suffixes_ids_len = []
 
80
  suffixes_ones = self.suffixes_ones.unsqueeze(0)
81
  suffixes_ones = suffixes_ones.repeat(batch_size, 1)
82
  device = next(self.parameters()).device
83
+ attention_mask = torch.cat([attention_mask, suffixes_ones], dim=-1).to(device)
84
 
85
  suffixes_ids = self.suffixes_ids.unsqueeze(0)
86
  suffixes_ids = suffixes_ids.repeat(batch_size, 1)
87
+ input_ids = torch.cat([input_ids, suffixes_ids], dim=-1) #to("cuda")
88
+ last_hidden_state = self.llama.base_model(attention_mask=attention_mask, input_ids=input_ids).last_hidden_state.to(device)
89
  index = -1
90
  for i in range(len(self.suffixes_ids_len)):
91
  embedding = last_hidden_state[:, index, :]
 
120
  output = TE_model(["Hello", "Nice to meet you"])
121
  cos_sim = F.cosine_similarity(output[0],output[1],dim=0)
122
  print(cos_sim)
123
+
124
+ ```