adenhaus commited on
Commit
d6feb95
1 Parent(s): e721459

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -36,9 +36,10 @@ It achieves an RMSE loss of 0.32 on the dev split, and a Pearson correlation of
36
  from transformers import MT5ForConditionalGeneration, MT5Tokenizer
37
  import torch
38
 
39
- model_path = 'adenhaus/mt5-large-stata'
40
  tokenizer = MT5Tokenizer.from_pretrained(model_path)
41
  model = MT5ForConditionalGeneration.from_pretrained(model_path)
 
42
 
43
  class RegressionLogitsProcessor(torch.nn.Module):
44
  def __init__(self, extra_token_id):
@@ -53,8 +54,6 @@ def preprocess_inference_input(input_text):
53
  input_encoded = tokenizer(input_text, return_tensors='pt')
54
  return input_encoded
55
 
56
- unused_token = "<extra_id_1>"
57
-
58
  def sigmoid(x):
59
  return 1 / (1 + torch.exp(-x))
60
 
@@ -74,10 +73,11 @@ def do_regression(input_str):
74
  # Extract the logit
75
  unused_token_id = tokenizer.get_vocab()[unused_token]
76
  regression_logit = output_sequences.scores[0][0][unused_token_id]
77
-
78
  regression_score = sigmoid(regression_logit).item()
79
-
80
  return regression_score
81
 
82
- print(do_regression("Vaccination Coverage by Province | Percent of children age 12-23 months who received all basic vaccinations | (Angola, 31) (Cabinda, 38) (Zaire, 38) (Uige, 15) (Bengo, 24) (Cuanza Norte, 30) (Luanda, 50) (Malanje, 38) (Lunda Norte, 21) (Cuanza Sul, 19) (Lunda Sul, 21) (Benguela, 26) (Huambo, 26) (Bié, 10) (Moxico, 10) (Namibe, 30) (Huíla, 23) (Cunene, 40) (Cuando Cubango, 8) [output] Three in ten children age 12-23 months received all basic vaccinations—one dose each of BCG and measles and three doses each of DPT-containing vaccine and polio."))
 
 
 
83
  ```
 
36
  from transformers import MT5ForConditionalGeneration, MT5Tokenizer
37
  import torch
38
 
39
+ model_path = 'adenhaus/mt5-small-stata'
40
  tokenizer = MT5Tokenizer.from_pretrained(model_path)
41
  model = MT5ForConditionalGeneration.from_pretrained(model_path)
42
+ unused_token = "<extra_id_1>"
43
 
44
  class RegressionLogitsProcessor(torch.nn.Module):
45
  def __init__(self, extra_token_id):
 
54
  input_encoded = tokenizer(input_text, return_tensors='pt')
55
  return input_encoded
56
 
 
 
57
  def sigmoid(x):
58
  return 1 / (1 + torch.exp(-x))
59
 
 
73
  # Extract the logit
74
  unused_token_id = tokenizer.get_vocab()[unused_token]
75
  regression_logit = output_sequences.scores[0][0][unused_token_id]
 
76
  regression_score = sigmoid(regression_logit).item()
 
77
  return regression_score
78
 
79
+ source_table = "Vaccination Coverage by Province | Percent of children age 12-23 months who received all basic vaccinations | (Angola, 31) (Cabinda, 38) (Zaire, 38) (Uige, 15) (Bengo, 24) (Cuanza Norte, 30) (Luanda, 50) (Malanje, 38) (Lunda Norte, 21) (Cuanza Sul, 19) (Lunda Sul, 21) (Benguela, 26) (Huambo, 26) (Bié, 10) (Moxico, 10) (Namibe, 30) (Huíla, 23) (Cunene, 40) (Cuando Cubango, 8"
80
+ output = "Three in ten children age 12-23 months received all basic vaccinations—one dose each of BCG and measles and three doses each of DPT-containing vaccine and polio."
81
+
82
+ print(do_regression(source_table + " [output] " + output))
83
  ```