StyleForge / app.py
Sambhavnoobcoder's picture
Rename app (1).py to app.py
a423523
import gradio as gr
import torch
from transformers import logging
import random
from PIL import Image
from Utils import MingleModel
logging.set_verbosity_error()
def get_concat_h(images):
widths, heights = zip(*(i.size for i in images))
total_width = sum(widths)
max_height = max(heights)
dst = Image.new('RGB', (total_width, max_height))
x_offset = 0
for im in images:
dst.paste(im, (x_offset,0))
x_offset += im.size[0]
return dst
mingle_model = MingleModel()
def mingle_prompts(first_prompt, second_prompt):
imgs = []
text_input1 = mingle_model.do_tokenizer(first_prompt)
text_input2 = mingle_model.do_tokenizer(second_prompt)
with torch.no_grad():
text_embeddings1 = mingle_model.get_text_encoder(text_input1)
text_embeddings2 = mingle_model.get_text_encoder(text_input2)
rand_generator = random.randint(1, 2048)
# Mix them together
# mix_factors = [0.1, 0.3, 0.5, 0.7, 0.9]
mix_factors = [0.5]
for mix_factor in mix_factors:
mixed_embeddings = (text_embeddings1 * mix_factor + text_embeddings2 * (1 - mix_factor))
# Generate!
steps = 20
guidence_scale = 8.0
img = mingle_model.generate_with_embs(mixed_embeddings, rand_generator, num_inference_steps=steps,
guidance_scale=guidence_scale)
imgs.append(img)
return get_concat_h(imgs)
with gr.Blocks() as demo:
gr.Markdown(
'''
<h1 style="text-align: center;"> Fashion Generator GAN</h1>
''')
gr.Markdown(
'''
<h3 style="text-align: center;"> Note : the gan is extremely resource extensive, so it running the inference on cpu takes long time . kindly wait patiently while the model generates the output. </h3>
''')
gr.Markdown(
'''
<p style="text-align: center;">generated an image as an average of 2 prompts inserted !!</p>
''')
first_prompt = gr.Textbox(label="first_prompt")
second_prompt = gr.Textbox(label="second_prompt")
greet_btn = gr.Button("Submit")
gr.Markdown("# Output Results")
output = gr.Image(shape=(512,512))
greet_btn.click(fn=mingle_prompts, inputs=[first_prompt, second_prompt], outputs=[output])
demo.launch()