Update README.md
Browse files
README.md
CHANGED
@@ -19,7 +19,7 @@ def extract_feature(protein):
|
|
19 |
input_ids = torch.tensor(ids['input_ids']).to(self.device)
|
20 |
attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
|
21 |
with torch.no_grad():
|
22 |
-
embedding_repr = self.model(
|
23 |
return torch.mean(embedding_repr)
|
24 |
```
|
25 |
|
@@ -32,7 +32,7 @@ def extract_features_batch(proteins):
|
|
32 |
input_ids = torch.tensor(ids['input_ids']).to(self.device)
|
33 |
attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
|
34 |
with torch.no_grad():
|
35 |
-
embedding_repr = self.model(
|
36 |
attention_mask = attention_mask.unsqueeze(-1)
|
37 |
attention_mask = attention_mask.expand(-1, -1, embedding_repr.size(-1))
|
38 |
masked_embedding_repr = embedding_repr * attention_mask
|
|
|
19 |
input_ids = torch.tensor(ids['input_ids']).to(self.device)
|
20 |
attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
|
21 |
with torch.no_grad():
|
22 |
+
embedding_repr = self.model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
|
23 |
return torch.mean(embedding_repr)
|
24 |
```
|
25 |
|
|
|
32 |
input_ids = torch.tensor(ids['input_ids']).to(self.device)
|
33 |
attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
|
34 |
with torch.no_grad():
|
35 |
+
embedding_repr = self.model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
|
36 |
attention_mask = attention_mask.unsqueeze(-1)
|
37 |
attention_mask = attention_mask.expand(-1, -1, embedding_repr.size(-1))
|
38 |
masked_embedding_repr = embedding_repr * attention_mask
|