TK192828 commited on
Commit
c69f3c6
1 Parent(s): 5294acb

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +12 -0
  2. app.py +100 -0
  3. requirements.txt +78 -0
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "[NSFW] C0ffee's Erotic Story Generator 2"
3
+ emoji: 🍑
4
+ colorFrom: gray
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import nltk
4
+ import string
5
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig, set_seed
6
+ import random
7
+
8
+ nltk.download('punkt')
9
+
10
+ response_length = 200
11
+
12
+ sentence_detector = nltk.data.load('tokenizers/punkt/english.pickle')
13
+
14
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
15
+ tokenizer.truncation_side = 'right'
16
+
17
+ # model = GPT2LMHeadModel.from_pretrained('checkpoint-50000')
18
+ model = GPT2LMHeadModel.from_pretrained('coffeeee/nsfw-story-generator2')
19
+ generation_config = GenerationConfig.from_pretrained('gpt2-medium')
20
+ generation_config.max_new_tokens = response_length
21
+ generation_config.pad_token_id = generation_config.eos_token_id
22
+ def generate_response(outputs, new_prompt):
23
+
24
+ story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)]) if outputs else ""
25
+
26
+ set_seed(random.randint(0, 4000000000))
27
+ inputs = tokenizer.encode(story_so_far + "\n" + new_prompt if story_so_far else new_prompt,
28
+ return_tensors='pt', truncation=True,
29
+ max_length=1024 - response_length)
30
+
31
+ output = model.generate(inputs, do_sample=True, generation_config=generation_config)
32
+
33
+ response = clean_paragraph(tokenizer.batch_decode(output)[0][(len(story_so_far) + 1 if story_so_far else 0):])
34
+ outputs.append(response)
35
+ return {
36
+ user_outputs: outputs,
37
+ story: (story_so_far + "\n" if story_so_far else "") + response,
38
+ prompt: None
39
+ }
40
+
41
+ def undo(outputs):
42
+
43
+ outputs = outputs[:-1] if outputs else []
44
+ return {
45
+ user_outputs: outputs,
46
+ story: "\n".join(outputs) if outputs else None
47
+ }
48
+
49
+ def clean_paragraph(entry):
50
+ paragraphs = entry.split('\n')
51
+
52
+ for i in range(len(paragraphs)):
53
+ split_sentences = nltk.tokenize.sent_tokenize(paragraphs[i], language='english')
54
+ if i == len(paragraphs) - 1 and split_sentences[:1][-1] not in string.punctuation:
55
+ paragraphs[i] = " ".join(split_sentences[:-1])
56
+
57
+ return capitalize_first_char("\n".join(paragraphs))
58
+
59
+ def reset():
60
+ return {
61
+ user_outputs: [],
62
+ story: None
63
+ }
64
+
65
+ def capitalize_first_char(entry):
66
+ for i in range(len(entry)):
67
+ if entry[i].isalpha():
68
+ return entry[:i] + entry[i].upper() + entry[i + 1:]
69
+ return entry
70
+
71
+ with gr.Blocks(theme=gr.themes.Default(text_size='lg', font=[gr.themes.GoogleFont("Bitter"), "Arial", "sans-serif"])) as demo:
72
+
73
+ placeholder_text = '''
74
+ Disclaimer: everything this model generates is a work of fiction.
75
+ Content from this model WILL generate inappropriate and potentially offensive content.
76
+
77
+ Use at your own discretion. Please respect the Huggingface code of conduct.'''
78
+
79
+ story = gr.Textbox(label="Story", interactive=False, lines=20, placeholder=placeholder_text)
80
+ story.style(show_copy_button=True)
81
+
82
+ user_outputs = gr.State([])
83
+
84
+ prompt = gr.Textbox(label="Prompt", placeholder="Start a new story, or continue your current one!", lines=3, max_lines=3)
85
+
86
+ with gr.Row():
87
+ gen_button = gr.Button('Generate')
88
+ undo_button = gr.Button("Undo")
89
+ res_button = gr.Button("Reset")
90
+
91
+ prompt.submit(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
92
+ gen_button.click(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
93
+ undo_button.click(undo, user_outputs, [user_outputs, story], scroll_to_output=True)
94
+ res_button.click(reset, [], [user_outputs, story], scroll_to_output=True)
95
+
96
+ # for local server; comment out for deploy
97
+
98
+ demo.launch(inbrowser=True, server_name='0.0.0.0')
99
+
100
+
requirements.txt ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ altgraph==0.17.3
6
+ anyio==3.6.2
7
+ async-timeout==4.0.2
8
+ attrs==23.1.0
9
+ certifi==2022.12.7
10
+ charset-normalizer==3.1.0
11
+ click==8.1.3
12
+ colorama==0.4.6
13
+ contourpy==1.0.7
14
+ cycler==0.11.0
15
+ entrypoints==0.4
16
+ fastapi==0.95.1
17
+ ffmpy==0.3.0
18
+ filelock==3.12.0
19
+ fonttools==4.39.3
20
+ frozenlist==1.3.3
21
+ fsspec==2023.4.0
22
+ gradio==3.27.0
23
+ gradio_client==0.1.3
24
+ h11==0.14.0
25
+ httpcore==0.17.0
26
+ httpx==0.24.0
27
+ huggingface-hub==0.14.0
28
+ idna==3.4
29
+ Jinja2==3.1.2
30
+ joblib==1.2.0
31
+ jsonschema==4.17.3
32
+ kiwisolver==1.4.4
33
+ linkify-it-py==2.0.0
34
+ markdown-it-py==2.2.0
35
+ MarkupSafe==2.1.2
36
+ matplotlib==3.7.1
37
+ mdit-py-plugins==0.3.3
38
+ mdurl==0.1.2
39
+ mpmath==1.3.0
40
+ multidict==6.0.4
41
+ networkx==3.1
42
+ nltk==3.8.1
43
+ numpy==1.24.3
44
+ orjson==3.8.10
45
+ packaging==23.1
46
+ pandas==2.0.1
47
+ pefile==2023.2.7
48
+ Pillow==9.5.0
49
+ pydantic==1.10.7
50
+ pydub==0.25.1
51
+ pyinstaller==5.10.1
52
+ pyinstaller-hooks-contrib==2023.2
53
+ pyparsing==3.0.9
54
+ pyrsistent==0.19.3
55
+ python-dateutil==2.8.2
56
+ python-multipart==0.0.6
57
+ pytz==2023.3
58
+ pywin32-ctypes==0.2.0
59
+ PyYAML==6.0
60
+ regex==2023.3.23
61
+ requests==2.28.2
62
+ semantic-version==2.10.0
63
+ six==1.16.0
64
+ sniffio==1.3.0
65
+ starlette==0.26.1
66
+ sympy==1.11.1
67
+ tokenizers==0.13.3
68
+ toolz==0.12.0
69
+ torch==2.0.0
70
+ tqdm==4.65.0
71
+ transformers==4.28.1
72
+ typing_extensions==4.5.0
73
+ tzdata==2023.3
74
+ uc-micro-py==1.0.1
75
+ urllib3==1.26.15
76
+ uvicorn==0.21.1
77
+ websockets==11.0.2
78
+ yarl==1.9.1