File size: 4,397 Bytes
486a808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import random
import torch
import gc
import gradio as gr
from PIL import Image
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline,\
                      StableDiffusionXLPipeline, StableDiffusionUpscalePipeline,\
                      DiffusionPipeline
from utils import *
from style import custom_css, beta_header_html

def gen_image(prompt, negative_prompt, width, height, num_steps,
              mode, seed, guidance_scale, device):
    """
    Run diffusion model to generate image
    """
    use_adapter = True
    device = f"cuda:{device.split('GPU')[1][1]}"
    guidance_scale = float(guidance_scale)
    generator = torch.Generator(device).manual_seed(int(seed)) 
    model_path = DIFFUSION_CHECKPOINTS[mode]["path"]
    Text2Image_class = globals()[DIFFUSION_CHECKPOINTS[mode]["pipeline"]]
    if DIFFUSION_CHECKPOINTS[mode]["type"] == "pretrained":
        pipeline = Text2Image_class.from_pretrained(model_path)
    else:
        pipeline = Text2Image_class.from_single_file(model_path)

    if use_adapter:
        if "SDXL 1.0" in mode:
            print("Load LoRA model")
            pipeline.load_lora_weights("../checkpoints", weight_name="mod2.safetensors")

    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    try:
        pipeline = pipeline.to(device)
        image = pipeline(prompt=prompt,
                        negative_prompt=negative_prompt,
                        width=nearest_divisible_by_8(int(width)),
                        height=nearest_divisible_by_8(int(height)),
                        num_inference_steps=int(num_steps),
                        generator=generator,
                        guidance_scale=guidance_scale).images[0]
        del pipeline
        torch.cuda.empty_cache()
        gc.collect()
    except Exception as e:
        image = Image.open("stuffs/serverdown.jpg")
        print(e)
        del pipeline
        torch.cuda.empty_cache()
        gc.collect()
        return image
    return image

with gr.Blocks(title="(Beta) TonAI Creative", theme=APP_THEME) as interface1:
    gr.HTML(beta_header_html)
    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", 
                                placeholder="Tell me what you want to generate",
                                container=True)
            negative_prompt = gr.Textbox(label="Negative Prompt", 
                                         placeholder="Instruct the AI model that it should not include",
                                         container=True)
            with gr.Row():
                width = gr.Textbox(label="Image Width", value=768)
                height = gr.Textbox(label="Image Height", value=768)
            with gr.Row():
                seed = gr.Textbox(label="RNG Seed", value=0, scale=1)
                guidance_scale = gr.Textbox(label="CFG Scale", value=7, scale=1)
            with gr.Row():
                num_steps = gr.components.Slider(
                                minimum=5, maximum=60, value=20, step=1,
                                label="Inference Steps"
                            )
                mode=gr.Dropdown(choices=DIFFUSION_CHECKPOINTS.keys(), label="Mode",
                                 value=list(DIFFUSION_CHECKPOINTS.keys())[1])
            device_choices = display_gpu_info()
            device=gr.Dropdown(choices=device_choices, label="Device", value=device_choices[0])
            generate_btn = gr.Button("Generate")
        with gr.Column(scale=2):
            generate_btn.click(
                        fn=gen_image,
                        inputs=[prompt, negative_prompt, width, height, num_steps, mode, seed, guidance_scale, device],
                        outputs=gr.Image(label="Generated Image", format="png"),
                        concurrency_limit=5
                    )
        interface1.load(lambda: gr.update(value=random.randint(0, 99999)), None, seed)
        interface1.load(lambda: gr.update(choices=display_gpu_info(), value=display_gpu_info()[0]), None, device)

# interface = gr.TabbedInterface([interface1, iface2], ["Text-to-text", "image-to-text"])
allowed_paths=["stuffs/tonai_research_logo.png"]
interface1.queue(default_concurrency_limit=5)
interface1.launch(share=False,
                  allowed_paths=allowed_paths,
                  max_threads=5)