mt5-large-stata / README.md
adenhaus's picture
Update README.md
62660f1 verified
|
raw
history blame
2.22 kB
metadata
datasets:
  - adenhaus/stata
language:
  - en
  - yo
  - sw
  - ig
  - ar
  - fr
  - pt
  - ha
  - ru
tags:
  - data-to-text
multilinguality:
  - 'yes'
license: cc-by-sa-4.0
inference: false

Example use

from transformers import MT5ForConditionalGeneration, MT5Tokenizer
import torch

model_path = 'adenhaus/mt5-large-stata'
tokenizer = MT5Tokenizer.from_pretrained(model_path)
model = MT5ForConditionalGeneration.from_pretrained(model_path)

class RegressionLogitsProcessor(torch.nn.Module):
    def __init__(self, extra_token_id):
        super().__init__()
        self.extra_token_id = extra_token_id

    def __call__(self, input_ids, scores):
        extra_token_logit = scores[:, :, self.extra_token_id] 
        return extra_token_logit

def preprocess_inference_input(input_text):
    input_encoded = tokenizer(input_text, return_tensors='pt')
    return input_encoded

unused_token = "<extra_id_1>"

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

def do_regression(input_str):
  input_data = preprocess_inference_input(input_str)

  logits_processor = RegressionLogitsProcessor(tokenizer.get_vocab()[unused_token])

  output_sequences = model.generate(
      **input_data,
      max_length=2,  # Generate just the regression token
      do_sample=False,  # Important: Disable sampling for deterministic output
      return_dict_in_generate=True,  # Get the scores directly
      output_scores=True
  )

  # Extract the logit
  unused_token_id = tokenizer.get_vocab()[unused_token]
  regression_logit = output_sequences.scores[0][0][unused_token_id]

  regression_score = sigmoid(regression_logit).item()

  return regression_score

print(do_regression("Vaccination Coverage by Province | Percent of children age 12-23 months who received all basic vaccinations | (Angola, 31) (Cabinda, 38) (Zaire, 38) (Uige, 15) (Bengo, 24) (Cuanza Norte, 30) (Luanda, 50) (Malanje, 38) (Lunda Norte, 21) (Cuanza Sul, 19) (Lunda Sul, 21) (Benguela, 26) (Huambo, 26) (Bié, 10) (Moxico, 10) (Namibe, 30) (Huíla, 23) (Cunene, 40) (Cuando Cubango, 8) [output] Three in ten children age 12-23 months received all basic vaccinations—one dose each of BCG and measles and three doses each of DPT-containing vaccine and polio."))