Johannes Stelzer commited on
Commit
5427bbe
1 Parent(s): 940cc9a
Files changed (1) hide show
  1. latentblending/gradio_ui.py +0 -153
latentblending/gradio_ui.py DELETED
@@ -1,153 +0,0 @@
1
- import os
2
- import torch
3
- torch.backends.cudnn.benchmark = False
4
- torch.set_grad_enabled(False)
5
- import numpy as np
6
- import warnings
7
- warnings.filterwarnings('ignore')
8
- from tqdm.auto import tqdm
9
- from PIL import Image
10
- import gradio as gr
11
- import shutil
12
- import uuid
13
- from diffusers import AutoPipelineForText2Image
14
- from latentblending.blending_engine import BlendingEngine
15
- import datetime
16
-
17
- warnings.filterwarnings('ignore')
18
- torch.set_grad_enabled(False)
19
- torch.backends.cudnn.benchmark = False
20
- import json
21
-
22
-
23
-
24
- class BlendingFrontend():
25
- def __init__(
26
- self,
27
- be,
28
- share=False):
29
- r"""
30
- Gradio Helper Class to collect UI data and start latent blending.
31
- Args:
32
- be:
33
- Blendingengine
34
- share: bool
35
- Set true to get a shareable gradio link (e.g. for running a remote server)
36
- """
37
- self.be = be
38
- self.share = share
39
-
40
- # UI Defaults
41
- self.seed1 = 420
42
- self.seed2 = 420
43
- self.prompt1 = ""
44
- self.prompt2 = ""
45
- self.negative_prompt = ""
46
-
47
- # Vars
48
- self.prompt = None
49
- self.negative_prompt = None
50
- self.list_seeds = []
51
- self.idx_movie = 0
52
- self.data = []
53
-
54
- def take_image0(self):
55
- return self.take_image(0)
56
-
57
- def take_image1(self):
58
- return self.take_image(1)
59
-
60
- def take_image2(self):
61
- return self.take_image(2)
62
-
63
- def take_image3(self):
64
- return self.take_image(3)
65
-
66
-
67
- def take_image(self, id_img):
68
- if self.prompt is None:
69
- print("Cannot take because no prompt was set!")
70
- return [None, None, None, None, ""]
71
- if self.idx_movie == 0:
72
- current_time = datetime.datetime.now()
73
- self.fp_out = "movie_" + current_time.strftime("%y%m%d_%H%M") + ".json"
74
- self.data.append({"settings": "sdxl", "width": bf.be.dh.width_img, "height": self.be.dh.height_img, "num_inference_steps": self.be.dh.num_inference_steps})
75
-
76
- seed = self.list_seeds[id_img]
77
-
78
- self.data.append({"iteration": self.idx_movie, "seed": seed, "prompt": self.prompt, "negative_prompt": self.negative_prompt})
79
-
80
- # Write the data list to a JSON file
81
- with open(self.fp_out, 'w') as f:
82
- json.dump(self.data, f, indent=4)
83
-
84
- self.idx_movie += 1
85
- self.prompt = None
86
- return [None, None, None, None, ""]
87
-
88
-
89
- def compute_imgs(self, prompt, negative_prompt):
90
- self.prompt = prompt
91
- self.negative_prompt = negative_prompt
92
- self.be.set_prompt1(prompt)
93
- self.be.set_prompt2(prompt)
94
- self.be.set_negative_prompt(negative_prompt)
95
- self.list_seeds = []
96
- self.list_images = []
97
- for i in range(4):
98
- seed = np.random.randint(0, 1000000000)
99
- self.be.seed1 = seed
100
- self.list_seeds.append(seed)
101
- img = self.be.compute_latents1(return_image=True)
102
- self.list_images.append(img)
103
- return self.list_images
104
-
105
-
106
-
107
-
108
- if __name__ == "__main__":
109
-
110
- width = 786
111
- height = 1024
112
- num_inference_steps = 4
113
-
114
- pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
115
- # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16")
116
- pipe.to("cuda")
117
-
118
- be = BlendingEngine(pipe)
119
- be.set_dimensions((width, height))
120
- be.set_num_inference_steps(num_inference_steps)
121
-
122
- bf = BlendingFrontend(be)
123
-
124
- with gr.Blocks() as demo:
125
-
126
- with gr.Row():
127
- prompt = gr.Textbox(label="prompt")
128
- negative_prompt = gr.Textbox(label="negative prompt")
129
-
130
- with gr.Row():
131
- b_compute = gr.Button('compute new images', variant='primary')
132
-
133
- with gr.Row():
134
- with gr.Column():
135
- img0 = gr.Image(label="seed1")
136
- b_take0 = gr.Button('take', variant='primary')
137
- with gr.Column():
138
- img1 = gr.Image(label="seed2")
139
- b_take1 = gr.Button('take', variant='primary')
140
- with gr.Column():
141
- img2 = gr.Image(label="seed3")
142
- b_take2 = gr.Button('take', variant='primary')
143
- with gr.Column():
144
- img3 = gr.Image(label="seed4")
145
- b_take3 = gr.Button('take', variant='primary')
146
-
147
- b_compute.click(bf.compute_imgs, inputs=[prompt, negative_prompt], outputs=[img0, img1, img2, img3])
148
- b_take0.click(bf.take_image0, outputs=[img0, img1, img2, img3, prompt])
149
- b_take1.click(bf.take_image1, outputs=[img0, img1, img2, img3, prompt])
150
- b_take2.click(bf.take_image2, outputs=[img0, img1, img2, img3, prompt])
151
- b_take3.click(bf.take_image3, outputs=[img0, img1, img2, img3, prompt])
152
-
153
- demo.launch(share=bf.share, inbrowser=True, inline=False, server_name="10.40.49.100")