Galuh Sahid commited on
Commit
f0d4713
1 Parent(s): 2786013
Files changed (2) hide show
  1. app.py +79 -52
  2. requirements.txt +0 -4
app.py CHANGED
@@ -4,33 +4,25 @@ from mtranslate import translate
4
  from prompts import PROMPT_LIST
5
  import streamlit as st
6
  import random
7
- import transformers
8
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
9
  import fasttext
10
  import SessionState
11
 
 
 
12
  LOGO = "huggingwayang.png"
13
 
14
  MODELS = {
15
- "GPT-2 Small": "flax-community/gpt2-small-indonesian",
16
- "GPT-2 Medium": "flax-community/gpt2-medium-indonesian",
17
- "GPT-2 Small finetuned on Indonesian academic journals": "Galuh/id-journal-gpt2"
 
 
 
 
 
 
18
  }
19
 
20
- headers = {}
21
-
22
- @st.cache(show_spinner=False)
23
- def load_gpt(model_type):
24
- model = GPT2LMHeadModel.from_pretrained(MODELS[model_type])
25
-
26
- return model
27
-
28
- @st.cache(show_spinner=False, hash_funcs={transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer: lambda _: None})
29
- def load_gpt_tokenizer(model_type):
30
- tokenizer = GPT2Tokenizer.from_pretrained(MODELS[model_type])
31
-
32
- return tokenizer
33
-
34
  def get_image(text: str):
35
  url = "https://wikisearch.uncool.ai/get_image/"
36
  try:
@@ -46,10 +38,44 @@ def get_image(text: str):
46
  image = ""
47
  return image
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  st.set_page_config(page_title="Indonesian GPT-2 Demo")
50
 
51
  st.title("Indonesian GPT-2")
52
 
 
 
 
 
 
 
53
  ft_model = fasttext.load_model('lid.176.ftz')
54
 
55
  # Sidebar
@@ -138,41 +164,42 @@ if st.button("Run"):
138
  text = translate(session_state.text, "id", lang)
139
 
140
  st.subheader("Result")
141
- model = load_gpt(model_name)
142
- tokenizer = load_gpt_tokenizer(model_name)
143
-
144
- input_ids = tokenizer.encode(text, return_tensors='pt')
145
- output = model.generate(input_ids=input_ids,
146
- max_length=max_len,
147
- temperature=temp,
148
- top_k=top_k,
149
- top_p=top_p,
150
- repetition_penalty=2.0)
151
-
152
- text = tokenizer.decode(output[0],
153
- skip_special_tokens=True)
154
- st.write(text.replace("\n", " \n"))
155
-
156
- st.text("Translation")
157
- translation = translate(text, "en", "id")
158
-
159
- if lang == "id":
160
- st.write(translation.replace("\n", " \n"))
161
-
162
- else:
163
- st.write(translate(text, lang, "id").replace("\n", " \n"))
164
-
165
- image_cat = "https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif"
166
- image = get_image(translation.replace("\"", "'"))
167
-
168
- if image is not "":
169
- st.image(image, width=400)
170
-
171
  else:
172
- # display cat image if no image found
173
- st.image(image_cat, width=400)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Reset state
176
  session_state.prompt = None
177
  session_state.prompt_box = None
178
- session_state.text = None
4
  from prompts import PROMPT_LIST
5
  import streamlit as st
6
  import random
 
 
7
  import fasttext
8
  import SessionState
9
 
10
+ headers = {}
11
+
12
  LOGO = "huggingwayang.png"
13
 
14
  MODELS = {
15
+ "GPT-2 Small": {
16
+ "url": "https://api-inference.huggingface.co/models/flax-community/gpt2-small-indonesian"
17
+ },
18
+ "GPT-2 Medium": {
19
+ "url": "https://api-inference.huggingface.co/models/flax-community/gpt2-medium-indonesian"
20
+ },
21
+ "GPT-2 Small finetuned on Indonesian academic journals": {
22
+ "url": "https://api-inference.huggingface.co/models/Galuh/id-journal-gpt2"
23
+ },
24
  }
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def get_image(text: str):
27
  url = "https://wikisearch.uncool.ai/get_image/"
28
  try:
38
  image = ""
39
  return image
40
 
41
+ def query(payload, model_name):
42
+ data = json.dumps(payload)
43
+ # print("model url:", MODELS[model_name]["url"])
44
+ response = requests.request("POST", MODELS[model_name]["url"], headers=headers, data=data)
45
+ return json.loads(response.content.decode("utf-8"))
46
+
47
+ def process(text: str,
48
+ model_name: str,
49
+ max_len: int,
50
+ temp: float,
51
+ top_k: int,
52
+ top_p: float):
53
+
54
+ payload = {
55
+ "inputs": text,
56
+ "parameters": {
57
+ "max_new_tokens": max_len,
58
+ "top_k": top_k,
59
+ "top_p": top_p,
60
+ "temperature": temp,
61
+ "repetition_penalty": 2.0,
62
+ },
63
+ "options": {
64
+ "use_cache": True,
65
+ }
66
+ }
67
+ return query(payload, model_name)
68
+
69
  st.set_page_config(page_title="Indonesian GPT-2 Demo")
70
 
71
  st.title("Indonesian GPT-2")
72
 
73
+ try:
74
+ token = st.secrets["flax_community_token"]
75
+ headers = {"Authorization": f"Bearer {token}"}
76
+ except FileNotFoundError:
77
+ print(f"Token is not found")
78
+
79
  ft_model = fasttext.load_model('lid.176.ftz')
80
 
81
  # Sidebar
164
  text = translate(session_state.text, "id", lang)
165
 
166
  st.subheader("Result")
167
+ result = process(text=text,
168
+ model_name=model_name,
169
+ max_len=int(max_len),
170
+ temp=temp,
171
+ top_k=int(top_k),
172
+ top_p=float(top_p))
173
+
174
+ # print("result:", result)
175
+ if "error" in result:
176
+ if type(result["error"]) is str:
177
+ st.write(f'{result["error"]}.', end=" ")
178
+ if "estimated_time" in result:
179
+ st.write(f'Please try it again in about {result["estimated_time"]:.0f} seconds')
180
+ else:
181
+ if type(result["error"]) is list:
182
+ for error in result["error"]:
183
+ st.write(f'{error}')
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  else:
185
+ result = result[0]["generated_text"]
186
+ st.write(result.replace("\n", " \n"))
187
+ st.text("Translation")
188
+ translation = translate(result, "en", "id")
189
+ if lang == "id":
190
+ st.write(translation.replace("\n", " \n"))
191
+ else:
192
+ st.write(translate(result, lang, "id").replace("\n", " \n"))
193
+
194
+ image_cat = "https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif"
195
+ image = get_image(translation.replace("\"", "'"))
196
+ if image is not "":
197
+ st.image(image, width=400)
198
+ else:
199
+ # display cat image if no image found
200
+ st.image(image_cat, width=400)
201
 
202
  # Reset state
203
  session_state.prompt = None
204
  session_state.prompt_box = None
205
+ session_state.text = None
requirements.txt CHANGED
@@ -1,9 +1,5 @@
1
- transformers
2
  streamlit
3
  requests==2.24.0
4
  requests-toolbelt==0.9.1
5
  mtranslate
6
- -f https://download.pytorch.org/whl/torch_stable.html
7
- torch==1.7.1+cpu; sys_platform == 'linux'
8
- torch==1.7.1; sys_platform == 'darwin'
9
  fasttext
 
1
  streamlit
2
  requests==2.24.0
3
  requests-toolbelt==0.9.1
4
  mtranslate
 
 
 
5
  fasttext