h4duan commited on
Commit
3439ca7
1 Parent(s): f588852

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -19,7 +19,7 @@ def extract_feature(protein):
19
  input_ids = torch.tensor(ids['input_ids']).to(self.device)
20
  attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
21
  with torch.no_grad():
22
- embedding_repr = self.model(output_hidden_states=True, input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
23
  return torch.mean(embedding_repr)
24
  ```
25
 
@@ -32,7 +32,7 @@ def extract_features_batch(proteins):
32
  input_ids = torch.tensor(ids['input_ids']).to(self.device)
33
  attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
34
  with torch.no_grad():
35
- embedding_repr = self.model(output_hidden_states=True, input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
36
  attention_mask = attention_mask.unsqueeze(-1)
37
  attention_mask = attention_mask.expand(-1, -1, embedding_repr.size(-1))
38
  masked_embedding_repr = embedding_repr * attention_mask
 
19
  input_ids = torch.tensor(ids['input_ids']).to(self.device)
20
  attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
21
  with torch.no_grad():
22
+ embedding_repr = self.model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
23
  return torch.mean(embedding_repr)
24
  ```
25
 
 
32
  input_ids = torch.tensor(ids['input_ids']).to(self.device)
33
  attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
34
  with torch.no_grad():
35
+ embedding_repr = self.model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
36
  attention_mask = attention_mask.unsqueeze(-1)
37
  attention_mask = attention_mask.expand(-1, -1, embedding_repr.size(-1))
38
  masked_embedding_repr = embedding_repr * attention_mask