Update README.md
Browse files
README.md
CHANGED
@@ -59,21 +59,30 @@ pip install -U sentence-transformers
|
|
59 |
Then you can load this model and run inference.
|
60 |
```python
|
61 |
from sentence_transformers import SentenceTransformer
|
|
|
62 |
|
63 |
# Download from the 🤗 Hub
|
64 |
-
|
|
|
65 |
|
66 |
# Don't forget to add the prefix "query: " for query-side or "passage: " for passage-side texts.
|
67 |
sentences = [
|
68 |
-
'query: PKSHAはどんな会社ですか?'
|
69 |
-
'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。'
|
|
|
|
|
70 |
]
|
71 |
-
embeddings = model.encode(sentences)
|
72 |
print(embeddings.shape)
|
73 |
-
# [
|
74 |
|
75 |
# Get the similarity scores for the embeddings
|
76 |
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
|
|
|
|
|
|
|
|
|
|
|
77 |
```
|
78 |
|
79 |
<!--
|
|
|
59 |
Then you can load this model and run inference.
|
60 |
```python
|
61 |
from sentence_transformers import SentenceTransformer
|
62 |
+
import torch.nn.functional as F
|
63 |
|
64 |
# Download from the 🤗 Hub
|
65 |
+
# The argument "trust_remote_code=True" is required to load the model
|
66 |
+
model = SentenceTransformer("pkshatech/RoSEtta-base-ja",trust_remote_code=True)
|
67 |
|
68 |
# Don't forget to add the prefix "query: " for query-side or "passage: " for passage-side texts.
|
69 |
sentences = [
|
70 |
+
'query: PKSHAはどんな会社ですか?',
|
71 |
+
'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。',
|
72 |
+
'query: 日本で一番高い山は?',
|
73 |
+
'passage: 富士山(ふじさん)は、標高3776.12 m、日本最高峰(剣ヶ峰)の独立峰で、その優美な風貌は日本国外でも日本の象徴として広く知られている。',
|
74 |
]
|
75 |
+
embeddings = model.encode(sentences,convert_to_tensor=True)
|
76 |
print(embeddings.shape)
|
77 |
+
# [4, 768]
|
78 |
|
79 |
# Get the similarity scores for the embeddings
|
80 |
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
|
81 |
+
print(similarities)
|
82 |
+
# tensor([[1.0000, 0.5910, 0.4332, 0.5421],
|
83 |
+
# [0.5910, 1.0000, 0.4977, 0.6969],
|
84 |
+
# [0.4332, 0.4977, 1.0000, 0.7475],
|
85 |
+
# [0.5421, 0.6969, 0.7475, 1.0000]])
|
86 |
```
|
87 |
|
88 |
<!--
|