File size: 3,511 Bytes
db02454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f1d46
 
 
db02454
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import io, os, base64
from PIL import Image
import gradio as gr
import shortuuid
from transformers import pipeline


text_generation_model = "pranavpsv/gpt2-genre-story-generator"
text_generation = pipeline("text-generation", text_generation_model)
latent = gr.Interface.load("spaces/multimodalart/latentdiffusion")


def get_story(user_input, genre="sci_fi"):
    prompt = f"<BOS> <{genre}> "
    stories = text_generation(f"{prompt}{user_input}", max_length=32, num_return_sequences=1)
    story = stories[0]["generated_text"]
    story_without_prompt = story[len(prompt):]
    return story_without_prompt
    

def text2image_latent(text, steps, width, height, images, diversity):
    print(text)
    results = latent(text, steps, width, height, images, diversity)
    image_paths = []
    for image in results[1]:
        image_str = image[0]
        image_str = image_str.replace("data:image/png;base64,","")
        decoded_bytes = base64.decodebytes(bytes(image_str, "utf-8"))
        img = Image.open(io.BytesIO(decoded_bytes))
        url = shortuuid.uuid()
        temp_dir = './tmp'
        if not os.path.exists(temp_dir):
            os.makedirs(temp_dir, exist_ok=True)
        image_path = f'{temp_dir}/{url}.png'
        img.save(f'{temp_dir}/{url}.png')
        image_paths.append(image_path)
    return(image_paths)
    
    
with gr.Blocks() as demo:    
    with gr.Row():
        with gr.Column():
            user_input = gr.inputs.Textbox(placeholder="Type your prompt to generate an image", label="Prompt - try adding increments to your prompt such as 'a painting of', 'in the style of Picasso'", default="A giant mecha robot in Rio de Janeiro, oil on canvas")
            genre_input = gr.Dropdown(["superhero","action","drama","horror","thriller","sci_fi",])
            generated_story = gr.Textbox()
            with gr.Row():
                button_generate_story = gr.Button("Generate Story")        
        with gr.Column():
            steps = gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=50,maximum=50,minimum=1,step=1)
            width = gr.inputs.Slider(label="Width", default=256, step=32, maximum=256, minimum=32)
            height = gr.inputs.Slider(label="Height", default=256, step=32, maximum = 256, minimum=32)
            images = gr.inputs.Slider(label="Images - How many images you wish to generate", default=4, step=1, minimum=1, maximum=4)
            diversity = gr.inputs.Slider(label="Diversity scale - How different from one another you wish the images to be",default=15.0, minimum=1.0, maximum=15.0)
        with gr.Column():
            gallery = gr.Gallery(label="Individual images")
            with gr.Row():
                get_image_latent = gr.Button("Generate Image", css={"margin-top": "1em"})  
    with gr.Row():
        gr.Markdown("<a href='https://huggingface.co/spaces/merve/GPT-2-story-gen' target='_blank'>Story generation with GPT-2</a>, and text to image by <a href='https://huggingface.co/spaces/multimodalart/latentdiffusion' target='_blank'>Latent Diffusion</a>.")
    with gr.Row():
        gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=gradio-blocks_latent_gpt2_story)")
        

    button_generate_story.click(get_story, inputs=[user_input, genre_input], outputs=generated_story)
    get_image_latent.click(text2image_latent, inputs=[generated_story,steps,width,height,images,diversity], outputs=gallery)


demo.launch(enable_queue=False)