File size: 2,799 Bytes
8c7a320
 
 
 
 
 
 
 
 
 
e40e595
8c7a320
 
 
 
 
 
2afffb1
8c7a320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e40e595
 
 
 
 
 
8c7a320
 
 
 
 
e40e595
8c7a320
3815353
8c7a320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import random
from typing import *

import gradio as gr
import numpy as np
import seaborn as sns
import sentencepiece as sp
import torch

from huggingface_hub import hf_hub_download
from matplotlib.figure import Figure
from torchtext.datasets import Multi30k

from models import Seq2Seq


# Load model
model_path = hf_hub_download("msarmi9/multi30k", "models/de-en/version_1/model.bin")
model = Seq2Seq(vocab_size=8000, hidden_dim=512, bos_idx=1, eos_idx=2, pad_idx=3, temperature=2)
model.load_state_dict(torch.load(model_path))
model.eval()

# Load sentencepiece tokenizers
source_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/de8000.model")
target_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/en8000.model")
source_spm = sp.SentencePieceProcessor(model_file=source_spm_path, add_eos=True)
target_spm = sp.SentencePieceProcessor(model_file=target_spm_path, add_eos=True)

# Load test set for example inputs
normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip())
test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))


def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> Figure:
    figure = Figure(dpi=800, tight_layout=True)
    axes = figure.add_subplot()
    axes = sns.heatmap(weights, ax=axes, xticklabels=input_tokens, yticklabels=output_tokens, cmap="gray", cbar=False)
    axes.tick_params(axis="x", rotation=90, length=0)
    axes.tick_params(axis="y", rotation=0, length=0)
    axes.xaxis.tick_top()
    return figure


@torch.inference_mode()
def run(input: str) -> Tuple[str, Figure]:
    """Run inference on a single sentence. Returns prediction and attention heatmap."""""
    input = input.lower().strip().rstrip(".") + "."
    input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
    output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
    output = target_spm.decode(output.detach().tolist())
    input_tokens = source_spm.encode(input, out_type=str)
    output_tokens = target_spm.encode(output, out_type=str)
    return output, attention_heatmap(input_tokens, output_tokens, weights.detach().numpy())


if __name__ == "__main__":
    interface = gr.Interface(
        run, 
        inputs=gr.inputs.Textbox(lines=4, label="German"),
        outputs=[
            gr.outputs.Textbox(label="English"), 
            gr.outputs.Image(type="plot", label="Attention Heatmap"),
        ],
        title = "Multi30k Translation Widget",
        examples=random.sample(test_source, k=30),
        examples_per_page=10,
        allow_flagging="never",
        theme="huggingface",
        live=True,
    )

    interface.launch(
        enable_queue=True,
        cache_examples=True,
    )