Update README.md
Browse files
README.md
CHANGED
@@ -48,19 +48,22 @@ if __name__ == '__main__':
|
|
48 |
from transformers import AutoTokenizer, AutoModel, DataCollatorWithPadding
|
49 |
import torch
|
50 |
from torch.utils.data import DataLoader
|
|
|
|
|
51 |
|
52 |
-
device = torch.device('
|
|
|
53 |
|
54 |
# Sentences we want sentence embeddings for
|
55 |
sentences = ['This is an example sentence', 'Each sentence is converted']
|
56 |
|
57 |
# Load model from HuggingFace Hub
|
58 |
-
tokenizer = AutoTokenizer.from_pretrained('
|
59 |
-
collator = DataCollatorWithPadding(tokenizer)
|
60 |
-
model = AutoModel.from_pretrained('
|
61 |
|
62 |
tokenized_data = tokenizer(sentences, padding=True, truncation=True)
|
63 |
-
tokenized_data =
|
64 |
dataloader = DataLoader(tokenized_data, batch_size=batch_size, pin_memory=True, collate_fn=collator)
|
65 |
all_outputs = torch.zeros((len(tokenized_data), 1024)).to(device)
|
66 |
start_idx = 0
|
@@ -69,7 +72,7 @@ start_idx = 0
|
|
69 |
with torch.no_grad():
|
70 |
for inputs in tqdm(dataloader):
|
71 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
72 |
-
representations, _ =
|
73 |
attention_mask = inputs["attention_mask"]
|
74 |
input_mask_expanded = (attention_mask.unsqueeze(-1).expand(representations.size()).to(representations.dtype))
|
75 |
summed = torch.sum(representations * input_mask_expanded, 1)
|
|
|
48 |
from transformers import AutoTokenizer, AutoModel, DataCollatorWithPadding
|
49 |
import torch
|
50 |
from torch.utils.data import DataLoader
|
51 |
+
from tqdm import tqdm
|
52 |
+
from datasets import Dataset
|
53 |
|
54 |
+
device = torch.device('cpu')
|
55 |
+
batch_size=1
|
56 |
|
57 |
# Sentences we want sentence embeddings for
|
58 |
sentences = ['This is an example sentence', 'Each sentence is converted']
|
59 |
|
60 |
# Load model from HuggingFace Hub
|
61 |
+
tokenizer = AutoTokenizer.from_pretrained('sorryhyun/sentence-embedding-klue-large')
|
62 |
+
collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
63 |
+
model = AutoModel.from_pretrained('sorryhyun/sentence-embedding-klue-large').to(device)
|
64 |
|
65 |
tokenized_data = tokenizer(sentences, padding=True, truncation=True)
|
66 |
+
tokenized_data = Dataset.from_dict(tokenized_data)
|
67 |
dataloader = DataLoader(tokenized_data, batch_size=batch_size, pin_memory=True, collate_fn=collator)
|
68 |
all_outputs = torch.zeros((len(tokenized_data), 1024)).to(device)
|
69 |
start_idx = 0
|
|
|
72 |
with torch.no_grad():
|
73 |
for inputs in tqdm(dataloader):
|
74 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
75 |
+
representations, _ = model(**inputs, return_dict=False)
|
76 |
attention_mask = inputs["attention_mask"]
|
77 |
input_mask_expanded = (attention_mask.unsqueeze(-1).expand(representations.size()).to(representations.dtype))
|
78 |
summed = torch.sum(representations * input_mask_expanded, 1)
|