joefreaks_api / app.py
grandestroyer's picture
Duplicate from grandestroyer/joefreaks_api_test
fa9b36e
raw history blame
No virus
3.73 kB
from aitextgen import aitextgen
import gradio as gr
import os
cache_dir = os.getcwd() + '/cache'
ai = aitextgen(model="grandestroyer/joefreaks", cache_dir=cache_dir)
def generate(n, temp, prompt, exclude_repetitions):
no_repeat_ngram_size = 2 if exclude_repetitions else 0
print('Generating with params n={}, temp={}, prompt="{}", no_repeat_ngram_size={}'
.format(n, temp, prompt, no_repeat_ngram_size))
if prompt == '':
return [txt.strip() for txt in ai.generate(n=n, temperature=temp, top_p=0.9, top_k=40, return_as_list=True,
no_repeat_ngram_size=no_repeat_ngram_size, max_length=500)]
else:
return [txt.strip() for txt in
ai.generate(prompt=prompt, n=n, temperature=temp, top_p=0.9, top_k=40, return_as_list=True,
no_repeat_ngram_size=no_repeat_ngram_size, max_length=500)]
def generate_from_full_params(prompt='', n=1, temp=0.7, top_p=0.9, top_k=40, max_length=500, exclude_repetitions=False):
no_repeat_ngram_size = 2 if exclude_repetitions else 0
temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
print('Generating with params prompt="{}", n={}, temp={}, top_p={}, top_k={}, max_length={}, no_repeat_ngram_size={}'
.format(prompt, n, temp_normalized, top_p, top_k, max_length, no_repeat_ngram_size))
return [txt.strip() for txt in
ai.generate(prompt=prompt, n=n, temperature=temp_normalized, top_p=top_p, top_k=top_k, return_as_list=True,
no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length)]
def display_results(prompt, results):
history = []
if prompt != '':
history = history + [(None, prompt)]
else:
history = history + [(None, '&lt;empty prompt&gt;')]
for res in results:
history = history + [(res, None)]
return history
def submit_input(n, temp, prompt, exclude_repetitions):
results = generate(n, temp, prompt, exclude_repetitions)
return display_results(prompt, results)
def generate_one(temp, prompt, exclude_repetitions):
temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
return generate(n=1, temp=temp_normalized, prompt=prompt, exclude_repetitions=exclude_repetitions)[0]
with gr.Blocks() as demo:
temp = gr.Number(visible=False, label='temp', value=0.7)
top_p_el = gr.Number(visible=False, label='top_p', value=0.9)
top_k_el = gr.Number(visible=False, label='top_k', value=40, precision=0)
max_length_el = gr.Number(visible=False, label='max_length', value=500, precision=0)
result = gr.Textbox(visible=False)
apiBtn = gr.Button(visible=False)
apiFullBtn = gr.Button(visible=False)
with gr.Row():
with gr.Column(scale=0.5):
n_el = gr.Number(label='Number of generated strings', value=5, precision=0)
temp_el = gr.Slider(label='Temperature', value=0.7, minimum=0.1, maximum=2.0)
exclude_repetitions_el = gr.Checkbox(label='Reduce repetitions (if possible)', value=False)
prompt_el = gr.Textbox(label='Prompt (optional)')
btn = gr.Button(value='Submit')
with gr.Column(scale=0.5):
chatbox_el = gr.Chatbot().style(height=500)
btn.click(submit_input, inputs=[n_el, temp_el, prompt_el, exclude_repetitions_el], outputs=[chatbox_el])
apiBtn.click(generate_one, [temp, prompt_el, exclude_repetitions_el], [result], api_name='generate')
apiFullBtn.click(generate_from_full_params, [prompt_el, n_el, temp, top_p_el, top_k_el, max_length_el, exclude_repetitions_el], [chatbox_el], api_name='generateWithFullParams')
if __name__ == "__main__":
demo.launch()