sonoisa commited on
Commit
2b5b57b
1 Parent(s): 7f8bb72
Files changed (1) hide show
  1. README.md +64 -0
README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: "ja"
3
+ license: "cc-by-sa-3.0"
4
+ tags:
5
+ - sentence-transformers
6
+ - sentence-bert
7
+ - feature-extraction
8
+ - sentence-similarity
9
+ ---
10
+
11
+ This is a Japanese sentence-BERT model.
12
+
13
+ 日本語用Sentence-BERTモデルです。
14
+
15
+ # 使い方
16
+
17
+ ```python:
18
+ from transformers import BertJapaneseTokenizer, BertModel
19
+ import torch
20
+
21
+
22
+ class SentenceBertJapanese:
23
+ def __init__(self, model_name_or_path, device=None):
24
+ self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
25
+ self.model = BertModel.from_pretrained(model_name_or_path)
26
+ self.model.eval()
27
+
28
+ if device is None:
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ self.device = torch.device(device)
31
+ self.model.to(device)
32
+
33
+ def _mean_pooling(self, model_output, attention_mask):
34
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
35
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
36
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
37
+
38
+ @torch.no_grad()
39
+ def encode(self, sentences, batch_size=8):
40
+ all_embeddings = []
41
+ iterator = range(0, len(sentences), batch_size)
42
+ for batch_idx in iterator:
43
+ batch = sentences[batch_idx:batch_idx + batch_size]
44
+
45
+ encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest",
46
+ truncation=True, return_tensors="pt").to(self.device)
47
+ model_output = self.model(**encoded_input)
48
+ sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')
49
+
50
+ all_embeddings.extend(sentence_embeddings)
51
+
52
+ return torch.stack(all_embeddings)
53
+
54
+
55
+ MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens"
56
+ model = SentenceBertJapanese(MODEL_NAME)
57
+
58
+ sentences = ["暴走したAI", "暴走した人工知能"]
59
+ sentence_embeddings = model.encode(sentences, batch_size=8)
60
+
61
+ print("Sentence embeddings:", sentence_embeddings)
62
+ ```
63
+
64
+