BigSalmon commited on
Commit
414d64e
โ€ข
1 Parent(s): bf70e0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -98
app.py CHANGED
@@ -1,100 +1,96 @@
1
- import json
2
- import requests
 
3
  import streamlit as st
4
  import random
5
- headers = {}
6
- MODELS = {
7
- "GPT-2 Base": {
8
- "url": "https://api-inference.huggingface.co/models/BigSalmon/InformalToFormalLincoln14"
9
- }
10
- }
11
- def query(payload, model_name):
12
- data = json.dumps(payload)
13
- print("model url:", MODELS[model_name]["url"])
14
- response = requests.request(
15
- "POST", MODELS[model_name]["url"], headers=headers, data=data)
16
- return json.loads(response.content.decode("utf-8"))
17
- def process(text: str,
18
- model_name: str,
19
- max_len: int,
20
- temp: float,
21
- top_k: int,
22
- num_return_sequences:int,
23
- top_p: float):
24
- payload = {
25
- "inputs": text,
26
- "parameters": {
27
- "max_new_tokens": max_len,
28
- "top_k": top_k,
29
- "top_p": top_p,
30
- "temperature": temp,
31
- "num_return_sequences": num_return_sequences,
32
- "repetition_penalty": 2.0,
33
- },
34
- "options": {
35
- "use_cache": True,
36
- }
37
- }
38
- return query(payload, model_name)
39
- st.set_page_config(page_title="Thai GPT2 Demo")
40
- st.title("๐Ÿ˜ Thai GPT2")
41
- st.sidebar.subheader("Configurable parameters")
42
- max_len = st.sidebar.text_input(
43
- "Maximum length",
44
- value=5,
45
- help="The maximum length of the sequence to be generated."
46
- )
47
- temp = st.sidebar.slider(
48
- "Temperature",
49
- value=1.0,
50
- min_value=0.7,
51
- max_value=1.5,
52
- help="The value used to module the next token probabilities."
53
- )
54
- top_k = st.sidebar.text_input(
55
- "Top k",
56
- value=50,
57
- help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
58
- )
59
- num_return_sequences = st.sidebar.text_input(
60
- "Returns",
61
- value=5,
62
- help="Number to return."
63
- )
64
- top_p = st.sidebar.text_input(
65
- "Top p",
66
- value=0.95,
67
- help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation."
68
- )
69
- do_sample = st.sidebar.selectbox(
70
- 'Sampling?', (True, False), help="Whether or not to use sampling; use greedy decoding otherwise.")
71
- st.markdown(
72
- """Thai GPT-2 demo. Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/)."""
73
- )
74
- model_name = st.selectbox('Model', (['GPT-2 Base']))
75
- hello = ['Custom', 'Yellow']
76
- prompt = st.selectbox('Prompt', options = hello)
77
- if prompt == "Custom":
78
- prompt_box = "Enter your text here"
79
- text = st.text_area("Enter text", prompt_box)
80
- if st.button("Run"):
81
- with st.spinner(text="Getting results..."):
82
- st.subheader("Result")
83
- print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
84
- result = process(text=text,
85
- model_name=model_name,
86
- max_len=int(max_len),
87
- temp=temp,
88
- num_return_sequences = int(num_return_sequences),
89
- top_k=int(top_k),
90
- top_p=float(top_p))
91
- st.write(result)
92
- if "error" in result:
93
- if type(result["error"]) is str:
94
- st.write(f'{result["error"]}. Please try it again in about {result["estimated_time"]:.0f} seconds')
95
- else:
96
- if type(result["error"]) is list:
97
- for error in result["error"]:
98
- st.write(f'{error}')
99
- else:
100
- print("hey")
 
1
+ import argparse
2
+ import re
3
+ import os
4
  import streamlit as st
5
  import random
