Update README.md
Browse files
README.md
CHANGED
@@ -26,6 +26,7 @@ from src.modeling_enc_t5 import EncT5ForSequenceClassification
|
|
26 |
from src.tokenization_enc_t5 import EncT5Tokenizer
|
27 |
import torch
|
28 |
import torch.nn.functional as F
|
|
|
29 |
|
30 |
# load TART full and tokenizer
|
31 |
model = EncT5ForSequenceClassification.from_pretrained("facebook/tart-full-flan-t5-xl")
|
@@ -49,10 +50,10 @@ print([p_1, p_2][np.argmax(normalized_scores)]) # "The population of Japan's cap
|
|
49 |
# 2. TART-full can identify the document that is more relevant AND follows instructions.
|
50 |
in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question"
|
51 |
q_1 = "How many people live in Tokyo?"
|
52 |
-
features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [
|
53 |
with torch.no_grad():
|
54 |
scores = model(**features).logits
|
55 |
normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
|
56 |
|
57 |
-
print([
|
58 |
```
|
|
|
26 |
from src.tokenization_enc_t5 import EncT5Tokenizer
|
27 |
import torch
|
28 |
import torch.nn.functional as F
|
29 |
+
import numpy as np
|
30 |
|
31 |
# load TART full and tokenizer
|
32 |
model = EncT5ForSequenceClassification.from_pretrained("facebook/tart-full-flan-t5-xl")
|
|
|
50 |
# 2. TART-full can identify the document that is more relevant AND follows instructions.
|
51 |
in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question"
|
52 |
q_1 = "How many people live in Tokyo?"
|
53 |
+
features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [p_1, q_1], padding=True, truncation=True, return_tensors="pt")
|
54 |
with torch.no_grad():
|
55 |
scores = model(**features).logits
|
56 |
normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
|
57 |
|
58 |
+
print([p_1, q_1][np.argmax(normalized_scores)]) # "How many people live in Tokyo?"
|
59 |
```
|