adenhaus commited on
Commit
62660f1
1 Parent(s): 6bf4376

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -1
README.md CHANGED
@@ -17,4 +17,56 @@ multilinguality:
17
  - 'yes'
18
  license: cc-by-sa-4.0
19
  inference: false
20
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  - 'yes'
18
  license: cc-by-sa-4.0
19
  inference: false
20
+ ---
21
+
22
+ # Example use
23
+
24
+ ```python
25
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer
26
+ import torch
27
+
28
+ model_path = 'adenhaus/mt5-large-stata'
29
+ tokenizer = MT5Tokenizer.from_pretrained(model_path)
30
+ model = MT5ForConditionalGeneration.from_pretrained(model_path)
31
+
32
+ class RegressionLogitsProcessor(torch.nn.Module):
33
+ def __init__(self, extra_token_id):
34
+ super().__init__()
35
+ self.extra_token_id = extra_token_id
36
+
37
+ def __call__(self, input_ids, scores):
38
+ extra_token_logit = scores[:, :, self.extra_token_id]
39
+ return extra_token_logit
40
+
41
+ def preprocess_inference_input(input_text):
42
+ input_encoded = tokenizer(input_text, return_tensors='pt')
43
+ return input_encoded
44
+
45
+ unused_token = "<extra_id_1>"
46
+
47
+ def sigmoid(x):
48
+ return 1 / (1 + torch.exp(-x))
49
+
50
+ def do_regression(input_str):
51
+ input_data = preprocess_inference_input(input_str)
52
+
53
+ logits_processor = RegressionLogitsProcessor(tokenizer.get_vocab()[unused_token])
54
+
55
+ output_sequences = model.generate(
56
+ **input_data,
57
+ max_length=2, # Generate just the regression token
58
+ do_sample=False, # Important: Disable sampling for deterministic output
59
+ return_dict_in_generate=True, # Get the scores directly
60
+ output_scores=True
61
+ )
62
+
63
+ # Extract the logit
64
+ unused_token_id = tokenizer.get_vocab()[unused_token]
65
+ regression_logit = output_sequences.scores[0][0][unused_token_id]
66
+
67
+ regression_score = sigmoid(regression_logit).item()
68
+
69
+ return regression_score
70
+
71
+ print(do_regression("Vaccination Coverage by Province | Percent of children age 12-23 months who received all basic vaccinations | (Angola, 31) (Cabinda, 38) (Zaire, 38) (Uige, 15) (Bengo, 24) (Cuanza Norte, 30) (Luanda, 50) (Malanje, 38) (Lunda Norte, 21) (Cuanza Sul, 19) (Lunda Sul, 21) (Benguela, 26) (Huambo, 26) (Bié, 10) (Moxico, 10) (Namibe, 30) (Huíla, 23) (Cunene, 40) (Cuando Cubango, 8) [output] Three in ten children age 12-23 months received all basic vaccinations—one dose each of BCG and measles and three doses each of DPT-containing vaccine and polio."))
72
+ ```