6
+ import numpy as np
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ import tokenizers
10
+ #os.environ["TOKENIZERS_PARALLELISM"] = "false"
11
+ random.seed(None)
12
+ suggested_text_list = ['ืคืขื ืื—ืช, ืœืคื ื™ ืฉื ื™ื ืจื‘ื•ืช','ืฉืœื•ื, ืงื•ืจืื™ื ืœื™ ื“ื•ืจื•ืŸ ื•ืื ื™','ื‘ื•ืงืจ ื˜ื•ื‘ ืœื›ื•ืœื','ื•ืื– ื”ืคืจืชื™ ืืช ื›ืœ ื›ืœืœื™ ื”ื˜ืงืก ื›ืฉ']
13
+ @st.cache(hash_funcs={tokenizers.Tokenizer: id, tokenizers.AddedToken: id})
14
+ def load_model(model_name):
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForCausalLM.from_pretrained(model_name)
17
+ return model, tokenizer
18
+ def extend(input_text, max_size=20, top_k=50, top_p=0.95):
19
+ if len(input_text) == 0:
20
+ input_text = ""
21
+ encoded_prompt = tokenizer.encode(
22
+ input_text, add_special_tokens=False, return_tensors="pt")
23
+ encoded_prompt = encoded_prompt.to(device)
24
+ if encoded_prompt.size()[-1] == 0:
25
+ input_ids = None
26
+ else:
27
+ input_ids = encoded_prompt
28
+
29
+ output_sequences = model.generate(
30
+ input_ids=input_ids,
31
+ max_length=max_size + len(encoded_prompt[0]),
32
+ top_k=top_k,
33
+ top_p=top_p,
34
+ do_sample=True,
35
+ num_return_sequences=1)
36
+ # Remove the batch dimension when returning multiple sequences
37
+ if len(output_sequences.shape) > 2:
38
+ output_sequences.squeeze_()
39
+ generated_sequences = []
40
+ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
41
+ generated_sequence = generated_sequence.tolist()
42
+ # Decode text
43
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
44
+ # Remove all text after the stop token
45
+ text = text[: text.find(stop_token) if stop_token else None]
46
+ # Remove all text after 3 newlines
47
+ text = text[: text.find(new_lines) if new_lines else None]
48
+ # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
49
+ total_sequence = (
50
+ input_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
51
+ )
52
+ generated_sequences.append(total_sequence)
53
+
54
+ parsed_text = total_sequence.replace("<|startoftext|>", "").replace("\r","").replace("\n\n", "\n")
55
+ if len(parsed_text) == 0:
56
+ parsed_text = "ืฉื’ื™ืื”"
57
+ return parsed_text
58
+ if __name__ == "__main__":
59
+ st.title("Hebrew GPT Neo (Small)")
60
+ pre_model_path = "Norod78/hebrew-gpt_neo-small"
61
+ model, tokenizer = load_model(pre_model_path)
62
+ stop_token = "<|endoftext|>"
63
+ new_lines = "\n\n\n"
64
+ np.random.seed(None)
65
+ random_seed = np.random.randint(10000,size=1)
66
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
+ n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
68
+ torch.manual_seed(random_seed)
69
+ if n_gpu > 0:
70
+ torch.cuda.manual_seed_all(random_seed)
71
+ model.to(device)
72
+ text_area = st.text_area("Enter the first few words (or leave blank), tap on \"Generate Text\" below. Tapping again will produce a different result.", 'ื”ืื™ืฉ ื”ืื—ืจื•ืŸ ื‘ืขื•ืœื ื™ืฉื‘ ืœื‘ื“ ื‘ื—ื“ืจื• ื›ืฉืœืคืชืข ื ืฉืžืขื” ื ืงื™ืฉื”')
73
+ st.sidebar.subheader("Configurable parameters")
74
+ max_len = st.sidebar.slider("Max-Length", 0, 256, 192,help="The maximum length of the sequence to be generated.")
75
+ top_k = st.sidebar.slider("Top-K", 0, 100, 40, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
76
+ top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.92, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.")
77
+ if st.button("Generate Text"):
78
+ with st.spinner(text="Generating results..."):
79
+ st.subheader("Result")
80
+ print(f"device:{device}, n_gpu:{n_gpu}, random_seed:{random_seed}, maxlen:{max_len}, top_k:{top_k}, top_p:{top_p}")
81
+ if len(text_area.strip()) == 0:
82
+ text_area = random.choice(suggested_text_list)
83
+ result = extend(input_text=text_area,
84
+ max_size=int(max_len),
85
+ top_k=int(top_k),
86
+ top_p=float(top_p))
87
+ print("Done length: " + str(len(result)) + " bytes")
88
+ #<div class="rtl" dir="rtl" style="text-align:right;">
89
+ st.markdown(f"<p dir=\"rtl\" style=\"text-align:right;\"> {result} </p>", unsafe_allow_html=True)
90
+ st.write("\n\nResult length: " + str(len(result)) + " bytes")
91
+ print(f"\"{result}\"")
92
+
93
+ st.markdown(
94
+ """Hebrew text generation model (125M parameters) based on EleutherAI's gpt-neo architecture. Originally trained on a TPUv3-8 which was made avilable to me via the [TPU Research Cloud Program](https://sites.research.google/trc/)."""
95
+ )
96
+ st.markdown("<footer><hr><p style=\"font-size:14px\">Enjoy</p><p style=\"font-size:12px\">Created by <a href=\"https://linktr.ee/Norod78\">Doron Adler</a></p></footer> ", unsafe_allow_html=True)