File size: 3,574 Bytes
22b56e7
388679e
22b56e7
 
 
 
 
 
fbbcdd2
 
 
22b56e7
 
 
fbbcdd2
 
 
22b56e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbbcdd2
22b56e7
fbbcdd2
22b56e7
797d112
 
 
22b56e7
 
 
 
fbbcdd2
22b56e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbbcdd2
 
22b56e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

from accelerate import Accelerator
accelerator = Accelerator()

tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru")
model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru")

device = accelerator.device

model = accelerator.prepare(model)

@dataclass
class TranslationResult:
    input_text: str
    n_input: int
    input_tokens: List[str]
    n_output: int
    output_text: str
    output_tokens: List[str]
    output_scores: List[List[Tuple[str, float]]]
    cross_attention: np.ndarray


def translator_fn(input_text: str, k=10) -> TranslationResult:
    # Preprocess input
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
    input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens]).to(device)

    if len(input_tokens) > model.config.d_model:
        raise ValueError("Input text is too long")

    # Generate output
    outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
    output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    output_tokens = tokenizer.batch_decode(outputs.sequences[0])
    output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens]).to(device)

    # Get cross attention matrix
    cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions])
    attention_matrix = cross_attention.mean(dim=4).mean(dim=3).mean(dim=2).mean(dim=1).detach().cpu().numpy()

    # Get top tokens
    top_scores = []
    len_input = len(input_tokens)
    len_output = len(output_tokens)

    for i in range(len_output - 1):
        if i + 1 < len_output and output_special_mask[i + 1] == 1:
            # Skip special tokens (e.g. </s>, <pad>, etc.)
            continue
        top_elements, top_indices = outputs.scores[i].mean(dim=0).topk(k)
        top_elements = top_elements.exp()
        top_elements /= top_elements.sum()

        top_indices = tokenizer.batch_decode(top_indices)

        # filter out special tokens
        top_pairs = [(m, t.item()) for t, m in zip(top_elements, top_indices) if m not in tokenizer.all_special_tokens]
        top_scores.append(top_pairs)

    # Filter out special tokens from all elements
    clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0]
    clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0]
    clean_attention_matrix = attention_matrix[:len_output, :len_input]  # for padding
    clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask.detach().cpu().numpy() == 1), axis=0)
    clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask.detach().cpu().numpy() == 1), axis=1)

    n_input = len(clean_input_tokens)
    n_output = len(clean_output_tokens)

    assert clean_attention_matrix.shape == (n_output, n_input)
    assert len(top_scores) == n_output
    return TranslationResult(
        input_text=input_text,
        n_input=n_input,
        input_tokens=clean_input_tokens,
        output_text=output_text,
        n_output=n_output,
        output_tokens=clean_output_tokens,
        output_scores=top_scores,
        cross_attention=clean_attention_matrix
    )