Spaces:
Runtime error
Runtime error
Vignesh-10215
commited on
Commit
•
32b097c
1
Parent(s):
338a53a
added readme
Browse files
delete.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
from transformers import BertTokenizer, BertModel, BertConfig
|
2 |
-
import torch
|
3 |
-
from torch import nn
|
4 |
-
|
5 |
-
threshold = 0.001
|
6 |
-
device = "cpu"
|
7 |
-
bert = "bert-base-multilingual-cased"
|
8 |
-
config = BertConfig.from_pretrained(bert, output_hidden_states=True)
|
9 |
-
bert_tokenizer = BertTokenizer.from_pretrained(bert)
|
10 |
-
bert_model = BertModel.from_pretrained(bert, config=config).to(device)
|
11 |
-
source_text = "Hello, my dog is cute"
|
12 |
-
translated_text = "Hello, my dog is cute"
|
13 |
-
source_tokens = bert_tokenizer(source_text, return_tensors="pt")
|
14 |
-
print(source_tokens)
|
15 |
-
source_tokens_len = len(bert_tokenizer.tokenize(source_text))
|
16 |
-
target_tokens_len = len(bert_tokenizer.tokenize(translated_text))
|
17 |
-
target_tokens = bert_tokenizer(translated_text, return_tensors="pt")
|
18 |
-
bpe_source_map = []
|
19 |
-
for i in source_text.split():
|
20 |
-
bpe_source_map += len(bert_tokenizer.tokenize(i)) * [i]
|
21 |
-
bpe_target_map = []
|
22 |
-
for i in translated_text.split():
|
23 |
-
bpe_target_map += len(bert_tokenizer.tokenize(i)) * [i]
|
24 |
-
source_embedding = bert_model(**source_tokens).hidden_states[8]
|
25 |
-
target_embedding = bert_model(**target_tokens).hidden_states[8]
|
26 |
-
target_embedding = target_embedding.transpose(-1, -2)
|
27 |
-
source_target_mapping = nn.Softmax(dim=-1)(
|
28 |
-
torch.matmul(source_embedding, target_embedding)
|
29 |
-
)
|
30 |
-
print(source_target_mapping.shape)
|
31 |
-
target_source_mapping = nn.Softmax(dim=-2)(
|
32 |
-
torch.matmul(source_embedding, target_embedding)
|
33 |
-
)
|
34 |
-
print(target_source_mapping.shape)
|
35 |
-
|
36 |
-
align_matrix = (source_target_mapping > threshold) * (target_source_mapping > threshold)
|
37 |
-
align_prob = (2 * source_target_mapping * target_source_mapping) / (
|
38 |
-
source_target_mapping + target_source_mapping + 1e-9
|
39 |
-
)
|
40 |
-
non_zeros = torch.nonzero(align_matrix)
|
41 |
-
print(non_zeros)
|
42 |
-
for i, j, k in non_zeros:
|
43 |
-
if j + 1 < source_tokens_len - 1 and k + 1 < target_tokens_len - 1:
|
44 |
-
print(bpe_source_map[j + 1], bpe_target_map[k + 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|