miwojc commited on
Commit
d2b205a
1 Parent(s): a90542c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -9
app.py CHANGED
@@ -1,9 +1,116 @@
1
- import gradio as gr
2
- title = "papuGaPT2 (Polish GPT2) language model demo"
3
- description = "This demo showcases the text generation capabilities for papuGaPT2, a GPT2 model pre-trained from scratch using Polish subset of the multilingual Oscar corpus. To use it, add your text to 'Input Text' box, or click one of the below examples to load them and clik 'Submit' button. The model will generate text based on the entered text (prompt). For more information including dataset, training and evaluation procedure, intended use, limitations and bias analysis see the model card at https://huggingface.co/flax-community/papuGaPT2)"
4
- examples = [
5
- ["Najsmaczniejszy owoc to"],
6
- ["Największym polskim poetą był"],
7
- ["Cześć mam na imię"]
8
- ]
9
- gr.Interface.load("huggingface/flax-community/papuGaPT2", inputs=gr.inputs.Textbox(lines=5, label="Input Text"),title=title, description=description, examples=examples).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import requests
4
+ from mtranslate import translate
5
+ import streamlit as st
6
+ MODEL_URL = "https://api-inference.huggingface.co/models/flax-community/papuGaPT2"
7
+ PROMPT_LIST = {
8
+ "Najsmaczniejszy owoc to...": ["Najsmaczniejszy owoc to "],
9
+ "Cześć, mam na imię...": ["Cześć, mam na imię "],
10
+ "Największym polskim poetą był...": ["Największym polskim poetą był "],
11
+ }
12
+ def query(payload, model_url):
13
+ data = json.dumps(payload)
14
+ print("model url:", model_url)
15
+ response = requests.request(
16
+ "POST", model_url, headers={}, data=data
17
+ )
18
+ return json.loads(response.content.decode("utf-8"))
19
+ def process(
20
+ text: str, model_name: str, max_len: int, temp: float, top_k: int, top_p: float
21
+ ):
22
+ payload = {
23
+ "inputs": text,
24
+ "parameters": {
25
+ "max_new_tokens": max_len,
26
+ "top_k": top_k,
27
+ "top_p": top_p,
28
+ "temperature": temp,
29
+ "repetition_penalty": 2.0,
30
+ },
31
+ "options": {
32
+ "use_cache": True,
33
+ },
34
+ }
35
+ return query(payload, model_name)
36
+ # Page
37
+ st.set_page_config(page_title="papuGaPT2 (Polish GPT-2) Demo")
38
+ st.title("papuGaPT2 (Polish GPT-2")
39
+ # Sidebar
40
+ st.sidebar.subheader("Configurable parameters")
41
+ max_len = st.sidebar.number_input(
42
+ "Maximum length",
43
+ value=100,
44
+ help="The maximum length of the sequence to be generated.",
45
+ )
46
+ temp = st.sidebar.slider(
47
+ "Temperature",
48
+ value=1.0,
49
+ min_value=0.1,
50
+ max_value=100.0,
51
+ help="The value used to module the next token probabilities.",
52
+ )
53
+ top_k = st.sidebar.number_input(
54
+ "Top k",
55
+ value=10,
56
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
57
+ )
58
+ top_p = st.sidebar.number_input(
59
+ "Top p",
60
+ value=0.95,
61
+ help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
62
+ )
63
+ do_sample = st.sidebar.selectbox(
64
+ "Sampling?",
65
+ (True, False),
66
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
67
+ )
68
+ # Body
69
+ st.markdown(
70
+ """
71
+ papuGaPT2 (Polish GPT-2) model trained from scratch on OSCAR dataset.
72
+
73
+ The models were trained with Jax and Flax using TPUs as part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organised by HuggingFace.
74
+ """
75
+ )
76
+ model_name = MODEL_URL
77
+ ALL_PROMPTS = list(PROMPT_LIST.keys()) + ["Custom"]
78
+ prompt = st.selectbox("Prompt", ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
79
+ if prompt == "Custom":
80
+ prompt_box = "Enter your text here"
81
+ else:
82
+ prompt_box = random.choice(PROMPT_LIST[prompt])
83
+ text = st.text_area("Enter text", prompt_box)
84
+ if st.button("Run"):
85
+ with st.spinner(text="Getting results..."):
86
+ st.subheader("Result")
87
+ print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
88
+ result = process(
89
+ text=text,
90
+ model_name=model_name,
91
+ max_len=int(max_len),
92
+ temp=temp,
93
+ top_k=int(top_k),
94
+ top_p=float(top_p),
95
+ )
96
+ print("result:", result)
97
+ if "error" in result:
98
+ if type(result["error"]) is str:
99
+ st.write(f'{result["error"]}.', end=" ")
100
+ if "estimated_time" in result:
101
+ st.write(
102
+ f'Please try again in about {result["estimated_time"]:.0f} seconds.'
103
+ )
104
+ else:
105
+ if type(result["error"]) is list:
106
+ for error in result["error"]:
107
+ st.write(f"{error}")
108
+ else:
109
+ result = result[0]["generated_text"]
110
+ st.write(result.replace("\
111
+ ", " \
112
+ "))
113
+ st.text("English translation")
114
+ st.write(translate(result, "en", "es").replace("\
115
+ ", " \
116
+ "))