File size: 4,209 Bytes
c5d0e38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
Hugging Face's logo
Hugging Face
Search models, datasets, users...
Models
Datasets
Resources
Solutions
Pricing

Space:
Flax Community's picture
flax-community
/
papuGaPT2 Copied
Runtime error
App
Files and versions
Settings
papuGaPT2
/
app.py
miwojc's picture
miwojc
Update app.py
d4fb97b
2 minutes ago
raw
history
blame
edit
3,870 Bytes
import json
import random
import requests
from mtranslate import translate
import streamlit as st
MODEL_URL = "https://api-inference.huggingface.co/models/flax-community/papuGaPT2"
PROMPT_LIST = {
    "Najsmaczniejszy owoc to...": ["Najsmaczniejszy owoc to "],
    "Cześć, mam na imię...": ["Cześć, mam na imię "],
    "Największym polskim poetą był...": ["Największym polskim poetą był "],
}
def query(payload, model_url):
    data = json.dumps(payload)
    print("model url:", model_url)
    response = requests.request(
        "POST", model_url, headers={}, data=data
    )
    return json.loads(response.content.decode("utf-8"))
def process(
    text: str, model_name: str, max_len: int, temp: float, top_k: int, top_p: float
):
    payload = {
        "inputs": text,
        "parameters": {
            "max_new_tokens": max_len,
            "top_k": top_k,
            "top_p": top_p,
            "temperature": temp,
            "repetition_penalty": 2.0,
        },
        "options": {
            "use_cache": True,
        },
    }
    return query(payload, model_name)
# Page
st.set_page_config(page_title="papuGaPT2 (Polish GPT-2) Demo")
st.title("papuGaPT2 (Polish GPT-2")
# Sidebar
st.sidebar.subheader("Configurable parameters")
max_len = st.sidebar.number_input(
    "Maximum length",
    value=100,
    help="The maximum length of the sequence to be generated.",
)
temp = st.sidebar.slider(
    "Temperature",
    value=1.0,
    min_value=0.1,
    max_value=100.0,
    help="The value used to module the next token probabilities.",
)
top_k = st.sidebar.number_input(
    "Top k",
    value=10,
    help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
)
top_p = st.sidebar.number_input(
    "Top p",
    value=0.95,
    help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
)
do_sample = st.sidebar.selectbox(
    "Sampling?",
    (True, False),
    help="Whether or not to use sampling; use greedy decoding otherwise.",
)
# Body
st.markdown(
    """
    papuGaPT2 (Polish GPT-2) model trained from scratch on OSCAR dataset.
    
    The models were trained with Jax and Flax using TPUs as part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organised by HuggingFace.
    """
)
model_name = MODEL_URL
ALL_PROMPTS = list(PROMPT_LIST.keys()) + ["Custom"]
prompt = st.selectbox("Prompt", ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
if prompt == "Custom":
    prompt_box = "Enter your text here"
else:
    prompt_box = random.choice(PROMPT_LIST[prompt])
text = st.text_area("Enter text", prompt_box)
if st.button("Run"):
    with st.spinner(text="Getting results..."):
        st.subheader("Result")
        print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
        result = process(
            text=text,
            model_name=model_name,
            max_len=int(max_len),
            temp=temp,
            top_k=int(top_k),
            top_p=float(top_p),
        )
        print("result:", result)
        if "error" in result:
            if type(result["error"]) is str:
                st.write(f'{result["error"]}.', end=" ")
                if "estimated_time" in result:
                    st.write(
                        f'Please try again in about {result["estimated_time"]:.0f} seconds.'
                    )
            else:
                if type(result["error"]) is list:
                    for error in result["error"]:
                        st.write(f"{error}")
        else:
            result = result[0]["generated_text"]
            st.write(result.replace("\
", "  \
"))
            st.text("English translation")
            st.write(translate(result, "en", "es").replace("\
", "  \
"))