File size: 2,190 Bytes
f7606e4
9d294d5
49e904b
 
f7606e4
9d294d5
 
70ab7c2
f7606e4
0cdb2ad
9d294d5
23aa67b
9d294d5
 
23aa67b
9d294d5
f7606e4
9d294d5
 
f7606e4
9d294d5
f7606e4
9d294d5
 
 
 
 
 
 
 
 
 
f7606e4
9d294d5
f7606e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cdb2ad
 
f7606e4
 
 
0cdb2ad
 
f7606e4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from optimum.intel import OVModelForCausalLM


model_name = "DarwinAnim8or/Pythia-Greentext-1.4b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = OVModelForCausalLM.from_pretrained(model_name, export=True)

def generate(text, length=100, penalty=3, temperature=0.8, topk=40):
    input_text = "Write a greentext from 4chan.org. The story should be like a bullet-point list using > as the start of each line. Most greentexts are humorous or absurd in nature. Most greentexts have a twist near the end.\n"

    if not text.startswith(">"):
        input_text += ">" + text + "\n>"
    else:
        input_text += text + "\n>"

    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    input_ids = input_ids[:, :-1]  # remove the last token, which is ">"

    length = length + input_ids.size(1)  # adjust total length

    output = model.generate(
        input_ids,
        max_length=length,
        temperature=temperature,
        top_k=topk,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        no_repeat_ngram_size=penalty,
        early_stopping=True,
    )

    generated_text = tokenizer.decode(output[:, input_ids.size(1):][0], skip_special_tokens=True)
    return generated_text

examples = [
    ["be me"],
    ["be going to heaven"],
    #["be going to work"],
    #["be baking a pie"],
    #["come home after another tiring day"],
    ["be a plague doctor"]
]

demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.inputs.Textbox(lines=5, label="Input Text"),
        gr.inputs.Slider(5, 200, label='Length', default=100, step=5),
        gr.inputs.Slider(1, 10, label='no repeat ngram size', default=2, step=1),
        gr.inputs.Slider(0.0, 1.0, label='Temperature - control randomness', default=0.2, step=0.1),
        gr.inputs.Slider(10, 100, label="top_k", default=40, step=10)
    ],
    outputs=gr.outputs.Textbox(label="Generated Text"),
    examples=examples,
    title="Pythia-Greentext Playground",
    description="Using the 1.4b size model. You may need to run it a few times in order to get something good!"
)

demo.launch()