mlnotes commited on
Commit
db02454
1 Parent(s): 09cf298

ADD app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io, os, base64
2
+ from PIL import Image
3
+ import gradio as gr
4
+ import shortuuid
5
+ from transformers import pipeline
6
+
7
+
8
+ text_generation_model = "pranavpsv/gpt2-genre-story-generator"
9
+ text_generation = pipeline("text-generation", text_generation_model)
10
+ latent = gr.Interface.load("spaces/multimodalart/latentdiffusion")
11
+
12
+
13
+ def get_story(user_input, genre="sci_fi"):
14
+ prompt = f"<BOS> <{genre}> "
15
+ stories = text_generation(f"{prompt}{user_input}", max_length=32, num_return_sequences=1)
16
+ story = stories[0]["generated_text"]
17
+ story_without_prompt = story[len(prompt):]
18
+ return story_without_prompt
19
+
20
+
21
+ def text2image_latent(text, steps, width, height, images, diversity):
22
+ print(text)
23
+ results = latent(text, steps, width, height, images, diversity)
24
+ image_paths = []
25
+ for image in results[1]:
26
+ image_str = image[0]
27
+ image_str = image_str.replace("data:image/png;base64,","")
28
+ decoded_bytes = base64.decodebytes(bytes(image_str, "utf-8"))
29
+ img = Image.open(io.BytesIO(decoded_bytes))
30
+ url = shortuuid.uuid()
31
+ temp_dir = './tmp'
32
+ if not os.path.exists(temp_dir):
33
+ os.makedirs(temp_dir, exist_ok=True)
34
+ image_path = f'{temp_dir}/{url}.png'
35
+ img.save(f'{temp_dir}/{url}.png')
36
+ image_paths.append(image_path)
37
+ return(image_paths)
38
+
39
+
40
+ with gr.Blocks() as demo:
41
+ with gr.Row():
42
+ with gr.Column():
43
+ user_input = gr.inputs.Textbox(placeholder="Type your prompt to generate an image", label="Prompt - try adding increments to your prompt such as 'a painting of', 'in the style of Picasso'", default="A giant mecha robot in Rio de Janeiro, oil on canvas")
44
+ genre_input = gr.Dropdown(["superhero","action","drama","horror","thriller","sci_fi",])
45
+ generated_story = gr.Textbox()
46
+ with gr.Row():
47
+ button_generate_story = gr.Button("Generate Story")
48
+ with gr.Column():
49
+ steps = gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=50,maximum=50,minimum=1,step=1)
50
+ width = gr.inputs.Slider(label="Width", default=256, step=32, maximum=256, minimum=32)
51
+ height = gr.inputs.Slider(label="Height", default=256, step=32, maximum = 256, minimum=32)
52
+ images = gr.inputs.Slider(label="Images - How many images you wish to generate", default=4, step=1, minimum=1, maximum=4)
53
+ diversity = gr.inputs.Slider(label="Diversity scale - How different from one another you wish the images to be",default=15.0, minimum=1.0, maximum=15.0)
54
+ with gr.Column():
55
+ gallery = gr.Gallery(label="Individual images")
56
+ with gr.Row():
57
+ get_image_latent = gr.Button("Generate Image", css={"margin-top": "1em"})
58
+ with gr.Row():
59
+ gr.Markdown("<a href='https://huggingface.co/spaces/merve/GPT-2-story-gen' target='_blank'>Story generation with GPT-2</a>, and text to image by <a href='https://huggingface.co/spaces/multimodalart/latentdiffusion' target='_blank'>Latent Diffusion</a>.")
60
+
61
+
62
+ button_generate_story.click(get_story, inputs=[user_input, genre_input], outputs=generated_story)
63
+ get_image_latent.click(text2image_latent, inputs=[generated_story,steps,width,height,images,diversity], outputs=gallery)
64
+
65
+
66
+ demo.launch(enable_queue=False)