File size: 3,570 Bytes
c6d338c
 
 
 
be6f31c
c6d338c
be6f31c
8136881
 
 
 
 
 
 
 
 
 
 
 
 
28efd24
8136881
 
 
c6d338c
 
 
 
 
 
8136881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d338c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28efd24
c6d338c
 
 
 
 
 
 
 
 
 
 
 
e9acb28
c6d338c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28efd24
c6d338c
 
28efd24
c6d338c
28efd24
 
e9acb28
28efd24
e9acb28
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import requests
from mtranslate import translate
from prompts import PROMPT_LIST
import streamlit as st
import random

headers = {"Authorization": f"Bearer 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"}
MODELS = {
    "GPT-2 Small": {
        "url": "https://api-inference.huggingface.co/models/flax-community/gpt2-small-indonesian"
    },
    "GPT-2 Medium": {
        "url": "https://api-inference.huggingface.co/models/flax-community/gpt2-medium-indonesian"
    },
}


def query(payload, model_name):
    data = json.dumps(payload)
    print("model url:", MODELS[model_name]["url"])
    response = requests.request("POST", MODELS[model_name]["url"], headers=headers, data=data)
    return json.loads(response.content.decode("utf-8"))


def process(text: str,
            model_name: str,
            max_len: int,
            temp: float,
            top_k: int,
            top_p: float):

    payload = {
        "inputs": text,
        "parameters": {
            "max_new_tokens": max_len,
            "top_k": top_k,
            "top_p": top_p,
            "temperature": temp,
            "repetition_penalty": 2.0,
        },
        "options": {
            "use_cache": True,
        }
    }
    return query(payload, model_name)

st.set_page_config(page_title="Indonesian GPT-2 Demo")

st.title("Indonesian GPT-2")

st.sidebar.subheader("Configurable parameters")

max_len = st.sidebar.text_input(
    "Maximum length",
    value=100,
    help="The maximum length of the sequence to be generated."
)

temp = st.sidebar.slider(
    "Temperature",
    value=1.0,
    min_value=0.1,
    max_value=100.0,
    help="The value used to module the next token probabilities."
)

top_k = st.sidebar.text_input(
    "Top k",
    value=50,
    help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
)

top_p = st.sidebar.text_input(
    "Top p",
    value=0.95,
    help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation."
)

do_sample = st.sidebar.selectbox('Sampling?', (True, False), help="Whether or not to use sampling; use greedy decoding otherwise.")

st.markdown(
    """Indonesian GPT-2 demo. Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/)."""
)

model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium']))

ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)

if prompt == "Custom":
    prompt_box = "Enter your text here"
else:
    prompt_box = random.choice(PROMPT_LIST[prompt])

text = st.text_area("Enter text", prompt_box)

if st.button("Run"):
    with st.spinner(text="Getting results..."):
        st.subheader("Result")
        print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
        result = process(text=text,
                         model_name=model_name,
                         max_len=int(max_len),
                         temp=temp,
                         top_k=int(top_k),
                         top_p=float(top_p))

        print("result:", result)
        if "error" in result:
            st.write(f'{result["error"]}. Please try it again in about {result["estimated_time"]:.0f} seconds')
        else:
            result = result[0]["generated_text"]
            st.write(result.replace("\n", "  \n"))
            st.text("English translation")
            st.write(translate(result, "en", "id").replace("\n", "  \n"))