Vignesh-10215 commited on
Commit
32b097c
1 Parent(s): 338a53a

added readme

Browse files
Files changed (1) hide show
  1. delete.py +0 -44
delete.py DELETED
@@ -1,44 +0,0 @@
1
- from transformers import BertTokenizer, BertModel, BertConfig
2
- import torch
3
- from torch import nn
4
-
5
- threshold = 0.001
6
- device = "cpu"
7
- bert = "bert-base-multilingual-cased"
8
- config = BertConfig.from_pretrained(bert, output_hidden_states=True)
9
- bert_tokenizer = BertTokenizer.from_pretrained(bert)
10
- bert_model = BertModel.from_pretrained(bert, config=config).to(device)
11
- source_text = "Hello, my dog is cute"
12
- translated_text = "Hello, my dog is cute"
13
- source_tokens = bert_tokenizer(source_text, return_tensors="pt")
14
- print(source_tokens)
15
- source_tokens_len = len(bert_tokenizer.tokenize(source_text))
16
- target_tokens_len = len(bert_tokenizer.tokenize(translated_text))
17
- target_tokens = bert_tokenizer(translated_text, return_tensors="pt")
18
- bpe_source_map = []
19
- for i in source_text.split():
20
- bpe_source_map += len(bert_tokenizer.tokenize(i)) * [i]
21
- bpe_target_map = []
22
- for i in translated_text.split():
23
- bpe_target_map += len(bert_tokenizer.tokenize(i)) * [i]
24
- source_embedding = bert_model(**source_tokens).hidden_states[8]
25
- target_embedding = bert_model(**target_tokens).hidden_states[8]
26
- target_embedding = target_embedding.transpose(-1, -2)
27
- source_target_mapping = nn.Softmax(dim=-1)(
28
- torch.matmul(source_embedding, target_embedding)
29
- )
30
- print(source_target_mapping.shape)
31
- target_source_mapping = nn.Softmax(dim=-2)(
32
- torch.matmul(source_embedding, target_embedding)
33
- )
34
- print(target_source_mapping.shape)
35
-
36
- align_matrix = (source_target_mapping > threshold) * (target_source_mapping > threshold)
37
- align_prob = (2 * source_target_mapping * target_source_mapping) / (
38
- source_target_mapping + target_source_mapping + 1e-9
39
- )
40
- non_zeros = torch.nonzero(align_matrix)
41
- print(non_zeros)
42
- for i, j, k in non_zeros:
43
- if j + 1 < source_tokens_len - 1 and k + 1 < target_tokens_len - 1:
44
- print(bpe_source_map[j + 1], bpe_target_map[k + 1])