Galuh commited on
Commit
c389ccc
1 Parent(s): 638e90b

Fix prompt reloading bug; add new finetuned model; replace api

Browse files
Files changed (1) hide show
  1. app.py +93 -93
app.py CHANGED
@@ -4,19 +4,32 @@ from mtranslate import translate
4
  from prompts import PROMPT_LIST
5
  import streamlit as st
6
  import random
 
 
7
  import fasttext
 
8
 
9
- headers = {}
10
  LOGO = "huggingwayang.png"
 
11
  MODELS = {
12
- "GPT-2 Small": {
13
- "url": "https://api-inference.huggingface.co/models/flax-community/gpt2-small-indonesian"
14
- },
15
- "GPT-2 Medium": {
16
- "url": "https://api-inference.huggingface.co/models/flax-community/gpt2-medium-indonesian"
17
- },
18
  }
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def get_image(text: str):
22
  url = "https://wikisearch.uncool.ai/get_image/"
@@ -33,45 +46,12 @@ def get_image(text: str):
33
  image = ""
34
  return image
35
 
36
- def query(payload, model_name):
37
- data = json.dumps(payload)
38
- # print("model url:", MODELS[model_name]["url"])
39
- response = requests.request("POST", MODELS[model_name]["url"], headers=headers, data=data)
40
- return json.loads(response.content.decode("utf-8"))
41
-
42
-
43
- def process(text: str,
44
- model_name: str,
45
- max_len: int,
46
- temp: float,
47
- top_k: int,
48
- top_p: float):
49
-
50
- payload = {
51
- "inputs": text,
52
- "parameters": {
53
- "max_new_tokens": max_len,
54
- "top_k": top_k,
55
- "top_p": top_p,
56
- "temperature": temp,
57
- "repetition_penalty": 2.0,
58
- },
59
- "options": {
60
- "use_cache": True,
61
- }
62
- }
63
- return query(payload, model_name)
64
-
65
  st.set_page_config(page_title="Indonesian GPT-2 Demo")
66
- st.title("Indonesian GPT-2")
67
 
68
- try:
69
- token = st.secrets["flax_community_token"]
70
- headers = {"Authorization": f"Bearer {token}"}
71
- except FileNotFoundError:
72
- print(f"Token is not found")
73
 
74
  ft_model = fasttext.load_model('lid.176.ftz')
 
75
  # Sidebar
76
  st.sidebar.image(LOGO)
77
  st.sidebar.subheader("Configurable parameters")
