Update README.md
Browse files
README.md
CHANGED
@@ -48,15 +48,10 @@ SentenceTransformer(
|
|
48 |
|
49 |
## Usage
|
50 |
|
51 |
-
###
|
52 |
|
53 |
-
|
54 |
|
55 |
-
```bash
|
56 |
-
pip install -U sentence-transformers
|
57 |
-
```
|
58 |
-
|
59 |
-
Then you can load this model and run inference.
|
60 |
```python
|
61 |
from sentence_transformers import SentenceTransformer
|
62 |
import torch.nn.functional as F
|
@@ -65,7 +60,8 @@ import torch.nn.functional as F
|
|
65 |
# The argument "trust_remote_code=True" is required to load the model
|
66 |
model = SentenceTransformer("pkshatech/RoSEtta-base-ja",trust_remote_code=True)
|
67 |
|
68 |
-
#
|
|
|
69 |
sentences = [
|
70 |
'query: PKSHAはどんな会社ですか?',
|
71 |
'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。',
|
@@ -79,19 +75,56 @@ print(embeddings.shape)
|
|
79 |
# Get the similarity scores for the embeddings
|
80 |
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
|
81 |
print(similarities)
|
82 |
-
#
|
83 |
-
#
|
84 |
-
#
|
85 |
-
#
|
86 |
```
|
87 |
|
88 |
-
<!--
|
89 |
### Direct Usage (Transformers)
|
90 |
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
<!--
|
97 |
### Downstream Usage (Sentence Transformers)
|
|
|
48 |
|
49 |
## Usage
|
50 |
|
51 |
+
### Usage (Sentence Transformers)
|
52 |
|
53 |
+
You can perform inference using SentenceTransformers with the following code:
|
54 |
|
|
|
|
|
|
|
|
|
|
|
55 |
```python
|
56 |
from sentence_transformers import SentenceTransformer
|
57 |
import torch.nn.functional as F
|
|
|
60 |
# The argument "trust_remote_code=True" is required to load the model
|
61 |
model = SentenceTransformer("pkshatech/RoSEtta-base-ja",trust_remote_code=True)
|
62 |
|
63 |
+
# Each input text should start with "query: " or "passage: ".
|
64 |
+
# For tasks other than retrieval, you can simply use the "query: " prefix.
|
65 |
sentences = [
|
66 |
'query: PKSHAはどんな会社ですか?',
|
67 |
'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。',
|
|
|
75 |
# Get the similarity scores for the embeddings
|
76 |
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
|
77 |
print(similarities)
|
78 |
+
# [[1.0000, 0.5910, 0.4332, 0.5421],
|
79 |
+
# [0.5910, 1.0000, 0.4977, 0.6969],
|
80 |
+
# [0.4332, 0.4977, 1.0000, 0.7475],
|
81 |
+
# [0.5421, 0.6969, 0.7475, 1.0000]]
|
82 |
```
|
83 |
|
|
|
84 |
### Direct Usage (Transformers)
|
85 |
|
86 |
+
You can perform inference using Transformers with the following code:
|
87 |
|
88 |
+
```python
|
89 |
+
import torch.nn.functional as F
|
90 |
+
from torch import Tensor
|
91 |
+
from transformers import AutoTokenizer, AutoModel
|
92 |
+
|
93 |
+
def mean_pooling(last_hidden_states: Tensor,attention_mask: Tensor) -> Tensor:
|
94 |
+
emb = last_hidden_states * attention_mask.unsqueeze(-1)
|
95 |
+
emb = emb.sum(dim=1) / attention_mask.sum(dim=1).unsqueeze(-1)
|
96 |
+
return emb
|
97 |
+
|
98 |
+
# Download from the 🤗 Hub
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained("pkshatech/RoSEtta-base-ja")
|
100 |
+
# The argument "trust_remote_code=True" is required to load the model
|
101 |
+
model = AutoModel.from_pretrained("pkshatech/RoSEtta-base-ja",trust_remote_code=True)
|
102 |
+
|
103 |
+
# Each input text should start with "query: " or "passage: ".
|
104 |
+
# For tasks other than retrieval, you can simply use the "query: " prefix.
|
105 |
+
sentences = [
|
106 |
+
'query: PKSHAはどんな会社ですか?',
|
107 |
+
'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。',
|
108 |
+
'query: 日本で一番高い山は?',
|
109 |
+
'passage: 富士山(ふじさん)は、標高3776.12 m、日本最高峰(剣ヶ峰)の独立峰で、その優美な風貌は日本国外でも日本の象徴として広く知られている。',
|
110 |
+
]
|
111 |
+
|
112 |
+
# Tokenize the input texts
|
113 |
+
batch_dict = tokenizer(sentences, max_length=1024, padding=True, truncation=True, return_tensors='pt')
|
114 |
+
|
115 |
+
outputs = model(**batch_dict)
|
116 |
+
embeddings = mean_pooling(outputs.last_hidden_state, batch_dict['attention_mask'])
|
117 |
+
print(embeddings.shape)
|
118 |
+
# [4, 768]
|
119 |
+
|
120 |
+
# Get the similarity scores for the embeddings
|
121 |
+
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
|
122 |
+
print(similarities)
|
123 |
+
# [[1.0000, 0.5910, 0.4332, 0.5421],
|
124 |
+
# [0.5910, 1.0000, 0.4977, 0.6969],
|
125 |
+
# [0.4332, 0.4977, 1.0000, 0.7475],
|
126 |
+
# [0.5421, 0.6969, 0.7475, 1.0000]]
|
127 |
+
```
|
128 |
|
129 |
<!--
|
130 |
### Downstream Usage (Sentence Transformers)
|