grandestroyer commited on
Commit
79f1518
1 Parent(s): 96a2232

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -53
app.py CHANGED
@@ -6,19 +6,6 @@ cache_dir = os.getcwd() + '/cache'
6
  ai = aitextgen(model="grandestroyer/joefreaks", cache_dir=cache_dir)
7
 
8
 
9
- def generate(n, temp, prompt, exclude_repetitions):
10
- no_repeat_ngram_size = 2 if exclude_repetitions else 0
11
- print('Generating with params n={}, temp={}, prompt="{}", no_repeat_ngram_size={}'
12
- .format(n, temp, prompt, no_repeat_ngram_size))
13
- if prompt == '':
14
- return [txt.strip() for txt in ai.generate(n=n, temperature=temp, top_p=0.9, top_k=40, return_as_list=True,
15
- no_repeat_ngram_size=no_repeat_ngram_size, max_length=500)]
16
- else:
17
- return [txt.strip() for txt in
18
- ai.generate(prompt=prompt, n=n, temperature=temp, top_p=0.9, top_k=40, return_as_list=True,
19
- no_repeat_ngram_size=no_repeat_ngram_size, max_length=500)]
20
-
21
-
22
  def generate_from_full_params(prompt='', n=1, temp=0.7, top_p=0.9, top_k=40, max_length=500, exclude_repetitions=False):
23
  no_repeat_ngram_size = 2 if exclude_repetitions else 0
24
  temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
@@ -29,48 +16,14 @@ def generate_from_full_params(prompt='', n=1, temp=0.7, top_p=0.9, top_k=40, max
29
  no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length)]
30
 
31
 
32
- def display_results(prompt, results):
33
- history = []
34
- if prompt != '':
35
- history = history + [(None, prompt)]
36
- else:
37
- history = history + [(None, '&lt;empty prompt&gt;')]
38
- for res in results:
39
- history = history + [(res, None)]
40
- return history
41
-
42
-
43
- def submit_input(n, temp, prompt, exclude_repetitions):
44
- results = generate(n, temp, prompt, exclude_repetitions)
45
- return display_results(prompt, results)
46
-
47
-
48
  def generate_one(temp, prompt, exclude_repetitions):
49
  temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
50
  return generate(n=1, temp=temp_normalized, prompt=prompt, exclude_repetitions=exclude_repetitions)[0]
51
 
52
 
53
- with gr.Blocks() as demo:
54
- temp = gr.Number(visible=False, label='temp', value=0.7)
55
- top_p_el = gr.Number(visible=False, label='top_p', value=0.9)
56
- top_k_el = gr.Number(visible=False, label='top_k', value=40, precision=0)
57
- max_length_el = gr.Number(visible=False, label='max_length', value=500, precision=0)
58
- result = gr.Textbox(visible=False)
59
- apiBtn = gr.Button(visible=False)
60
- apiFullBtn = gr.Button(visible=False)
61
-
62
- with gr.Row():
63
- with gr.Column(scale=0.5):
64
- n_el = gr.Number(label='Number of generated strings', value=5, precision=0)
65
- temp_el = gr.Slider(label='Temperature', value=0.7, minimum=0.1, maximum=2.0)
66
- exclude_repetitions_el = gr.Checkbox(label='Reduce repetitions (if possible)', value=False)
67
- prompt_el = gr.Textbox(label='Prompt (optional)')
68
- btn = gr.Button(value='Submit')
69
- with gr.Column(scale=0.5):
70
- chatbox_el = gr.Chatbot().style(height=500)
71
- btn.click(submit_input, inputs=[n_el, temp_el, prompt_el, exclude_repetitions_el], outputs=[chatbox_el])
72
- apiBtn.click(generate_one, [temp, prompt_el, exclude_repetitions_el], [result], api_name='generate')
73
- apiFullBtn.click(generate_from_full_params, [prompt_el, n_el, temp, top_p_el, top_k_el, max_length_el, exclude_repetitions_el], [chatbox_el], api_name='generateWithFullParams')
74
-
75
- if __name__ == "__main__":
76
- demo.launch()
 
6
  ai = aitextgen(model="grandestroyer/joefreaks", cache_dir=cache_dir)
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def generate_from_full_params(prompt='', n=1, temp=0.7, top_p=0.9, top_k=40, max_length=500, exclude_repetitions=False):
10
  no_repeat_ngram_size = 2 if exclude_repetitions else 0
11
  temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
 
16
  no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length)]
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def generate_one(temp, prompt, exclude_repetitions):
20
  temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
21
  return generate(n=1, temp=temp_normalized, prompt=prompt, exclude_repetitions=exclude_repetitions)[0]
22
 
23
 
24
+ gr_interface = gradio.Inteface(
25
+ fn = generate_from_full_params,
26
+ inputs = ['text', 'number', 'number', 'number', 'number', 'number', 'checkbox'],
27
+ outputs = 'json'
28
+ )
29
+ gr_interface.launch()