File size: 2,637 Bytes
75b9522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from pathlib import Path
from functools import partial

from joeynmt.prediction import predict
from joeynmt.helpers import (
    check_version,
    load_checkpoint,
    load_config,
    parse_train_args,
    resolve_ckpt_path,

)
from joeynmt.model import build_model
from joeynmt.tokenizers import build_tokenizer
from joeynmt.vocabulary import build_vocab
from joeynmt.datasets import build_dataset

import gradio as gr

# INPUT = "سلاو لە ناو گلی کرد"

cfg_file = 'config.yaml'
ckpt = './models/Sorani-Arabic/best.ckpt'

cfg = load_config(Path(cfg_file))
    # parse and validate cfg
model_dir, load_model, device, n_gpu, num_workers, _, fp16 = parse_train_args(
    cfg["training"], mode="prediction")
test_cfg = cfg["testing"]
src_cfg = cfg["data"]["src"]
trg_cfg = cfg["data"]["trg"]

load_model = load_model if ckpt is None else Path(ckpt)
ckpt = resolve_ckpt_path(load_model, model_dir)

src_vocab, trg_vocab = build_vocab(cfg["data"], model_dir=model_dir)

model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)

# load model state from disk
model_checkpoint = load_checkpoint(ckpt, device=device)
model.load_state_dict(model_checkpoint["model_state"])

if device.type == "cuda":
    model.to(device)

tokenizer = build_tokenizer(cfg["data"])
sequence_encoder = {
    src_cfg["lang"]: partial(src_vocab.sentences_to_ids, bos=False, eos=True),
    trg_cfg["lang"]: None,
}

test_cfg["batch_size"] = 1  # CAUTION: this will raise an error if n_gpus > 1
test_cfg["batch_type"] = "sentence"

test_data = build_dataset(
    dataset_type="stream",
    path=None,
    src_lang=src_cfg["lang"],
    trg_lang=trg_cfg["lang"],
    split="test",
    tokenizer=tokenizer,
    sequence_encoder=sequence_encoder,
)
# test_data.set_item(INPUT.rstrip())


def _translate_data(test_data, cfg=test_cfg):
    """Translates given dataset, using parameters from outer scope."""
    _, _, hypotheses, trg_tokens, trg_scores, _ = predict(
        model=model,
        data=test_data,
        compute_loss=False,
        device=device,
        n_gpu=n_gpu,
        normalization="none",
        num_workers=num_workers,
        cfg=cfg,
        fp16=fp16,
    )
    return hypotheses[0]



def normalize(text):
    test_data.set_item(text)
    result = _translate_data(test_data)
    return result

examples = [
    ["ياخوا تةمةن دريژبيت بوئةم ميللةتة"],
    ["سلاو برا جونی؟"],
]



demo = gr.Interface(
    fn=normalize,
    inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
    outputs=gr.outputs.Textbox(label="Output Text" ),
    examples=examples
)

demo.launch()