Elle McFarlane commited on
Commit
19ec9f5
·
1 Parent(s): 48c898d

update to multi-model selection

Browse files
Files changed (3) hide show
  1. app.py +59 -127
  2. model.py +211 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,146 +1,78 @@
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
- from diffusers import DiffusionPipeline
5
- import torch
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
17
 
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
22
 
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
27
-
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
-
38
- return image
39
 
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
44
- ]
 
45
 
46
- css="""
47
- #col-container {
48
- margin: 0 auto;
49
- max-width: 520px;
50
- }
51
- """
52
 
53
- if torch.cuda.is_available():
54
- power_device = "GPU"
55
- else:
56
- power_device = "CPU"
57
 
58
- with gr.Blocks(css=css) as demo:
59
-
60
- with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
- Currently running on {power_device}.
64
- """)
65
-
66
- with gr.Row():
67
-
68
- prompt = gr.Text(
69
- label="Prompt",
70
- show_label=False,
71
- max_lines=1,
72
- placeholder="Enter your prompt",
73
- container=False,
74
- )
75
-
76
- run_button = gr.Button("Run", scale=0)
77
-
78
- result = gr.Image(label="Result", show_label=False)
79
 
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
- max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
- )
88
-
89
- seed = gr.Slider(
90
- label="Seed",
91
- minimum=0,
92
- maximum=MAX_SEED,
93
- step=1,
94
- value=0,
95
- )
96
-
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
  with gr.Row():
100
-
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
- )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
- )
116
-
117
  with gr.Row():
118
-
119
- guidance_scale = gr.Slider(
120
- label="Guidance scale",
121
- minimum=0.0,
122
- maximum=10.0,
123
- step=0.1,
124
- value=0.0,
125
- )
126
-
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=12,
131
- step=1,
132
- value=2,
133
  )
134
-
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
- )
139
 
140
  run_button.click(
141
- fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
 
146
- demo.queue().launch()
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
  import gradio as gr
6
  import numpy as np
 
 
 
7
 
8
+ from model import Model
9
 
10
+ DESCRIPTION = "# [AvantGAN](https://github.com/ellemcfarlane/AvantGAN)"
 
 
 
 
 
 
 
11
 
 
 
12
 
13
+ def get_sample_image_url(name: str) -> str:
14
+ sample_image_dir = "https://huggingface.co/spaces/ellemac/avantGAN/resolve/main/samples"
15
+ return f"{sample_image_dir}/{name}.png"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def get_sample_image_markdown(name: str) -> str:
19
+ url = get_sample_image_url(name)
20
+ size = 128 if ("stylegan3" in name or "original" in name) else 64
21
+ return f"""
22
+ - size: {size}x{size}
23
+ ![sample images]({url})"""
24
 
 
 
 
 
 
 
25
 
26
+ model = Model()
 
 
 
27
 
28
+ with gr.Blocks(css="style.css") as demo:
29
+ gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ with gr.Tabs():
32
+ with gr.TabItem("App"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  with gr.Row():
34
+ with gr.Column():
35
+ model_name = gr.Dropdown(
36
+ label="Model", choices=list(model.MODEL_DICT.keys()), value="stylegan3-abstract"
37
+ )
38
+ seed = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.uint32).max, step=1, value=0)
39
+ run_button = gr.Button()
40
+ with gr.Column():
41
+ result = gr.Image(label="Result", elem_id="result", width=300, height=300)
42
+ print("RESULT", result, type(result), result.__dict__)
43
+
44
+ with gr.TabItem("Sample Images"):
 
 
 
 
 
 
45
  with gr.Row():
46
+ model_name2 = gr.Dropdown(
47
+ [
48
+ "stylegan3-abstract",
49
+ "stylegan3-high-fidelity",
50
+ "ada-dcgan",
51
+ "original-training-data",
52
+ ],
53
+ value="stylegan3-abstract",
54
+ label="Model",
 
 
 
 
 
 
55
  )
56
+ with gr.Row():
57
+ text = get_sample_image_markdown(model_name2.value)
58
+ sample_images = gr.Markdown(text)
 
 
59
 
60
  run_button.click(
61
+ fn=model.set_model_and_generate_image,
62
+ inputs=[
63
+ model_name,
64
+ seed,
65
+ ],
66
+ outputs=result,
67
+ api_name="run",
68
+ )
69
+ model_name2.change(
70
+ fn=get_sample_image_markdown,
71
+ inputs=model_name2,
72
+ outputs=sample_images,
73
+ queue=False,
74
+ api_name=False,
75
  )
76
 
77
+ if __name__ == "__main__":
78
+ demo.queue(max_size=20).launch()
model.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://huggingface.co/spaces/hysts/StyleGAN3/blob/main/model.py
2
+
3
+ import pathlib
4
+ import pickle
5
+ import sys
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ import torch
13
+ import torchvision.utils as vutils
14
+ import matplotlib.pyplot as plt
15
+ from io import BytesIO
16
+ from PIL import Image
17
+
18
+ current_dir = pathlib.Path(__file__).parent
19
+ submodule_dir = current_dir / "stylegan3"
20
+ sys.path.insert(0, submodule_dir.as_posix())
21
+
22
+ user = "ellemac"
23
+ dcgan_z_dim = 100
24
+ dcgan_gen_feats = 64
25
+ ngf = 64
26
+ dcgan_img_size = 64
27
+ nc = 3
28
+
29
+
30
+ class Generator(nn.Module):
31
+ def __init__(self, ngpu, nz):
32
+ super(Generator, self).__init__()
33
+ self.ngpu = ngpu
34
+ self.main = nn.Sequential(
35
+ # input is Z, going into a convolution
36
+ nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
37
+ nn.BatchNorm2d(ngf * 8),
38
+ nn.LeakyReLU(0.2, inplace=True),
39
+ # state size. (ngf*8) x 4 x 4
40
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
41
+ nn.BatchNorm2d(ngf * 4),
42
+ nn.LeakyReLU(0.2, inplace=True),
43
+ # state size. (ngf*4) x 8 x 8
44
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
45
+ nn.BatchNorm2d(ngf * 2),
46
+ nn.LeakyReLU(0.2, inplace=True),
47
+ # state size. (ngf*2) x 16 x 16
48
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
49
+ nn.BatchNorm2d(ngf),
50
+ nn.LeakyReLU(0.2, inplace=True),
51
+ # state size. (ngf) x 32 x 32
52
+ nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
53
+ nn.Tanh()
54
+ # state size. (nc) x 64 x 64
55
+ )
56
+
57
+ def forward(self, input):
58
+ return self.main(input)
59
+
60
+ # class Generator(nn.Module):
61
+ # def __init__(self, n_gen_feats, n_gpu, z_dim, n_channels):
62
+ # super(Generator, self).__init__()
63
+ # self.n_gpu = n_gpu
64
+ # self.main = nn.Sequential(
65
+ # # input is Z, going into a convolution
66
+ # nn.ConvTranspose2d(z_dim, n_gen_feats * 8, 4, 1, 0, bias=False),
67
+ # nn.BatchNorm2d(n_gen_feats * 8),
68
+ # nn.LeakyReLU(0.2, inplace=True),
69
+ # # state size. (n_gen_feats*8) x 4 x 4
70
+ # nn.ConvTranspose2d(n_gen_feats * 8, n_gen_feats * 4, 4, 2, 1, bias=False),
71
+ # nn.BatchNorm2d(n_gen_feats * 4),
72
+ # nn.LeakyReLU(0.2, inplace=True),
73
+ # # state size. (n_gen_feats*4) x 8 x 8
74
+ # nn.ConvTranspose2d(n_gen_feats * 4, n_gen_feats * 2, 4, 2, 1, bias=False),
75
+ # nn.BatchNorm2d(n_gen_feats * 2),
76
+ # nn.LeakyReLU(0.2, inplace=True),
77
+ # # state size. (n_gen_feats*2) x 16 x 16
78
+ # nn.ConvTranspose2d(n_gen_feats * 2, n_gen_feats, 4, 2, 1, bias=False),
79
+ # nn.BatchNorm2d(n_gen_feats),
80
+ # nn.LeakyReLU(0.2, inplace=True),
81
+ # # state size. (n_gen_feats) x 32 x 32
82
+ # nn.ConvTranspose2d(n_gen_feats, n_channels, 4, 2, 1, bias=False),
83
+ # nn.Tanh()
84
+ # # state size. (n_channels) x 64 x 64
85
+ # )
86
+
87
+ # def forward(self, input):
88
+ # return self.main(input)
89
+
90
+ class Model:
91
+ MODEL_DICT = {
92
+ "stylegan3-abstract": {"name": "abstract-560eps.pkl", "repo": "avantStyleGAN3"},
93
+ "stylegan3-high-fidelity": {"name": "high-fidelity-1120eps.pkl", "repo": "avantStyleGAN3"},
94
+ "ada-dcgan": {"name": "gen_6kepoch.pt", "repo": "avantGAN"},
95
+ }
96
+
97
+ def __init__(self):
98
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
99
+ self._download_all_models()
100
+ self.model_name = "ada-dcgan" #stylegan3-abstract"
101
+ self.model = self._load_model(self.model_name)
102
+
103
+ def _load_model(self, model_name: str) -> nn.Module:
104
+ file_name = self.MODEL_DICT[model_name]["name"]
105
+ repo = self.MODEL_DICT[model_name]["repo"]
106
+ path = hf_hub_download(f"{user}/{repo}", file_name) # model repo-type
107
+ if "stylegan" in model_name:
108
+ with open(path, "rb") as f:
109
+ model = pickle.load(f)["G_ema"]
110
+ else:
111
+ # todo (elle): don't hardcode the config
112
+ # model = Generator(dcgan_gen_feats, 1, dcgan_z_dim, 3)
113
+ print("WAS HERE")
114
+ model = Generator(0, 100)
115
+
116
+ model.eval()
117
+ model.to(self.device)
118
+ return model
119
+
120
+ def set_model(self, model_name: str) -> None:
121
+ if model_name == self.model_name:
122
+ return
123
+ self.model_name = model_name
124
+ self.model = self._load_model(model_name)
125
+
126
+ def _download_all_models(self):
127
+ for name in self.MODEL_DICT.keys():
128
+ self._load_model(name)
129
+
130
+ @staticmethod
131
+ def make_transform(translate: tuple[float, float] = (0,0), angle: float = 0) -> np.ndarray:
132
+ mat = np.eye(3)
133
+ sin = np.sin(angle / 360 * np.pi * 2)
134
+ cos = np.cos(angle / 360 * np.pi * 2)
135
+ mat[0][0] = cos
136
+ mat[0][1] = sin
137
+ mat[0][2] = translate[0]
138
+ mat[1][0] = -sin
139
+ mat[1][1] = cos
140
+ mat[1][2] = translate[1]
141
+ return mat
142
+
143
+ def generate_z(self, seed: int) -> torch.Tensor:
144
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
145
+ z = np.random.RandomState(seed).randn(1, self.model.z_dim)
146
+ return torch.from_numpy(z).float().to(self.device)
147
+
148
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
149
+ tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
150
+ return tensor.cpu().numpy()
151
+
152
+ def dcgan_postprocess(self, tensor: torch.Tensor) -> np.ndarray:
153
+ tensor = (tensor.permute(0, 2, 3, 1)).clamp(0, 255).to(torch.uint8)
154
+ return tensor.cpu().numpy()
155
+
156
+ def set_transform(self, tx: float = 0, ty: float = 0, angle: float = 0) -> None:
157
+ mat = self.make_transform((tx, ty), angle)
158
+ mat = np.linalg.inv(mat)
159
+ self.model.synthesis.input.transform.copy_(torch.from_numpy(mat))
160
+
161
+ @torch.inference_mode()
162
+ def generate(self, z: torch.Tensor, label: torch.Tensor, truncation_psi: float) -> torch.Tensor:
163
+ return self.model(z, label, truncation_psi=truncation_psi)
164
+
165
+ def generate_image(self, seed: int, truncation_psi: float = 0, tx: float = 0, ty: float = 0, angle: float = 0) -> np.ndarray:
166
+ self.set_transform(tx, ty, angle)
167
+
168
+ z = self.generate_z(seed)
169
+ label = torch.zeros([1, self.model.c_dim], device=self.device)
170
+
171
+ out = self.generate(z, label, truncation_psi)
172
+ out = self.postprocess(out)
173
+ return out[0]
174
+
175
+ def dcgan_generate_image(self, seed: int) -> np.ndarray:
176
+ dcgan_img_size = 64
177
+ dcgan_z_dim = 100
178
+
179
+ with torch.no_grad():
180
+ n_images = 1
181
+ z = torch.randn(n_images, dcgan_z_dim, 1, 1, device=self.device)
182
+ fake_images = self.model(z.to(self.device)).cpu()
183
+ fake_images = fake_images.view(fake_images.size(0), 3, dcgan_img_size, dcgan_img_size)
184
+
185
+ print('fake', fake_images)
186
+ print(fake_images.min(), fake_images.max())
187
+ # Create a grid of images
188
+ grid = vutils.make_grid(fake_images, normalize=True)
189
+ print('grid', grid)
190
+ # Plot the grid and save it to a buffer
191
+ fig, ax = plt.subplots()
192
+ ax.imshow(grid.permute(1, 2, 0)) # Convert from CHW to HWC for imshow
193
+ plt.axis('off')
194
+
195
+ # Save the plot to a buffer
196
+ buf = BytesIO()
197
+ plt.savefig(buf, format='png')
198
+ buf.seek(0)
199
+
200
+ # Load the buffer into a PIL Image
201
+ img = Image.open(buf)
202
+ return img
203
+
204
+ def set_model_and_generate_image(
205
+ self, model_name: str, seed: int, truncation_psi: float = 0, tx: float = 0, ty: float = 0, angle: float = 0
206
+ ) -> np.ndarray:
207
+ self.set_model(model_name)
208
+ if "stylegan3" in model_name:
209
+ return self.generate_image(seed, truncation_psi, tx, ty, angle)
210
+ else:
211
+ return self.dcgan_generate_image(seed)
requirements.txt CHANGED
@@ -3,4 +3,4 @@ diffusers
3
  invisible_watermark
4
  torch
5
  transformers
6
- xformers
 
3
  invisible_watermark
4
  torch
5
  transformers
6
+ xformers