Spaces:
Runtime error
Runtime error
File size: 3,574 Bytes
22b56e7 388679e 22b56e7 fbbcdd2 22b56e7 fbbcdd2 22b56e7 fbbcdd2 22b56e7 fbbcdd2 22b56e7 797d112 22b56e7 fbbcdd2 22b56e7 fbbcdd2 22b56e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from accelerate import Accelerator
accelerator = Accelerator()
tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru")
model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru")
device = accelerator.device
model = accelerator.prepare(model)
@dataclass
class TranslationResult:
input_text: str
n_input: int
input_tokens: List[str]
n_output: int
output_text: str
output_tokens: List[str]
output_scores: List[List[Tuple[str, float]]]
cross_attention: np.ndarray
def translator_fn(input_text: str, k=10) -> TranslationResult:
# Preprocess input
inputs = tokenizer(input_text, return_tensors="pt").to(device)
input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens]).to(device)
if len(input_tokens) > model.config.d_model:
raise ValueError("Input text is too long")
# Generate output
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
output_tokens = tokenizer.batch_decode(outputs.sequences[0])
output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens]).to(device)
# Get cross attention matrix
cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions])
attention_matrix = cross_attention.mean(dim=4).mean(dim=3).mean(dim=2).mean(dim=1).detach().cpu().numpy()
# Get top tokens
top_scores = []
len_input = len(input_tokens)
len_output = len(output_tokens)
for i in range(len_output - 1):
if i + 1 < len_output and output_special_mask[i + 1] == 1:
# Skip special tokens (e.g. </s>, <pad>, etc.)
continue
top_elements, top_indices = outputs.scores[i].mean(dim=0).topk(k)
top_elements = top_elements.exp()
top_elements /= top_elements.sum()
top_indices = tokenizer.batch_decode(top_indices)
# filter out special tokens
top_pairs = [(m, t.item()) for t, m in zip(top_elements, top_indices) if m not in tokenizer.all_special_tokens]
top_scores.append(top_pairs)
# Filter out special tokens from all elements
clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0]
clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0]
clean_attention_matrix = attention_matrix[:len_output, :len_input] # for padding
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask.detach().cpu().numpy() == 1), axis=0)
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask.detach().cpu().numpy() == 1), axis=1)
n_input = len(clean_input_tokens)
n_output = len(clean_output_tokens)
assert clean_attention_matrix.shape == (n_output, n_input)
assert len(top_scores) == n_output
return TranslationResult(
input_text=input_text,
n_input=n_input,
input_tokens=clean_input_tokens,
output_text=output_text,
n_output=n_output,
output_tokens=clean_output_tokens,
output_scores=top_scores,
cross_attention=clean_attention_matrix
)
|