File size: 3,316 Bytes
06fd19f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4a3594
649824b
186968a
649824b
77b46a5
fdece2c
f06bc7f
 
38c3ecc
 
186968a
 
 
 
cdc3165
 
649824b
 
 
 
 
 
 
 
 
c23d17e
649824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c23d17e
 
 
 
649824b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
---
datasets:
- adenhaus/stata
language:
- en
- yo
- sw
- ig
- ar
- fr
- pt
- ha
- ru
tags:
- data-to-text
multilinguality:
- 'yes'
license: cc-by-sa-4.0
inference: false
---
# Background

This learned metric is for evaluating models trained on the TaTA dataset. It was trained as per instructions in [TaTA: A Multilingual Table-to-Text Dataset for African Languages](https://aclanthology.org/2023.findings-emnlp.118/) (StATA-QE variant).

StATA takes as input a linearized table and an output verbalisation seperated by an " \[output\] " tag, and produces a score between 0 and 1. A score closer to 1 means the output is more understandable and atributable to the source table, a score closer to 0 is less so.

The original file can be found [here](https://github.com/google-research/url-nlp/tree/main/tata).

# Performance

It achieves an RMSE loss of 0.41 on the dev split, and a Pearson correlation of 0.23 with human evaluations on the test split ("attributable" column of [this dataset](https://huggingface.co/datasets/adenhaus/stata)).

[Here](https://huggingface.co/adenhaus/mt5-large-stata/blob/main/README.md) is a version trained with mT5-Large instead of mT5-Small, which achieves a correlation of 0.59.

# Example use

```python
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
import torch

model_path = 'adenhaus/mt5-small-stata'
tokenizer = MT5Tokenizer.from_pretrained(model_path)
model = MT5ForConditionalGeneration.from_pretrained(model_path)
unused_token = "<extra_id_1>"

class RegressionLogitsProcessor(torch.nn.Module):
    def __init__(self, extra_token_id):
        super().__init__()
        self.extra_token_id = extra_token_id

    def __call__(self, input_ids, scores):
        extra_token_logit = scores[:, :, self.extra_token_id] 
        return extra_token_logit

def preprocess_inference_input(input_text):
    input_encoded = tokenizer(input_text, return_tensors='pt')
    return input_encoded

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

def do_regression(input_str):
  input_data = preprocess_inference_input(input_str)

  logits_processor = RegressionLogitsProcessor(tokenizer.get_vocab()[unused_token])

  output_sequences = model.generate(
      **input_data,
      max_length=2,  # Generate just the regression token
      do_sample=False,  # Important: Disable sampling for deterministic output
      return_dict_in_generate=True,  # Get the scores directly
      output_scores=True
  )

  # Extract the logit
  unused_token_id = tokenizer.get_vocab()[unused_token]
  regression_logit = output_sequences.scores[0][0][unused_token_id]
  regression_score = sigmoid(regression_logit).item()
  return regression_score

source_table = "Vaccination Coverage by Province | Percent of children age 12-23 months who received all basic vaccinations | (Angola, 31) (Cabinda, 38) (Zaire, 38) (Uige, 15) (Bengo, 24) (Cuanza Norte, 30) (Luanda, 50) (Malanje, 38) (Lunda Norte, 21) (Cuanza Sul, 19) (Lunda Sul, 21) (Benguela, 26) (Huambo, 26) (Bié, 10) (Moxico, 10) (Namibe, 30) (Huíla, 23) (Cunene, 40) (Cuando Cubango, 8"
output = "Three in ten children age 12-23 months received all basic vaccinations—one dose each of BCG and measles and three doses each of DPT-containing vaccine and polio."

print(do_regression(source_table + " [output] " + output))
```