Spaces:
Runtime error
Runtime error
File size: 3,565 Bytes
10fb5a7 adafa1e 10fb5a7 fe765e3 10fb5a7 59fcbd6 d0d15d6 10fb5a7 67eeb75 10fb5a7 cfc7981 10fb5a7 67eeb75 59fcbd6 5b26870 59fcbd6 d31aa4a 59fcbd6 d0d15d6 59fcbd6 10fb5a7 59fcbd6 d31aa4a 59fcbd6 10fb5a7 5b26870 10fb5a7 d31aa4a 59fcbd6 5b26870 10fb5a7 59fcbd6 10fb5a7 5b26870 59fcbd6 10fb5a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import nltk
import string
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig, set_seed
import random
nltk.download('punkt')
response_length = 200
sentence_detector = nltk.data.load('tokenizers/punkt/english.pickle')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
tokenizer.truncation_side = 'right'
# model = GPT2LMHeadModel.from_pretrained('checkpoint-10000')
model = GPT2LMHeadModel.from_pretrained('coffeeee/nsfw-story-generator')
generation_config = GenerationConfig.from_pretrained('gpt2-medium')
generation_config.max_new_tokens = response_length
generation_config.pad_token_id = generation_config.eos_token_id
def generate_response(outputs, new_prompt):
story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)]) if outputs else ""
set_seed(random.randint(0, 4000000000))
inputs = tokenizer.encode(story_so_far + "\n" + new_prompt if story_so_far else new_prompt,
return_tensors='pt', truncation=True,
max_length=1024 - response_length)
output = model.generate(inputs, do_sample=True, generation_config=generation_config)
response = clean_paragraph(tokenizer.batch_decode(output)[0][(len(story_so_far) + 1 if story_so_far else 0):])
outputs.append(response)
return {
user_outputs: outputs,
story: (story_so_far + "\n" if story_so_far else "") + response,
prompt: None
}
def undo(outputs):
outputs = outputs[:-1] if outputs else []
return {
user_outputs: outputs,
story: "\n".join(outputs) if outputs else None
}
def clean_paragraph(entry):
paragraphs = entry.split('\n')
for i in range(len(paragraphs)):
split_sentences = nltk.tokenize.sent_tokenize(paragraphs[i], language='english')
if i == len(paragraphs) - 1 and split_sentences[:1][-1] not in string.punctuation:
paragraphs[i] = " ".join(split_sentences[:-1])
return capitalize_first_char("\n".join(paragraphs))
def reset():
return {
user_outputs: [],
story: None
}
def capitalize_first_char(entry):
for i in range(len(entry)):
if entry[i].isalpha():
return entry[:i] + entry[i].upper() + entry[i + 1:]
return entry
with gr.Blocks(theme=gr.themes.Default(text_size='lg', font=[gr.themes.GoogleFont("Bitter"), "Arial", "sans-serif"])) as demo:
placeholder_text = '''
Disclaimer: everything this model generates is a work of fiction.
Content from this model WILL generate inappropriate and potentially offensive content.
Use at your own discretion. Please respect the Huggingface code of conduct.'''
story = gr.Textbox(label="Story", interactive=False, lines=20, placeholder=placeholder_text)
story.style(show_copy_button=True)
user_outputs = gr.State([])
prompt = gr.Textbox(label="Prompt", placeholder="Start a new story, or continue your current one!", lines=3, max_lines=3)
with gr.Row():
gen_button = gr.Button('Generate')
undo_button = gr.Button("Undo")
res_button = gr.Button("Reset")
prompt.submit(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
gen_button.click(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
undo_button.click(undo, user_outputs, [user_outputs, story], scroll_to_output=True)
res_button.click(reset, [], [user_outputs, story], scroll_to_output=True)
demo.launch(inbrowser=True)
|