@@ -85,25 +65,23 @@ max_len = st.sidebar.number_input(
85
  temp = st.sidebar.slider(
86
  "Temperature",
87
  value=1.0,
88
- min_value=0.1,
89
  max_value=100.0,
90
  help="The value used to module the next token probabilities."
91
  )
92
 
93
  top_k = st.sidebar.number_input(
94
  "Top k",
95
- value=10,
96
  help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
97
  )
98
 
99
  top_p = st.sidebar.number_input(
100
  "Top p",
101
- value=0.95,
102
  help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation."
103
  )
104
 
105
- # do_sample = st.sidebar.selectbox('Sampling?', (True, False), help="Whether or not to use sampling; use greedy decoding otherwise.")
106
-
107
  st.markdown(
108
  """
109
  This demo uses the [small](https://huggingface.co/flax-community/gpt2-small-indonesian) and
@@ -111,68 +89,90 @@ st.markdown(
111
  trained on the Indonesian [Oscar](https://huggingface.co/datasets/oscar), [MC4](https://huggingface.co/datasets/mc4)
112
  and [Wikipedia](https://huggingface.co/datasets/wikipedia) dataset. We created it as part of the
113
  [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).
114
-
115
  The demo supports "multi language" ;-), feel free to try a prompt on your language. We are also experimenting with
116
  the sentence based image search using Wikipedia passages encoded with distillbert, and search the encoded sentence
117
  in the encoded passages using Facebook's Faiss.
118
  """
119
  )
120
 
121
- model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium']))
122
 
123
- ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
124
- prompt = st.selectbox('Please choose a predefined prompt or create your custom text.', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
 
 
125
 
126
- if prompt == "Custom":
127
- prompt_box = "Feel free to write text in any language"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  else:
129
- prompt_box = random.choice(PROMPT_LIST[prompt])
 
130
 
131
- text = st.text_area("Enter text", prompt_box)
132
 
133
  if st.button("Run"):
134
  with st.spinner(text="Getting results..."):
135
- lang_predictions, lang_probability = ft_model.predict(text.replace("\n", " "), k=3)
136
- # print(f"lang: {lang_predictions}, {lang_probability}")
137
  if "__label__id" in lang_predictions:
138
  lang = "id"
 
139
  else:
140
  lang = lang_predictions[0].replace("__label__", "")
141
- text = translate(text, "id", lang)
142
- # print(f"{lang}: {text}")
143
  st.subheader("Result")
144
- # print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
145
- result = process(text=text,
146
- model_name=model_name,
147
- max_len=int(max_len),
148
- temp=temp,
149
- top_k=int(top_k),
150
- top_p=float(top_p))
151
-
152
- # print("result:", result)
153
- if "error" in result:
154
- if type(result["error"]) is str:
155
- st.write(f'{result["error"]}.', end=" ")
156
- if "estimated_time" in result:
157
- st.write(f'Please try it again in about {result["estimated_time"]:.0f} seconds')
158
- else:
159
- if type(result["error"]) is list:
160
- for error in result["error"]:
161
- st.write(f'{error}')
 
 
 
162
  else:
163
- result = result[0]["generated_text"]
164
- st.write(result.replace("\n", " \n"))
165
- st.text("Translation")
166
- translation = translate(result, "en", "id")
167
- if lang == "id":
168
- st.write(translation.replace("\n", " \n"))
169
- else:
170
- st.write(translate(result, lang, "id").replace("\n", " \n"))
171
-
172
- image_cat = "https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif"
173
- image = get_image(translation.replace("\"", "'"))
174
- if image is not "":
175
- st.image(image, width=400)
176
- else:
177
- # display cat image if no image found
178
- st.image(image_cat, width=400)
 
4
  from prompts import PROMPT_LIST
5
  import streamlit as st
6
  import random
7
+ import transformers
8
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
9
  import fasttext
10
+ import SessionState
11
 
 
12
  LOGO = "huggingwayang.png"
13
+
14
  MODELS = {
15
+ "GPT-2 Small": "flax-community/gpt2-small-indonesian",
16
+ "GPT-2 Medium": "flax-community/gpt2-medium-indonesian",
17
+ "GPT-2 Small finetuned on Indonesian academic journals": "Galuh/id-journal-gpt2"
 
 
 
18
  }
19
 
20
+ headers = {}
21
+
22
+ @st.cache(show_spinner=False)
23
+ def load_gpt(model_type):
24
+ model = GPT2LMHeadModel.from_pretrained(MODELS[model_type])
25
+
26
+ return model
27
+
28
+ @st.cache(show_spinner=False, hash_funcs={transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer: lambda _: None})
29
+ def load_gpt_tokenizer(model_type):
30
+ tokenizer = GPT2Tokenizer.from_pretrained(MODELS[model_type])
31
+
32
+ return tokenizer
33
 
34
  def get_image(text: str):
35
  url = "https://wikisearch.uncool.ai/get_image/"
 
46
  image = ""
47
  return image
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  st.set_page_config(page_title="Indonesian GPT-2 Demo")
 
50
 
51
+ st.title("Indonesian GPT-2")
 
 
 
 
52
 
53
  ft_model = fasttext.load_model('lid.176.ftz')
54
+
55
  # Sidebar
56
  st.sidebar.image(LOGO)
57
  st.sidebar.subheader("Configurable parameters")
 
65
  temp = st.sidebar.slider(
66
  "Temperature",
67
  value=1.0,
68
+ min_value=0.0,
69
  max_value=100.0,
70
  help="The value used to module the next token probabilities."
71
  )
72
 
73
  top_k = st.sidebar.number_input(
74
  "Top k",
75
+ value=50,
76
  help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
77
  )
78
 
79
  top_p = st.sidebar.number_input(
80
  "Top p",
81
+ value=1.0,
82
  help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation."
83
  )
84
 
 
 
85
  st.markdown(
86
  """
87
  This demo uses the [small](https://huggingface.co/flax-community/gpt2-small-indonesian) and
 
89
  trained on the Indonesian [Oscar](https://huggingface.co/datasets/oscar), [MC4](https://huggingface.co/datasets/mc4)
90
  and [Wikipedia](https://huggingface.co/datasets/wikipedia) dataset. We created it as part of the
91
  [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).
92
+
93
  The demo supports "multi language" ;-), feel free to try a prompt on your language. We are also experimenting with
94
  the sentence based image search using Wikipedia passages encoded with distillbert, and search the encoded sentence
95
  in the encoded passages using Facebook's Faiss.
96
  """
97
  )
98
 
99
+ model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium', 'GPT-2 Small finetuned on Indonesian academic journals']))
100
 
101
+ if model_name in ["GPT-2 Small", "GPT-2 Medium"]:
102
+ prompt_group_name = "GPT-2"
103
+ elif model_name in ["GPT-2 Small finetuned on Indonesian academic journals"]:
104
+ prompt_group_name = "Indonesian Journals"
105
 
106
+ session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
107
+
108
+ ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
109
+ prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
110
+
111
+ # Update prompt
112
+ if session_state.prompt is None:
113
+ session_state.prompt = prompt
114
+ elif session_state.prompt is not None and (prompt != session_state.prompt):
115
+ session_state.prompt = prompt
116
+ session_state.prompt_box = None
117
+ session_state.text = None
118
+ else:
119
+ session_state.prompt = prompt
120
+
121
+ # Update prompt box
122
+ if session_state.prompt == "Custom":
123
+ session_state.prompt_box = "Enter your text here"
124
  else:
125
+ if session_state.prompt is not None and session_state.prompt_box is None:
126
+ session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt])
127
 
