File size: 3,565 Bytes
10fb5a7
 
 
 
 
 
 
adafa1e
 
10fb5a7
 
 
 
 
 
 
fe765e3
 
10fb5a7
 
 
59fcbd6
 
d0d15d6
10fb5a7
 
67eeb75
10fb5a7
 
cfc7981
10fb5a7
 
67eeb75
59fcbd6
 
 
5b26870
 
59fcbd6
 
 
 
d31aa4a
59fcbd6
 
d0d15d6
59fcbd6
10fb5a7
 
 
 
 
 
 
 
 
 
 
 
59fcbd6
d31aa4a
59fcbd6
 
10fb5a7
 
 
 
 
 
 
5b26870
 
 
 
 
 
 
 
 
10fb5a7
 
d31aa4a
59fcbd6
5b26870
10fb5a7
 
 
59fcbd6
10fb5a7
 
5b26870
 
59fcbd6
 
10fb5a7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr

import nltk
import string
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig, set_seed
import random

nltk.download('punkt')

response_length = 200

sentence_detector = nltk.data.load('tokenizers/punkt/english.pickle')

tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
tokenizer.truncation_side = 'right'

# model = GPT2LMHeadModel.from_pretrained('checkpoint-10000')
model = GPT2LMHeadModel.from_pretrained('coffeeee/nsfw-story-generator')
generation_config = GenerationConfig.from_pretrained('gpt2-medium')
generation_config.max_new_tokens = response_length
generation_config.pad_token_id = generation_config.eos_token_id
def generate_response(outputs, new_prompt):

    story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)]) if outputs else ""

    set_seed(random.randint(0, 4000000000))
    inputs = tokenizer.encode(story_so_far + "\n" + new_prompt if story_so_far else new_prompt,
                              return_tensors='pt', truncation=True,
                              max_length=1024 - response_length)

    output = model.generate(inputs, do_sample=True, generation_config=generation_config)

    response = clean_paragraph(tokenizer.batch_decode(output)[0][(len(story_so_far) + 1 if story_so_far else 0):])
    outputs.append(response)
    return {
        user_outputs: outputs,
        story: (story_so_far + "\n" if story_so_far else "") + response,
        prompt: None
    }

def undo(outputs):

    outputs = outputs[:-1] if outputs else []
    return {
        user_outputs: outputs,
        story: "\n".join(outputs) if outputs else None
    }

def clean_paragraph(entry):
    paragraphs = entry.split('\n')

    for i in range(len(paragraphs)):
        split_sentences = nltk.tokenize.sent_tokenize(paragraphs[i], language='english')
        if i == len(paragraphs) - 1 and split_sentences[:1][-1] not in string.punctuation:
            paragraphs[i] = " ".join(split_sentences[:-1])

    return capitalize_first_char("\n".join(paragraphs))

def reset():
    return {
        user_outputs: [],
        story: None
    }

def capitalize_first_char(entry):
    for i in range(len(entry)):
        if entry[i].isalpha():
            return entry[:i] + entry[i].upper() + entry[i + 1:]
    return entry

with gr.Blocks(theme=gr.themes.Default(text_size='lg', font=[gr.themes.GoogleFont("Bitter"), "Arial", "sans-serif"])) as demo:

    placeholder_text = '''
    Disclaimer: everything this model generates is a work of fiction.
    Content from this model WILL generate inappropriate and potentially offensive content.

    Use at your own discretion. Please respect the Huggingface code of conduct.'''

    story = gr.Textbox(label="Story", interactive=False, lines=20, placeholder=placeholder_text)
    story.style(show_copy_button=True)

    user_outputs = gr.State([])

    prompt = gr.Textbox(label="Prompt", placeholder="Start a new story, or continue your current one!", lines=3, max_lines=3)

    with gr.Row():
        gen_button = gr.Button('Generate')
        undo_button = gr.Button("Undo")
        res_button = gr.Button("Reset")

    prompt.submit(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
    gen_button.click(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
    undo_button.click(undo, user_outputs, [user_outputs, story], scroll_to_output=True)
    res_button.click(reset, [], [user_outputs, story], scroll_to_output=True)

demo.launch(inbrowser=True)