PFE / app.py
ohkarim's picture
Update app.py
29455bd verified
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
LORA=False
if torch.cuda.is_available():
torch.cuda.max_memory_allocated(device=device)
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipe.enable_xformers_memory_efficient_attention()
pipe = pipe.to(device)
else:
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def infer(prompt, negative_prompt, seed, randomize_seed, image_style,cfg,lora_scale):
pipe.unload_lora_weights()
if image_style == "No style":
pipe.unload_lora_weights()
elif image_style == "style_detailed":
pipe.load_lora_weights("ohkarim/LoRA_logos", weight_name="OH_logos.safetensors", adapter_name="OH_logos")
prompt=prompt+", detailed, close up, unique background color, OH_logos"
elif image_style == "style_blacknwhite":
pipe.load_lora_weights("ohkarim/lora_logo_blacknwhite", weight_name="bel_blacknwhite_lora.safetensors", adapter_name="bel_blacknwhite")
prompt=prompt+", black and white, balck and white, minimalist, unique background color, bel_blacknwhite"
elif image_style == "style_cartoon":
pipe.load_lora_weights("ohkarim/lora_logo_modern", weight_name="lora_modern.safetensors", adapter_name="oh_bel_modern")
prompt=prompt+", simple, modern, unique background color, oh_bel_modern"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
griddd = []
for _ in range(4):
image = pipe(
prompt = prompt,
negative_prompt = negative_prompt,
cross_attention_kwargs={"scale": lora_scale},
guidance_scale = cfg,
num_inference_steps = 20,
width = 512,
height = 512,
generator = generator
).images[0]
griddd.append(image)
return griddd
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
# Function to read CSS from file
def read_css_from_file(filename):
with open(filename, "r") as file:
return file.read()
# Read CSS from file
css = read_css_from_file("style.css")
PTI_SD_DESCRIPTION = '''
<div id="content_align">
<span style="color:darkred;font-size:32px;font-weight:bold">
Create your own logo now!
</span>
</div>
<div id="content_align">
<span style="color:blue;font-size:16px;font-weight:bold">
There is 3 styles so far, detailed, black & white and cartoon
</span>
</div>
<div id="content_align" style="margin-top: 10px;">
</div>
'''
# Creating Gradio interface
with gr.Blocks(css=css) as demo:
gr.Markdown(PTI_SD_DESCRIPTION)
with gr.Row():
with gr.Column():
text_prompt = gr.Textbox(label="Input Prompt", placeholder="Example: logo of a coffee shop, cup of coffee, mug, brown shades ", lines=2)
negative_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="Example: blurry, unfocused, complicated", lines=2)
image_style = gr.Dropdown(label="Choose a style", choices=["No style", "style_detailed", "style_blacknwhite", "style_cartoon"], value="No LoRA")
with gr.Column():
cfg = gr.Slider(label="Prompt guidance", minimum=1, maximum=20, step=1, value=4)
lora_scale = gr.Slider(label="Style guidance", minimum=0, maximum=2, step=0.01, value=0.9)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
generate_button = gr.Button("Generate Image", variant='primary')
with gr.Row():
output_image1 = gr.Image(type="pil", label="Image 1", width=256, height=400)
output_image2 = gr.Image(type="pil", label="Image 2", width=256, height=400)
output_image3 = gr.Image(type="pil", label="Image 3", width=256, height=400)
output_image4 = gr.Image(type="pil", label="Image 4", width=256, height=400)
generate_button.click(
fn = infer,
inputs = [text_prompt, negative_prompt, seed, randomize_seed,image_style,cfg,lora_scale],
outputs = [output_image1,output_image2,output_image3,output_image4],
show_progress=True
)
demo.queue().launch()