adenhaus commited on
Commit
c23d17e
1 Parent(s): f06bc7f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -5
README.md CHANGED
@@ -41,6 +41,7 @@ import torch
41
  model_path = 'adenhaus/mt5-small-stata'
42
  tokenizer = MT5Tokenizer.from_pretrained(model_path)
43
  model = MT5ForConditionalGeneration.from_pretrained(model_path)
 
44
 
45
  class RegressionLogitsProcessor(torch.nn.Module):
46
  def __init__(self, extra_token_id):
@@ -55,8 +56,6 @@ def preprocess_inference_input(input_text):
55
  input_encoded = tokenizer(input_text, return_tensors='pt')
56
  return input_encoded
57
 
58
- unused_token = "<extra_id_1>"
59
-
60
  def sigmoid(x):
61
  return 1 / (1 + torch.exp(-x))
62
 
@@ -76,10 +75,11 @@ def do_regression(input_str):
76
  # Extract the logit
77
  unused_token_id = tokenizer.get_vocab()[unused_token]
78
  regression_logit = output_sequences.scores[0][0][unused_token_id]
79
-
80
  regression_score = sigmoid(regression_logit).item()
81
-
82
  return regression_score
83
 
84
- print(do_regression("Vaccination Coverage by Province | Percent of children age 12-23 months who received all basic vaccinations | (Angola, 31) (Cabinda, 38) (Zaire, 38) (Uige, 15) (Bengo, 24) (Cuanza Norte, 30) (Luanda, 50) (Malanje, 38) (Lunda Norte, 21) (Cuanza Sul, 19) (Lunda Sul, 21) (Benguela, 26) (Huambo, 26) (Bié, 10) (Moxico, 10) (Namibe, 30) (Huíla, 23) (Cunene, 40) (Cuando Cubango, 8) [output] Three in ten children age 12-23 months received all basic vaccinations—one dose each of BCG and measles and three doses each of DPT-containing vaccine and polio."))
 
 
 
85
  ```
 
41
  model_path = 'adenhaus/mt5-small-stata'
42
  tokenizer = MT5Tokenizer.from_pretrained(model_path)
43
  model = MT5ForConditionalGeneration.from_pretrained(model_path)
44
+ unused_token = "<extra_id_1>"
45
 
46
  class RegressionLogitsProcessor(torch.nn.Module):
47
  def __init__(self, extra_token_id):
 
56
  input_encoded = tokenizer(input_text, return_tensors='pt')
57
  return input_encoded
58
 
 
 
59
  def sigmoid(x):
60
  return 1 / (1 + torch.exp(-x))
61
 
 
75
  # Extract the logit
76
  unused_token_id = tokenizer.get_vocab()[unused_token]
77
  regression_logit = output_sequences.scores[0][0][unused_token_id]
 
78
  regression_score = sigmoid(regression_logit).item()
 
79
  return regression_score
80
 
81
+ source_table = "Vaccination Coverage by Province | Percent of children age 12-23 months who received all basic vaccinations | (Angola, 31) (Cabinda, 38) (Zaire, 38) (Uige, 15) (Bengo, 24) (Cuanza Norte, 30) (Luanda, 50) (Malanje, 38) (Lunda Norte, 21) (Cuanza Sul, 19) (Lunda Sul, 21) (Benguela, 26) (Huambo, 26) (Bié, 10) (Moxico, 10) (Namibe, 30) (Huíla, 23) (Cunene, 40) (Cuando Cubango, 8"
82
+ output = "Three in ten children age 12-23 months received all basic vaccinations—one dose each of BCG and measles and three doses each of DPT-containing vaccine and polio."
83
+
84
+ print(do_regression(source_table + " [output] " + output))
85
  ```