File size: 1,358 Bytes
fa9b36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0116f84
79f1518
457981c
79f1518
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from aitextgen import aitextgen
import gradio as gr
import os

cache_dir = os.getcwd() + '/cache'
ai = aitextgen(model="grandestroyer/joefreaks", cache_dir=cache_dir)


def generate_from_full_params(prompt='', n=1, temp=0.7, top_p=0.9, top_k=40, max_length=500, exclude_repetitions=False):
    no_repeat_ngram_size = 2 if exclude_repetitions else 0
    temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
    print('Generating with params prompt="{}", n={}, temp={}, top_p={}, top_k={}, max_length={}, no_repeat_ngram_size={}'
          .format(prompt, n, temp_normalized, top_p, top_k, max_length, no_repeat_ngram_size))
    return [txt.strip() for txt in
            ai.generate(prompt=prompt, n=n, temperature=temp_normalized, top_p=top_p, top_k=top_k, return_as_list=True,
                        no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_length)]


def generate_one(temp, prompt, exclude_repetitions):
    temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
    return generate(n=1, temp=temp_normalized, prompt=prompt, exclude_repetitions=exclude_repetitions)[0]


gr_interface = gr.Interface(
    fn = generate_from_full_params,
    inputs = ['text', gr.Number(precision=0), 'number', 'number', gr.Number(precision=0), gr.Number(precision=0), 'checkbox'],
    outputs = 'json'
)
gr_interface.launch()