cahya commited on
Commit
7d0ffd9
1 Parent(s): b2d148b

add the max length of text

Browse files
Files changed (1) hide show
  1. app/app.py +11 -4
app/app.py CHANGED
@@ -14,16 +14,16 @@ model_name = "cahya/gpt2-small-indonesian-story"
14
 
15
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
16
  def get_generator():
17
- st.write("Loading the GPT2 model...")
18
  text_generator = pipeline('text-generation', model=model_name)
19
  return text_generator
20
 
21
 
22
  #@st.cache(suppress_st_warning=True)
23
  def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
24
- temperature: float = 1.0, max_time: float = None):
25
  st.write("Cache miss: process")
26
- set_seed(42)
27
  result = text_generator(text, max_length=max_length, do_sample=do_sample,
28
  top_k=top_k, top_p=top_p, temperature=temperature, max_time=max_time)
29
  return result
@@ -58,6 +58,13 @@ else:
58
 
59
  session_state.text = st.text_area("Enter text", session_state.prompt_box)
60
 
 
 
 
 
 
 
 
61
  temp = st.sidebar.slider(
62
  "Temperature",
63
  value=1.0,
@@ -80,7 +87,7 @@ if st.button("Run"):
80
  with st.spinner(text="Getting results..."):
81
  st.subheader("Result")
82
  time_start = time.time()
83
- result = process(text=session_state.text, top_k=int(top_k), top_p=float(top_p))
84
  time_end = time.time()
85
  time_diff = time_end-time_start
86
  #print(f"Text generated in {time_diff} seconds")
 
14
 
15
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
16
  def get_generator():
17
+ st.write(f"Loading the GPT2 model {model_name}, please wait...")
18
  text_generator = pipeline('text-generation', model=model_name)
19
  return text_generator
20
 
21
 
22
  #@st.cache(suppress_st_warning=True)
23
  def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
24
+ temperature: float = 1.0, max_time: float = None, seed=42):
25
  st.write("Cache miss: process")
26
+ set_seed(seed)
27
  result = text_generator(text, max_length=max_length, do_sample=do_sample,
28
  top_k=top_k, top_p=top_p, temperature=temperature, max_time=max_time)
29
  return result
 
58
 
59
  session_state.text = st.text_area("Enter text", session_state.prompt_box)
60
 
61
+ max_length = st.sidebar.number_input(
62
+ "Maximum length",
63
+ value=100,
64
+ max_value=512,
65
+ help="The maximum length of the sequence to be generated."
66
+ )
67
+
68
  temp = st.sidebar.slider(
69
  "Temperature",
70
  value=1.0,
 
87
  with st.spinner(text="Getting results..."):
88
  st.subheader("Result")
89
  time_start = time.time()
90
+ result = process(text=session_state.text, max_length=int(max_length), top_k=int(top_k), top_p=float(top_p))
91
  time_end = time.time()
92
  time_diff = time_end-time_start
93
  #print(f"Text generated in {time_diff} seconds")