File size: 2,292 Bytes
8c97ab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import torch
from transformers import logging
import random
from PIL import Image
from Utils import MingleModel

logging.set_verbosity_error()


def get_concat_h(images):
    widths, heights = zip(*(i.size for i in images))

    total_width = sum(widths)
    max_height = max(heights)

    dst = Image.new('RGB', (total_width, max_height))
    x_offset = 0
    for im in images:
      dst.paste(im, (x_offset,0))
      x_offset += im.size[0]
    return dst


mingle_model = MingleModel()


def mingle_prompts(first_prompt, second_prompt):
    imgs = []
    text_input1 = mingle_model.do_tokenizer(first_prompt)
    text_input2 = mingle_model.do_tokenizer(second_prompt)
    with torch.no_grad():
        text_embeddings1 = mingle_model.get_text_encoder(text_input1)
        text_embeddings2 = mingle_model.get_text_encoder(text_input2)

    rand_generator = random.randint(1, 2048)
    # Mix them together
    # mix_factors = [0.1, 0.3, 0.5, 0.7, 0.9]
    mix_factors = [0.5]
    for mix_factor in mix_factors:
        mixed_embeddings = (text_embeddings1 * mix_factor + text_embeddings2 * (1 - mix_factor))

        # Generate!
        steps = 20
        guidence_scale = 8.0
        img = mingle_model.generate_with_embs(mixed_embeddings, rand_generator, num_inference_steps=steps,
                                 guidance_scale=guidence_scale)
        imgs.append(img)

    return get_concat_h(imgs)


with gr.Blocks() as demo:
    gr.Markdown(
        '''
        <h1 style="text-align: center;"> Fashion Generator GAN</h1>
        ''')

    gr.Markdown(
        '''
        <h3 style="text-align: center;"> Note : the gan is extremely resource extensive, so it running the inference on cpu takes long time . kindly wait patiently while the model generates the output. </h3>
        ''')
    
    gr.Markdown(
        '''
        <p style="text-align: center;">generated an image as an average of 2 prompts inserted !!</p>
        ''')

    first_prompt = gr.Textbox(label="first_prompt")
    second_prompt = gr.Textbox(label="second_prompt")
    greet_btn = gr.Button("Submit")

    gr.Markdown("# Output Results")
    output = gr.Image(shape=(512,512))

    greet_btn.click(fn=mingle_prompts, inputs=[first_prompt, second_prompt], outputs=[output])

demo.launch()