hansyan's picture
Update app.py
49c4077 verified
raw
history blame contribute delete
No virus
3.92 kB
import spaces
import random
import gradio as gr
import numpy as np
import torch
from PIL import Image
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
### PeRFlow-T2I
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained("hansyan/perflow-sdxl-dreamshaper", torch_dtype=torch.float16, use_safetensors=True, variant="v0-fix")
from src.scheduler_perflow import PeRFlowScheduler
pipe.scheduler = PeRFlowScheduler.from_config(pipe.scheduler.config, prediction_type="ddim_eps", num_time_windows=4)
pipe.to("cuda:0", torch.float16)
# pipe_t2i = None
### gradio
@spaces.GPU
def generate(text, num_inference_steps, cfg_scale, seed):
setup_seed(int(seed))
num_inference_steps = int(num_inference_steps)
cfg_scale = float(cfg_scale)
prompt_prefix = "photorealistic, uhd, high resolution, high quality, highly detailed; "
neg_prompt = "distorted, blur, low-quality, haze, out of focus"
text = prompt_prefix + text
samples = pipe(
prompt = [text],
negative_prompt = [neg_prompt],
height = 1024,
width = 1024,
num_inference_steps = num_inference_steps,
guidance_scale = cfg_scale,
output_type = 'pt',
).images
samples = samples.squeeze(0).permute(1, 2, 0).cpu().numpy()*255.
samples = samples.astype(np.uint8)
samples = Image.fromarray(samples[:, :, :3])
return samples
# layout
css = """
h1 {
text-align: center;
display:block;
}
h2 {
text-align: center;
display:block;
}
h3 {
text-align: center;
display:block;
}
.gradio-container {
max-width: 768px !important;
}
"""
with gr.Blocks(title="PeRFlow-SDXL", css=css) as interface:
gr.Markdown(
"""
# PeRFlow-SDXL
GitHub: [https://github.com/magic-research/piecewise-rectified-flow](https://github.com/magic-research/piecewise-rectified-flow) <br/>
Models: [https://huggingface.co/hansyan/perflow-sdxl-dreamshaper](https://huggingface.co/hansyan/perflow-sdxl-dreamshaper)
<br/>
"""
)
with gr.Column():
text = gr.Textbox(
label="Input Prompt",
value="masterpiece, A closeup face photo of girl, wearing a rain coat, in the street, heavy rain, bokeh"
)
with gr.Row():
num_inference_steps = gr.Dropdown(label='Num Inference Steps',choices=[4,5,6,7,8], value=6, interactive=True)
cfg_scale = gr.Dropdown(label='CFG scale',choices=[1.5, 2.0, 2.5], value=2.0, interactive=True)
seed = gr.Textbox(label="Random Seed", value=42)
submit = gr.Button(scale=1, variant='primary')
# with gr.Column():
# with gr.Row():
output_image = gr.Image(label='Generated Image')
gr.Markdown(
"""
Here are some examples provided:
- “masterpiece, A closeup face photo of girl, wearing a rain coat, in the street, heavy rain, bokeh”
- “RAW photo, a handsome man, wearing a black coat, outside, closeup face”
- “RAW photo, a red luxury car, studio light”
- “masterpiece, A beautiful cat bask in the sun”
"""
)
# activate
text.submit(
fn=generate,
inputs=[text, num_inference_steps, cfg_scale, seed],
outputs=[output_image],
)
seed.submit(
fn=generate,
inputs=[text, num_inference_steps, cfg_scale, seed],
outputs=[output_image],
)
submit.click(fn=generate,
inputs=[text, num_inference_steps, cfg_scale, seed],
outputs=[output_image],
)
if __name__ == '__main__':
interface.queue(max_size=10)
interface.launch()