Guinnessgshep Catmeow commited on
Commit
424e8fd
0 Parent(s):

Duplicate from Catmeow/AI_story_writing

Browse files

Co-authored-by: XY Chen <Catmeow@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +33 -0
  2. README.md +13 -0
  3. app.py +44 -0
  4. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AI Story Writing
3
+ emoji: 📚
4
+ colorFrom: pink
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.8
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: Catmeow/AI_story_writing
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ title = "tory Generator"
4
+
5
+ # gpt-neo-2.7B gpt-j-6B
6
+
7
+ def generate(text,the_model,max_length,temperature,repetition_penalty):
8
+ generator = pipeline('text-generation', model=the_model)
9
+ result = generator(text, num_return_sequences=3,
10
+ max_length=max_length,
11
+ temperature=temperature,
12
+ repetition_penalty = repetition_penalty,
13
+ no_repeat_ngram_size=2,early_stopping=False)
14
+ return result[0]["generated_text"],result[1]["generated_text"],result[2]["generated_text"]
15
+
16
+
17
+ def complete_with_gpt(text,context,the_model,max_length,temperature,repetition_penalty):
18
+ # Use the last [context] characters of the text as context
19
+ max_length = max_length+context
20
+ return generate(text[-context:],the_model,max_length,temperature,repetition_penalty)
21
+
22
+ def send(text1,context,text2):
23
+ if len(text1)<context:
24
+ return text1 + text2[len(text1):]
25
+ else:
26
+ return text1 + text2[context:]
27
+
28
+ with gr.Blocks() as demo:
29
+ textbox = gr.Textbox(placeholder="Type here and press enter...", lines=4)
30
+ btn = gr.Button("Generate")
31
+ context = gr.Slider(value=200,label="Truncate input text length (AI's memory)",minimum=1,maximum=500)
32
+ the_model = gr.Dropdown(choices=['gpt2','gpt2-medium','gpt2-large','gpt2-xl','EleutherAI/gpt-neo-2.7B','EleutherAI/gpt-j-6B'],value = 'gpt2',label="Choose model")
33
+ max_length = gr.Slider(value=20,label="Max Generate Length",minimum=1,maximum=50)
34
+ temperature = gr.Slider(value=0.9,label="Temperature",minimum=0.0,maximum=1.0,step=0.05)
35
+ repetition_penalty = gr.Slider(value=1.5,label="Repetition penalty",minimum=0.2,maximum=2,step=0.1)
36
+ output1 = gr.Textbox(lines=4,label='1')
37
+ send1 = gr.Button("Send1 to Origin Textbox").click(send,inputs=[textbox,context,output1],outputs=textbox)
38
+ output2 = gr.Textbox(lines=4,label='2')
39
+ send2 = gr.Button("Send2 to Origin Textbox").click(send,inputs=[textbox,context,output2],outputs=textbox)
40
+ output3 = gr.Textbox(lines=4,label='3')
41
+ send3 = gr.Button("Send3 to Origin Textbox").click(send,inputs=[textbox,context,output3],outputs=textbox)
42
+ btn.click(complete_with_gpt,inputs=[textbox,context,the_model,max_length,temperature,repetition_penalty], outputs=[output1,output2,output3])
43
+
44
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ gradio
3
+ torch