Spaces:
Running
Running
grandestroyer
commited on
Commit
•
79f1518
1
Parent(s):
96a2232
Update app.py
Browse files
app.py
CHANGED
@@ -6,19 +6,6 @@ cache_dir = os.getcwd() + '/cache'
|
|
6 |
ai = aitextgen(model="grandestroyer/joefreaks", cache_dir=cache_dir)
|
7 |
|
8 |
|
9 |
-
def generate(n, temp, prompt, exclude_repetitions):
|
10 |
-
no_repeat_ngram_size = 2 if exclude_repetitions else 0
|
11 |
-
print('Generating with params n={}, temp={}, prompt="{}", no_repeat_ngram_size={}'
|
12 |
-
.format(n, temp, prompt, no_repeat_ngram_size))
|
13 |
-
if prompt == '':
|
14 |
-
return [txt.strip() for txt in ai.generate(n=n, temperature=temp, top_p=0.9, top_k=40, return_as_list=True,
|
15 |
-
no_repeat_ngram_size=no_repeat_ngram_size, max_length=500)]
|
16 |
-
else:
|
17 |
-
return [txt.strip() for txt in
|
18 |
-
ai.generate(prompt=prompt, n=n, temperature=temp, top_p=0.9, top_k=40, return_as_list=True,
|
19 |
-
no_repeat_ngram_size=no_repeat_ngram_size, max_length=500)]
|
20 |
-
|
21 |
-
|
22 |
def generate_from_full_params(prompt='', n=1, temp=0.7, top_p=0.9, top_k=40, max_length=500, exclude_repetitions=False):
|
23 |
no_repeat_ngram_size = 2 if exclude_repetitions else 0
|
24 |
temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
|
@@ -29,48 +16,14 @@ def generate_from_full_params(prompt='', n=1, temp=0.7, top_p=0.9, top_k=40, max
|
|
29 |
no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length)]
|
30 |
|
31 |
|
32 |
-
def display_results(prompt, results):
|
33 |
-
history = []
|
34 |
-
if prompt != '':
|
35 |
-
history = history + [(None, prompt)]
|
36 |
-
else:
|
37 |
-
history = history + [(None, '<empty prompt>')]
|
38 |
-
for res in results:
|
39 |
-
history = history + [(res, None)]
|
40 |
-
return history
|
41 |
-
|
42 |
-
|
43 |
-
def submit_input(n, temp, prompt, exclude_repetitions):
|
44 |
-
results = generate(n, temp, prompt, exclude_repetitions)
|
45 |
-
return display_results(prompt, results)
|
46 |
-
|
47 |
-
|
48 |
def generate_one(temp, prompt, exclude_repetitions):
|
49 |
temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
|
50 |
return generate(n=1, temp=temp_normalized, prompt=prompt, exclude_repetitions=exclude_repetitions)[0]
|
51 |
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
apiBtn = gr.Button(visible=False)
|
60 |
-
apiFullBtn = gr.Button(visible=False)
|
61 |
-
|
62 |
-
with gr.Row():
|
63 |
-
with gr.Column(scale=0.5):
|
64 |
-
n_el = gr.Number(label='Number of generated strings', value=5, precision=0)
|
65 |
-
temp_el = gr.Slider(label='Temperature', value=0.7, minimum=0.1, maximum=2.0)
|
66 |
-
exclude_repetitions_el = gr.Checkbox(label='Reduce repetitions (if possible)', value=False)
|
67 |
-
prompt_el = gr.Textbox(label='Prompt (optional)')
|
68 |
-
btn = gr.Button(value='Submit')
|
69 |
-
with gr.Column(scale=0.5):
|
70 |
-
chatbox_el = gr.Chatbot().style(height=500)
|
71 |
-
btn.click(submit_input, inputs=[n_el, temp_el, prompt_el, exclude_repetitions_el], outputs=[chatbox_el])
|
72 |
-
apiBtn.click(generate_one, [temp, prompt_el, exclude_repetitions_el], [result], api_name='generate')
|
73 |
-
apiFullBtn.click(generate_from_full_params, [prompt_el, n_el, temp, top_p_el, top_k_el, max_length_el, exclude_repetitions_el], [chatbox_el], api_name='generateWithFullParams')
|
74 |
-
|
75 |
-
if __name__ == "__main__":
|
76 |
-
demo.launch()
|
|
|
6 |
ai = aitextgen(model="grandestroyer/joefreaks", cache_dir=cache_dir)
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def generate_from_full_params(prompt='', n=1, temp=0.7, top_p=0.9, top_k=40, max_length=500, exclude_repetitions=False):
|
10 |
no_repeat_ngram_size = 2 if exclude_repetitions else 0
|
11 |
temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
|
|
|
16 |
no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length)]
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def generate_one(temp, prompt, exclude_repetitions):
|
20 |
temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
|
21 |
return generate(n=1, temp=temp_normalized, prompt=prompt, exclude_repetitions=exclude_repetitions)[0]
|
22 |
|
23 |
|
24 |
+
gr_interface = gradio.Inteface(
|
25 |
+
fn = generate_from_full_params,
|
26 |
+
inputs = ['text', 'number', 'number', 'number', 'number', 'number', 'checkbox'],
|
27 |
+
outputs = 'json'
|
28 |
+
)
|
29 |
+
gr_interface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|