Spaces:
Running
Running
from aitextgen import aitextgen | |
import gradio as gr | |
import os | |
cache_dir = os.getcwd() + '/cache' | |
ai = aitextgen(model="grandestroyer/joefreaks", cache_dir=cache_dir) | |
def generate(n, temp, prompt, exclude_repetitions): | |
no_repeat_ngram_size = 2 if exclude_repetitions else 0 | |
print('Generating with params n={}, temp={}, prompt="{}", no_repeat_ngram_size={}' | |
.format(n, temp, prompt, no_repeat_ngram_size)) | |
if prompt == '': | |
return [txt.strip() for txt in ai.generate(n=n, temperature=temp, top_p=0.9, top_k=40, return_as_list=True, | |
no_repeat_ngram_size=no_repeat_ngram_size, max_length=500)] | |
else: | |
return [txt.strip() for txt in | |
ai.generate(prompt=prompt, n=n, temperature=temp, top_p=0.9, top_k=40, return_as_list=True, | |
no_repeat_ngram_size=no_repeat_ngram_size, max_length=500)] | |
def display_results(prompt, results): | |
history = [] | |
if prompt != '': | |
history = history + [(None, prompt)] | |
else: | |
history = history + [(None, '<empty prompt>')] | |
for res in results: | |
history = history + [(res, None)] | |
return history | |
def submit_input(n, temp, prompt, exclude_repetitions): | |
results = generate(n, temp, prompt, exclude_repetitions) | |
return display_results(prompt, results) | |
def generate_one(temp, prompt, exclude_repetitions): | |
temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp) | |
return generate(n=1, temp=temp_normalized, prompt=prompt, exclude_repetitions=exclude_repetitions)[0] | |
with gr.Blocks() as demo: | |
temp = gr.Number(visible=False) | |
result = gr.Textbox(visible=False) | |
apiBtn = gr.Button(visible=False) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
n_el = gr.Number(label='Number of generated strings', value=5, precision=0) | |
temp_el = gr.Slider(label='Temperature', value=0.7, minimum=0.1, maximum=2.0) | |
exclude_repetitions_el = gr.Checkbox(label='Reduce repetitions (if possible)', value=False) | |
prompt_el = gr.Textbox(label='Prompt (optional)') | |
btn = gr.Button(value='Submit') | |
with gr.Column(scale=1): | |
chatbox_el = gr.Chatbot(height=500) | |
btn.click(submit_input, inputs=[n_el, temp_el, prompt_el, exclude_repetitions_el], outputs=[chatbox_el]) | |
apiBtn.click(generate_one, [temp, prompt_el, exclude_repetitions_el], [result], api_name='generate') | |
if __name__ == "__main__": | |
demo.launch() | |