import random from typing import * import gradio as gr import numpy as np import seaborn as sns import sentencepiece as sp import torch from huggingface_hub import hf_hub_download from matplotlib.figure import Figure from torchtext.datasets import Multi30k from models import Seq2Seq # Load model model_path = hf_hub_download("msarmi9/multi30k", "models/de-en/version_1/model.bin") model = Seq2Seq(vocab_size=8000, hidden_dim=512, bos_idx=1, eos_idx=2, pad_idx=3, temperature=2) model.load_state_dict(torch.load(model_path)) model.eval() # Load sentencepiece tokenizers source_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/de8000.model") target_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/en8000.model") source_spm = sp.SentencePieceProcessor(model_file=source_spm_path, add_eos=True) target_spm = sp.SentencePieceProcessor(model_file=target_spm_path, add_eos=True) # Load test set for example inputs normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip()) test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en")))) def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> Figure: figure = Figure(dpi=800, tight_layout=True) axes = figure.add_subplot() axes = sns.heatmap(weights, ax=axes, xticklabels=input_tokens, yticklabels=output_tokens, cmap="gray", cbar=False) axes.tick_params(axis="x", rotation=90, length=0) axes.tick_params(axis="y", rotation=0, length=0) axes.xaxis.tick_top() return figure @torch.inference_mode() def run(input: str) -> Tuple[str, Figure]: """Run inference on a single sentence. Returns prediction and attention heatmap.""""" input = input.lower().strip().rstrip(".") + "." input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64) output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80)) output = target_spm.decode(output.detach().tolist()) input_tokens = source_spm.encode(input, out_type=str) output_tokens = target_spm.encode(output, out_type=str) return output, attention_heatmap(input_tokens, output_tokens, weights.detach().numpy()) if __name__ == "__main__": interface = gr.Interface( run, inputs=gr.inputs.Textbox(lines=4, label="German"), outputs=[ gr.outputs.Textbox(label="English"), gr.outputs.Image(type="plot", label="Attention Heatmap"), ], title = "Multi30k Translation Widget", examples=random.sample(test_source, k=30), examples_per_page=10, allow_flagging="never", theme="huggingface", live=True, ) interface.launch( enable_queue=True, cache_examples=True, )