128
+ session_state.text = st.text_area("Enter text", session_state.prompt_box)
129
 
130
  if st.button("Run"):
131
  with st.spinner(text="Getting results..."):
132
+ lang_predictions, lang_probability = ft_model.predict(session_state.text.replace("\n", " "), k=3)
 
133
  if "__label__id" in lang_predictions:
134
  lang = "id"
135
+ text = session_state.text
136
  else:
137
  lang = lang_predictions[0].replace("__label__", "")
138
+ text = translate(session_state.text, "id", lang)
139
+
140
  st.subheader("Result")
141
+ model = load_gpt(model_name)
142
+ tokenizer = load_gpt_tokenizer(model_name)
143
+
144
+ input_ids = tokenizer.encode(text, return_tensors='pt')
145
+ output = model.generate(input_ids=input_ids,
146
+ max_length=max_len,
147
+ temperature=temp,
148
+ top_k=top_k,
149
+ top_p=top_p,
150
+ repetition_penalty=2.0)
151
+
152
+ text = tokenizer.decode(output[0],
153
+ skip_special_tokens=True)
154
+ st.write(text.replace("\n", " \n"))
155
+
156
+ st.text("Translation")
157
+ translation = translate(text, "en", "id")
158
+
159
+ if lang == "id":
160
+ st.write(translation.replace("\n", " \n"))
161
+
162
  else:
163
+ st.write(translate(text, lang, "id").replace("\n", " \n"))
164
+
165
+ image_cat = "https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif"
166
+ image = get_image(translation.replace("\"", "'"))
167
+
168
+ if image is not "":
169
+ st.image(image, width=400)
170
+
171
+ else:
172
+ # display cat image if no image found
173
+ st.image(image_cat, width=400)
174
+
175
+ # Reset state
176
+ session_state.prompt = None
177
+ session_state.prompt_box = None
178
+ session_state.text = None