File size: 3,801 Bytes
407b426
 
 
 
 
 
 
 
 
 
 
 
 
 
a45e982
 
 
 
 
 
 
 
f19f4b7
 
db29d97
 
0a0f809
 
a45e982
 
 
 
 
407b426
 
 
a45e982
 
 
 
407b426
 
 
a45e982
407b426
 
a45e982
407b426
 
a45e982
 
407b426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a45e982
407b426
 
 
 
a45e982
407b426
 
 
 
 
 
 
 
 
 
a45e982
407b426
 
 
 
a45e982
407b426
 
 
 
 
 
a45e982
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
This module contains the helper functions to get the word alignment mapping between two sentences.
"""

import torch
import itertools
import transformers
from transformers import logging

# Set the verbosity to error, so that the warning messages are not printed
logging.set_verbosity_warning()
logging.set_verbosity_error()


def select_model(model_name):
    """
    Select Model
    """
    if model_name == "Google-mBERT (Base-Multilingual)":
        model_name="bert-base-multilingual-cased"
    elif model_name == "Neulab-AwesomeAlign (Bn-En-0.5M)":
        model_name="musfiqdehan/bn-en-word-aligner"
    elif model_name == "BUET-BanglaBERT (Large)":
        model_name="csebuetnlp/banglabert_large"
    elif model_name == "SagorSarker-BanglaBERT (Base)":
        model_name="sagorsarker/bangla-bert-base"
    elif model_name == "SentenceTransformers-LaBSE (Multilingual)":
        model_name="sentence-transformers/LaBSE"

    return model_name


def get_alignment_mapping(source="", target="", model_name=""):
    """
    Get Aligned Words
    """
    model_name = select_model(model_name)

    model = transformers.BertModel.from_pretrained(model_name)
    tokenizer = transformers.BertTokenizer.from_pretrained(model_name)

    # pre-processing
    sent_src, sent_tgt = source.strip().split(), target.strip().split()

    token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [
        tokenizer.tokenize(word) for word in sent_tgt]
    
    wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [
        tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
    
    ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids']
    sub2word_map_src = []

    for i, word_list in enumerate(token_src):
        sub2word_map_src += [i for x in word_list]

    sub2word_map_tgt = []

    for i, word_list in enumerate(token_tgt):
        sub2word_map_tgt += [i for x in word_list]

    # alignment
    align_layer = 8

    threshold = 1e-3

    model.eval()

    with torch.no_grad():
        out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[
            2][align_layer][0, 1:-1]
        out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[
            2][align_layer][0, 1:-1]

        dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))

        softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
        softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)

        softmax_inter = (softmax_srctgt > threshold) * \
            (softmax_tgtsrc > threshold)

    align_subwords = torch.nonzero(softmax_inter, as_tuple=False)

    align_words = set()

    for i, j in align_subwords:
        align_words.add((sub2word_map_src[i], sub2word_map_tgt[j]))

    return sent_src, sent_tgt, align_words



def get_word_mapping(source="", target="", model_name=""):
    """
    Get Word Aligned Mapping Words
    """
    sent_src, sent_tgt, align_words = get_alignment_mapping(
        source=source, target=target, model_name=model_name)

    result = []

    for i, j in sorted(align_words):
        result.append(f'bn:({sent_src[i]}) -> en:({sent_tgt[j]})')

    return result



def get_word_index_mapping(source="", target="", model_name=""):
    """
    Get Word Aligned Mapping Index
    """
    sent_src, sent_tgt, align_words = get_alignment_mapping(
        source=source, target=target, model_name=model_name)

    result = []

    for i, j in sorted(align_words):
        result.append(f'bn:({i}) -> en:({j})')

    return result