sonoisa commited on
Commit
b3a7ec5
1 Parent(s): 92ac038

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +76 -0
README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: ja
3
+ license: cc-by-sa-4.0
4
+ tags:
5
+ - sentence-transformers
6
+ - sentence-bert
7
+ - feature-extraction
8
+ - sentence-similarity
9
+ ---
10
+
11
+ This is a Japanese sentence-BERT model.
12
+
13
+ 日本語用Sentence-BERTモデルです。
14
+
15
+ [バージョン1](https://huggingface.co/sonoisa/sentence-bert-base-ja-mean-tokens)よりも良いロス関数である[MultipleNegativesRankingLoss](https://www.sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss)を用いて学習した改良版です。
16
+
17
+ 手元の非公開データセットでは、バージョン1よりも1.5ポイントほど精度が高い結果が得られました。
18
+
19
+
20
+ # 旧バージョンの解説
21
+
22
+ https://qiita.com/sonoisa/items/1df94d0a98cd4f209051
23
+
24
+ モデル名を"sonoisa/sentence-bert-base-ja-mean-tokens-v2"に書き換えれば、本モデルを利用した挙動になります。
25
+
26
+
27
+ # 使い方
28
+
29
+ ```python
30
+ from transformers import BertJapaneseTokenizer, BertModel
31
+ import torch
32
+
33
+
34
+ class SentenceBertJapanese:
35
+ def __init__(self, model_name_or_path, device=None):
36
+ self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
37
+ self.model = BertModel.from_pretrained(model_name_or_path)
38
+ self.model.eval()
39
+
40
+ if device is None:
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ self.device = torch.device(device)
43
+ self.model.to(device)
44
+
45
+ def _mean_pooling(self, model_output, attention_mask):
46
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
47
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
48
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
49
+
50
+ @torch.no_grad()
51
+ def encode(self, sentences, batch_size=8):
52
+ all_embeddings = []
53
+ iterator = range(0, len(sentences), batch_size)
54
+ for batch_idx in iterator:
55
+ batch = sentences[batch_idx:batch_idx + batch_size]
56
+
57
+ encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest",
58
+ truncation=True, return_tensors="pt").to(self.device)
59
+ model_output = self.model(**encoded_input)
60
+ sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')
61
+
62
+ all_embeddings.extend(sentence_embeddings)
63
+
64
+ # return torch.stack(all_embeddings).numpy()
65
+ return torch.stack(all_embeddings)
66
+
67
+
68
+ MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens-v2" # <- v2です。
69
+ model = SentenceBertJapanese(MODEL_NAME)
70
+
71
+ sentences = ["暴走したAI", "暴走した人工知能"]
72
+ sentence_embeddings = model.encode(sentences, batch_size=8)
73
+
74
+ print("Sentence embeddings:", sentence_embeddings)
75
+ ```
76
+