akariasai commited on
Commit
9fedb1a
1 Parent(s): 3e42dfc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -2
README.md CHANGED
@@ -26,6 +26,7 @@ from src.modeling_enc_t5 import EncT5ForSequenceClassification
26
  from src.tokenization_enc_t5 import EncT5Tokenizer
27
  import torch
28
  import torch.nn.functional as F
 
29
 
30
  # load TART full and tokenizer
31
  model = EncT5ForSequenceClassification.from_pretrained("facebook/tart-full-flan-t5-xl")
@@ -49,10 +50,10 @@ print([p_1, p_2][np.argmax(normalized_scores)]) # "The population of Japan's cap
49
  # 2. TART-full can identify the document that is more relevant AND follows instructions.
50
  in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question"
51
  q_1 = "How many people live in Tokyo?"
52
- features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [p, q_1], padding=True, truncation=True, return_tensors="pt")
53
  with torch.no_grad():
54
  scores = model(**features).logits
55
  normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
56
 
57
- print([p, q_1][np.argmax(normalized_scores)]) # "How many people live in Tokyo?"
58
  ```
 
26
  from src.tokenization_enc_t5 import EncT5Tokenizer
27
  import torch
28
  import torch.nn.functional as F
29
+ import numpy as np
30
 
31
  # load TART full and tokenizer
32
  model = EncT5ForSequenceClassification.from_pretrained("facebook/tart-full-flan-t5-xl")
 
50
  # 2. TART-full can identify the document that is more relevant AND follows instructions.
51
  in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question"
52
  q_1 = "How many people live in Tokyo?"
53
+ features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [p_1, q_1], padding=True, truncation=True, return_tensors="pt")
54
  with torch.no_grad():
55
  scores = model(**features).logits
56
  normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
57
 
58
+ print([p_1, q_1][np.argmax(normalized_scores)]) # "How many people live in Tokyo?"
59
  ```