Spaces:
Runtime error
Runtime error
from transformers import BertTokenizer, BertModel, BertConfig | |
import torch | |
from torch import nn | |
threshold = 0.001 | |
device = "cpu" | |
bert = "bert-base-multilingual-cased" | |
config = BertConfig.from_pretrained(bert, output_hidden_states=True) | |
bert_tokenizer = BertTokenizer.from_pretrained(bert) | |
bert_model = BertModel.from_pretrained(bert, config=config).to(device) | |
source_text = "Hello, my dog is cute" | |
translated_text = "Hello, my dog is cute" | |
source_tokens = bert_tokenizer(source_text, return_tensors="pt") | |
print(source_tokens) | |
source_tokens_len = len(bert_tokenizer.tokenize(source_text)) | |
target_tokens_len = len(bert_tokenizer.tokenize(translated_text)) | |
target_tokens = bert_tokenizer(translated_text, return_tensors="pt") | |
bpe_source_map = [] | |
for i in source_text.split(): | |
bpe_source_map += len(bert_tokenizer.tokenize(i)) * [i] | |
bpe_target_map = [] | |
for i in translated_text.split(): | |
bpe_target_map += len(bert_tokenizer.tokenize(i)) * [i] | |
source_embedding = bert_model(**source_tokens).hidden_states[8] | |
target_embedding = bert_model(**target_tokens).hidden_states[8] | |
target_embedding = target_embedding.transpose(-1, -2) | |
source_target_mapping = nn.Softmax(dim=-1)( | |
torch.matmul(source_embedding, target_embedding) | |
) | |
print(source_target_mapping.shape) | |
target_source_mapping = nn.Softmax(dim=-2)( | |
torch.matmul(source_embedding, target_embedding) | |
) | |
print(target_source_mapping.shape) | |
align_matrix = (source_target_mapping > threshold) * (target_source_mapping > threshold) | |
align_prob = (2 * source_target_mapping * target_source_mapping) / ( | |
source_target_mapping + target_source_mapping + 1e-9 | |
) | |
non_zeros = torch.nonzero(align_matrix) | |
print(non_zeros) | |
for i, j, k in non_zeros: | |
if j + 1 < source_tokens_len - 1 and k + 1 < target_tokens_len - 1: | |
print(bpe_source_map[j + 1], bpe_target_map[k + 1]) | |