File size: 2,314 Bytes
e1188cc
 
 
 
 
896ca2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1188cc
896ca2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4134ae7
896ca2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import streamlit as st

from transformers import pipeline
from nagisa_bert import NagisaBertTokenizer


@st.cache(allow_output_mutation=True)
def load_tasks():
    model_name = "taishi-i/nagisa_bert"

    tokenizer = NagisaBertTokenizer.from_pretrained(model_name)
    fill_mask = pipeline(
        "fill-mask",
        model=model_name,
        tokenizer=tokenizer
    )

    feature_extract = pipeline(
        "feature-extraction",
        model=model_name,
        tokenizer=tokenizer
    )

    tasks = {
        "Tokenize": tokenizer,
        "Fill-mask": fill_mask,
        "Feature-extraction": feature_extract
    }
    return tasks


task2samples = {
    "Fill-mask": {
        "label": "[MASK]を含むテキストを入力してください。",
        "value": "nagisaで[MASK]できるモデルです"
    },
    "Feature-extraction": {
        "label": "[CLS]トークンのベクトルを取得します。ベクトル化するテキストを入力してください。",
        "value": "nagisaで利用できるモデルです"
    },
    "Tokenize": {
        "label": "トークナイズするテキストを入力してください。",
        "value": "nagisaで利用できるモデルです"
    },
}


def main():
    tasks = load_tasks()

    task_selection = st.selectbox(
        "Select a task (Fill-mask, Feature-extraction, Tokenize)",
        ("Fill-mask", "Feature-extraction", "Tokenize"))

    with st.form("Fill-mask"):

        text = st.text_area(
            label=task2samples[task_selection]["label"],
            value=task2samples[task_selection]["value"],
            max_chars=512
        )

        submitted = st.form_submit_button("Submit")

        if submitted:
            if task_selection == "Fill-mask":
                if "[MASK]" in text:
                    out = tasks[task_selection](text)
                    st.json(out)
                else:
                    st.write("[MASK] を含むテキストを入力してください。")
            elif task_selection == "Feature-extraction":
                out = tasks[task_selection](text)[0][0]
                st.code(out)
            elif task_selection == "Tokenize":
                out = tasks[task_selection].tokenize(text)
                st.json(out)


if __name__ == "__main__":
    main()