miwojc commited on
Commit
ecfed43
1 Parent(s): 42f3ce1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -11
app.py CHANGED
@@ -1,11 +1,112 @@
1
- import gradio as gr
2
-
3
- title = "papuGaPT2 Demo"
4
- description = "Demo for polish GPT2 for text generation. To use it, simply add your text, or click one of the examples to load them."
5
- examples = [
6
- ['Dialogi niedobre… Bardzo niedobre dialogi są. W ogóle brak akcji jest. Nic się nie dzieje.'],
7
- ["Na przykład w jednym sklepie na półkach po rocznym leżeniu cukier ma 80 proc. cukru w cukrze,"],
8
- ["Moja jest tylko racja i to święta racja. Bo nawet jak jest twoja,"]
9
- ]
10
-
11
- gr.Interface.load("huggingface/flax-community/papuGaPT2", inputs=gr.inputs.Textbox(lines=5, label="Input Text"),title=title,description=description, examples=examples).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import requests
4
+ from mtranslate import translate
5
+ import streamlit as st
6
+ MODEL_URL = "https://api-inference.huggingface.co/models/flax-community/papuGaPT2"
7
+ PROMPT_LIST = {
8
+ "Najsmaczniejszy owoc to...": ["Najsmaczniejszy owoc to "],
9
+ "Cześć, mam na imię...": ["Cześć, mam na imię "],
10
+ "Największym polskim poetą był...": ["Największym polskim poetą był "],
11
+ }
12
+ def query(payload, model_url):
13
+ data = json.dumps(payload)
14
+ print("model url:", model_url)
15
+ response = requests.request(
16
+ "POST", model_url, headers={}, data=data
17
+ )
18
+ return json.loads(response.content.decode("utf-8"))
19
+ def process(
20
+ text: str, model_name: str, max_len: int, temp: float, top_k: int, top_p: float
21
+ ):
22
+ payload = {
23
+ "inputs": text,
24
+ "parameters": {
25
+ "max_new_tokens": max_len,
26
+ "top_k": top_k,
27
+ "top_p": top_p,
28
+ "temperature": temp,
29
+ "repetition_penalty": 2.0,
30
+ },
31
+ "options": {
32
+ "use_cache": True,
33
+ },
34
+ }
35
+ return query(payload, model_name)
36
+ # Page
37
+ st.set_page_config(page_title="papuGaPT2 (Polish GPT-2) Demo")
38
+ st.title("papuGaPT2 (Polish GPT-2")
39
+ # Sidebar
40
+ st.sidebar.subheader("Configurable parameters")
41
+ max_len = st.sidebar.number_input(
42
+ "Maximum length",
43
+ value=100,
44
+ help="The maximum length of the sequence to be generated.",
45
+ )
46
+ temp = st.sidebar.slider(
47
+ "Temperature",
48
+ value=1.0,
49
+ min_value=0.1,
50
+ max_value=100.0,
51
+ help="The value used to module the next token probabilities.",
52
+ )
53
+ top_k = st.sidebar.number_input(
54
+ "Top k",
55
+ value=10,
56
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
57
+ )
58
+ top_p = st.sidebar.number_input(
59
+ "Top p",
60
+ value=0.95,
61
+ help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
62
+ )
63
+ do_sample = st.sidebar.selectbox(
64
+ "Sampling?",
65
+ (True, False),
66
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
67
+ )
68
+ # Body
69
+ st.markdown(
70
+ """
71
+ papuGaPT2 (Polish GPT-2) model trained from scratch on OSCAR dataset.
72
+
73
+ The models were trained with Jax and Flax using TPUs as part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organised by HuggingFace.
74
+ """
75
+ )
76
+ model_name = MODEL_name
77
+ ALL_PROMPTS = list(PROMPT_LIST.keys()) + ["Custom"]
78
+ prompt = st.selectbox("Prompt", ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
79
+ if prompt == "Custom":
80
+ prompt_box = "Enter your text here"
81
+ else:
82
+ prompt_box = random.choice(PROMPT_LIST[prompt])
83
+ text = st.text_area("Enter text", prompt_box)
84
+ if st.button("Run"):
85
+ with st.spinner(text="Getting results..."):
86
+ st.subheader("Result")
87
+ print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
88
+ result = process(
89
+ text=text,
90
+ model_name=model_name,
91
+ max_len=int(max_len),
92
+ temp=temp,
93
+ top_k=int(top_k),
94
+ top_p=float(top_p),
95
+ )
96
+ print("result:", result)
97
+ if "error" in result:
98
+ if type(result["error"]) is str:
99
+ st.write(f'{result["error"]}.', end=" ")
100
+ if "estimated_time" in result:
101
+ st.write(
102
+ f'Please try again in about {result["estimated_time"]:.0f} seconds.'
103
+ )
104
+ else:
105
+ if type(result["error"]) is list:
106
+ for error in result["error"]:
107
+ st.write(f"{error}")
108
+ else:
109
+ result = result[0]["generated_text"]
110
+ st.write(result.replace("\n", " \n"))
111
+ st.text("English translation")
112
+ st.write(translate(result, "en", "es").replace("\n", " \n"))