Spaces:
Build error
Build error
modify phrase
Browse files- app.py +19 -4
- stable_diffusion.py +103 -30
app.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
-
from stable_diffusion import
|
2 |
import gradio as gr
|
3 |
args = Args("", 5, None, 7.5, 512, 512, 443, "cpu", "./mdjrny-v4.pt")
|
4 |
-
model =
|
5 |
def text2img_output(phrase):
|
6 |
return model(phrase)
|
7 |
|
8 |
readme = open("me.md","rb+").read().decode("utf-8")
|
9 |
|
10 |
phrase = gr.components.Textbox(
|
11 |
-
value="
|
12 |
text2img_out = gr.components.Image(type="numpy")
|
13 |
|
14 |
instance = gr.Blocks()
|
@@ -22,4 +22,19 @@ with instance:
|
|
22 |
gr.Markdown(readme)
|
23 |
|
24 |
|
25 |
-
instance.queue(concurrency_count=20).launch(share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from stable_diffusion import Generate2img, Args
|
2 |
import gradio as gr
|
3 |
args = Args("", 5, None, 7.5, 512, 512, 443, "cpu", "./mdjrny-v4.pt")
|
4 |
+
model = Generate2img.instance(args)
|
5 |
def text2img_output(phrase):
|
6 |
return model(phrase)
|
7 |
|
8 |
readme = open("me.md","rb+").read().decode("utf-8")
|
9 |
|
10 |
phrase = gr.components.Textbox(
|
11 |
+
value="anthropomorphic cat portrait art")
|
12 |
text2img_out = gr.components.Image(type="numpy")
|
13 |
|
14 |
instance = gr.Blocks()
|
|
|
22 |
gr.Markdown(readme)
|
23 |
|
24 |
|
25 |
+
instance.queue(concurrency_count=20).launch(share=False)
|
26 |
+
#
|
27 |
+
#
|
28 |
+
# 1) anthropomorphic cat portrait art
|
29 |
+
#
|
30 |
+
# ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered.png)
|
31 |
+
#
|
32 |
+
# 2) anthropomorphic cat portrait art(mdjrny-v4.pt)
|
33 |
+
#
|
34 |
+
# ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered2.png)
|
35 |
+
#
|
36 |
+
# 3) Kung Fu Panda(weight: wd-1-3-penultimate-ucg-cont.pt, steps:50)
|
37 |
+
#
|
38 |
+
# ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered3.png)
|
39 |
+
# ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered4.png)
|
40 |
+
#
|
stable_diffusion.py
CHANGED
@@ -12,13 +12,13 @@ from collections import namedtuple
|
|
12 |
import numpy as np
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
-
|
16 |
-
from torch.nn import Conv2d, Linear, Module,SiLU, UpsamplingNearest2d,ModuleList
|
17 |
from torch import Tensor
|
18 |
from torch.nn import functional as F
|
19 |
from torch.nn.parameter import Parameter
|
20 |
|
21 |
-
device = "
|
22 |
|
23 |
def apply_seq(seqs, x):
|
24 |
for seq in seqs:
|
@@ -31,6 +31,12 @@ def gelu(self):
|
|
31 |
def quick_gelu(x):
|
32 |
return x * torch.sigmoid(x * 1.702)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
class Normalize(Module):
|
35 |
def __init__(self, in_channels, num_groups=32, name="normalize"):
|
36 |
super(Normalize, self).__init__()
|
@@ -166,13 +172,13 @@ class Encoder(Module):
|
|
166 |
self.down = ModuleList([
|
167 |
ResnetBlock(128, 128, name=name + "_down_block_0_0_ResnetBlock"),
|
168 |
ResnetBlock(128, 128, name=name + "_down_block_0_1_ResnetBlock"),
|
169 |
-
Conv2d(128, 128, 3, stride=2, padding=(0,
|
170 |
ResnetBlock(128, 256, name=name + "_down_block_1_0_ResnetBlock"),
|
171 |
ResnetBlock(256, 256, name=name + "_down_block_1_1_ResnetBlock"),
|
172 |
-
Conv2d(256, 256, 3, stride=2, padding=(0,
|
173 |
ResnetBlock(256, 512, name=name + "_down_block_2_0_ResnetBlock"),
|
174 |
ResnetBlock(512, 512, name=name + "_down_block_2_1_ResnetBlock"),
|
175 |
-
Conv2d(512, 512, 3, stride=2, padding=(0,
|
176 |
ResnetBlock(512, 512, name=name + "_down_block_3_0_ResnetBlock"),
|
177 |
ResnetBlock(512, 512, name=name + "_down_block_3_1_ResnetBlock"),
|
178 |
])
|
@@ -181,12 +187,17 @@ class Encoder(Module):
|
|
181 |
self.norm_out = Normalize(512, name=name+"_norm_out_Normalize")
|
182 |
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
183 |
self.name = name
|
|
|
184 |
|
185 |
def forward(self, x):
|
186 |
x = self.conv_in(x)
|
187 |
|
188 |
for l in self.down:
|
189 |
-
x = l(x)
|
|
|
|
|
|
|
|
|
190 |
x = self.mid(x)
|
191 |
return self.conv_out(F.silu(self.norm_out(x)))
|
192 |
|
@@ -637,7 +648,8 @@ class CLIPTextTransformer(Module):
|
|
637 |
self.encoder = CLIPEncoder(name=name+"_CLIPEncoder_0")
|
638 |
self.final_layer_norm = Normalize(768, num_groups=None, name=name+"_CLIPTextTransformer_normalizer_0")
|
639 |
# 上三角都是 -inf 值
|
640 |
-
|
|
|
641 |
self.name = name
|
642 |
|
643 |
def forward(self, input_ids):
|
@@ -804,7 +816,7 @@ class StableDiffusion(Module):
|
|
804 |
|
805 |
|
806 |
class Args(object):
|
807 |
-
def __init__(self, phrase, steps, model_type, guidance_scale, img_width, img_height, seed, device, model_file):
|
808 |
self.phrase = phrase
|
809 |
self.steps = steps
|
810 |
self.model_type = model_type
|
@@ -814,22 +826,41 @@ class Args(object):
|
|
814 |
self.seed = seed
|
815 |
self.device = device
|
816 |
self.model_file = model_file
|
|
|
|
|
|
|
|
|
817 |
|
|
|
818 |
|
819 |
-
class
|
820 |
_instance_lock = threading.Lock()
|
821 |
def __init__(self, args: Args):
|
822 |
-
super(
|
823 |
self.is_load_model=False
|
824 |
self.args = args
|
825 |
self.model = StableDiffusion().instance()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
826 |
|
827 |
@classmethod
|
828 |
def instance(cls, *args, **kwargs):
|
829 |
-
with
|
830 |
-
if not hasattr(
|
831 |
-
|
832 |
-
return
|
833 |
|
834 |
def load_model(self):
|
835 |
if self.args.model_file != "" and self.is_load_model==False:
|
@@ -841,6 +872,7 @@ class Text2img(Module):
|
|
841 |
def get_token_encode(self, phrase):
|
842 |
tokenizer = ClipTokenizer().instance()
|
843 |
phrase = tokenizer.encode(phrase)
|
|
|
844 |
with torch.no_grad():
|
845 |
context = self.model.text_decoder(phrase)
|
846 |
return context.to(self.args.device)
|
@@ -848,7 +880,7 @@ class Text2img(Module):
|
|
848 |
self.set_seeds(True)
|
849 |
self.load_model()
|
850 |
context = self.get_token_encode(phrase)
|
851 |
-
unconditional_context = self.get_token_encode(
|
852 |
|
853 |
timesteps = list(np.arange(1, 1000, 1000 // self.args.steps))
|
854 |
print(f"running for {timesteps} timesteps")
|
@@ -857,9 +889,26 @@ class Text2img(Module):
|
|
857 |
|
858 |
latent_width = int(self.args.img_width) // 8
|
859 |
latent_height = int(self.args.img_height) // 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
860 |
# start with random noise
|
861 |
-
latent =
|
|
|
862 |
latent = latent.to(self.args.device)
|
|
|
863 |
with torch.no_grad():
|
864 |
# this is diffusion
|
865 |
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
|
@@ -867,11 +916,14 @@ class Text2img(Module):
|
|
867 |
e_t = self.get_model_latent_output(latent.clone(), timestep, self.model.unet, context.clone(),
|
868 |
unconditional_context.clone())
|
869 |
x_prev, pred_x0 = self.get_x_prev_and_pred_x0(latent, e_t, index, alphas, alphas_prev)
|
|
|
870 |
# e_t_next = get_model_output(x_prev)
|
871 |
# e_t_prime = (e_t + e_t_next) / 2
|
872 |
# x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
873 |
-
|
874 |
-
|
|
|
|
|
875 |
|
876 |
def get_x_prev_and_pred_x0(self, x, e_t, index, alphas, alphas_prev):
|
877 |
temperature = 1
|
@@ -900,6 +952,27 @@ class Text2img(Module):
|
|
900 |
del unconditional_latent, latent, timesteps, context
|
901 |
return e_t
|
902 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
903 |
def latent_decode(self, latent, latent_height, latent_width):
|
904 |
# upsample latent space to image with autoencoder
|
905 |
# x = model.first_stage_model.post_quant_conv( 8* latent)
|
@@ -915,8 +988,7 @@ class Text2img(Module):
|
|
915 |
return decode_latent
|
916 |
def decode_latent2img(self, decode_latent):
|
917 |
# save image
|
918 |
-
|
919 |
-
img = Image.fromarray(decode_latent)
|
920 |
return img
|
921 |
|
922 |
def set_seeds(self, cuda):
|
@@ -925,11 +997,11 @@ class Text2img(Module):
|
|
925 |
if cuda:
|
926 |
torch.cuda.manual_seed_all(self.args.seed)
|
927 |
@lru_cache()
|
928 |
-
def
|
929 |
try:
|
930 |
-
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
|
931 |
-
im =
|
932 |
-
im =
|
933 |
finally:
|
934 |
pass
|
935 |
return im
|
@@ -943,19 +1015,20 @@ if __name__ == "__main__":
|
|
943 |
|
944 |
parser = argparse.ArgumentParser(description='Run Stable Diffusion',
|
945 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
946 |
-
parser.add_argument('--steps', type=int, default=
|
947 |
parser.add_argument('--phrase', type=str, default="anthropomorphic cat portrait art ", help="Phrase to render")
|
|
|
948 |
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
949 |
parser.add_argument('--scale', type=float, default=7.5, help="unconditional guidance scale")
|
950 |
-
parser.add_argument('--model_file', type=str, default="/
|
951 |
parser.add_argument('--img_width', type=int, default=512, help="output image width")
|
952 |
parser.add_argument('--img_height', type=int, default=512, help="output image height")
|
953 |
parser.add_argument('--seed', type=int, default=443, help="random seed")
|
954 |
-
parser.add_argument('--device_type', type=str, default="cpu", help="
|
|
|
955 |
args = parser.parse_args()
|
956 |
-
|
957 |
device = args.device_type
|
958 |
|
959 |
-
im =
|
960 |
print(f"saving {args.out}")
|
961 |
im.save(args.out)
|
|
|
12 |
import numpy as np
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
+
# ,
|
16 |
+
from torch.nn import Conv2d, Linear, Module, SiLU, UpsamplingNearest2d,ModuleList,ZeroPad2d
|
17 |
from torch import Tensor
|
18 |
from torch.nn import functional as F
|
19 |
from torch.nn.parameter import Parameter
|
20 |
|
21 |
+
device = "mps"
|
22 |
|
23 |
def apply_seq(seqs, x):
|
24 |
for seq in seqs:
|
|
|
31 |
def quick_gelu(x):
|
32 |
return x * torch.sigmoid(x * 1.702)
|
33 |
|
34 |
+
# class SiLU(Module):
|
35 |
+
# def __init__(self):
|
36 |
+
# super(SiLU, self).__init__()
|
37 |
+
# self.gelu = quick_gelu
|
38 |
+
# def forward(self,x ):
|
39 |
+
# return self.gelu(x)
|
40 |
class Normalize(Module):
|
41 |
def __init__(self, in_channels, num_groups=32, name="normalize"):
|
42 |
super(Normalize, self).__init__()
|
|
|
172 |
self.down = ModuleList([
|
173 |
ResnetBlock(128, 128, name=name + "_down_block_0_0_ResnetBlock"),
|
174 |
ResnetBlock(128, 128, name=name + "_down_block_0_1_ResnetBlock"),
|
175 |
+
Conv2d(128, 128, 3, stride=2, padding=(0, 0)),
|
176 |
ResnetBlock(128, 256, name=name + "_down_block_1_0_ResnetBlock"),
|
177 |
ResnetBlock(256, 256, name=name + "_down_block_1_1_ResnetBlock"),
|
178 |
+
Conv2d(256, 256, 3, stride=2, padding=(0, 0)),
|
179 |
ResnetBlock(256, 512, name=name + "_down_block_2_0_ResnetBlock"),
|
180 |
ResnetBlock(512, 512, name=name + "_down_block_2_1_ResnetBlock"),
|
181 |
+
Conv2d(512, 512, 3, stride=2, padding=(0, 0)),
|
182 |
ResnetBlock(512, 512, name=name + "_down_block_3_0_ResnetBlock"),
|
183 |
ResnetBlock(512, 512, name=name + "_down_block_3_1_ResnetBlock"),
|
184 |
])
|
|
|
187 |
self.norm_out = Normalize(512, name=name+"_norm_out_Normalize")
|
188 |
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
189 |
self.name = name
|
190 |
+
self.zero_pad2d_0_1 = ZeroPad2d((0,1,0,1))
|
191 |
|
192 |
def forward(self, x):
|
193 |
x = self.conv_in(x)
|
194 |
|
195 |
for l in self.down:
|
196 |
+
# x = l(x)
|
197 |
+
if isinstance(l, Conv2d):
|
198 |
+
x = l(self.zero_pad2d_0_1(x))
|
199 |
+
else:
|
200 |
+
x = l(x)
|
201 |
x = self.mid(x)
|
202 |
return self.conv_out(F.silu(self.norm_out(x)))
|
203 |
|
|
|
648 |
self.encoder = CLIPEncoder(name=name+"_CLIPEncoder_0")
|
649 |
self.final_layer_norm = Normalize(768, num_groups=None, name=name+"_CLIPTextTransformer_normalizer_0")
|
650 |
# 上三角都是 -inf 值
|
651 |
+
triu = np.triu(np.ones((1, 1, 77, 77), dtype=np.float32) * -np.inf, k=1)
|
652 |
+
self.causal_attention_mask = Tensor(triu).to(device)
|
653 |
self.name = name
|
654 |
|
655 |
def forward(self, input_ids):
|
|
|
816 |
|
817 |
|
818 |
class Args(object):
|
819 |
+
def __init__(self, phrase, steps, model_type, guidance_scale, img_width, img_height, seed, device, model_file, input_image:str="", input_mask:str="", input_image_strength=0.5, unphrase=""):
|
820 |
self.phrase = phrase
|
821 |
self.steps = steps
|
822 |
self.model_type = model_type
|
|
|
826 |
self.seed = seed
|
827 |
self.device = device
|
828 |
self.model_file = model_file
|
829 |
+
self.input_image = input_image
|
830 |
+
self.input_mask = input_mask
|
831 |
+
self.input_image_strength = input_image_strength
|
832 |
+
self.unphrase = unphrase
|
833 |
|
834 |
+
from PIL import Image
|
835 |
|
836 |
+
class Generate2img(Module):
|
837 |
_instance_lock = threading.Lock()
|
838 |
def __init__(self, args: Args):
|
839 |
+
super(Generate2img, self).__init__()
|
840 |
self.is_load_model=False
|
841 |
self.args = args
|
842 |
self.model = StableDiffusion().instance()
|
843 |
+
self.get_input_image_tensor()
|
844 |
+
# self.get_input_mask_tensor()
|
845 |
+
|
846 |
+
|
847 |
+
def get_input_image_tensor(self):
|
848 |
+
if self.args.input_image!="":
|
849 |
+
input_image = Image.open(args.input_image).convert("RGB").resize((self.args.img_width, self.args.img_height), resample=Image.Resampling.LANCZOS)
|
850 |
+
self.input_image_array = torch.from_numpy(np.array(input_image)).to(device)
|
851 |
+
self.input_image_tensor = torch.from_numpy((np.array(input_image, dtype=np.float32)[None, ..., :3]/ 255.0*2.0-1))
|
852 |
+
self.input_image_tensor = self.input_image_tensor.permute(0, 3, 1, 2) # bs, channel, height, width
|
853 |
+
else:
|
854 |
+
self.input_image_tensor = None
|
855 |
+
return self.input_image_tensor
|
856 |
+
|
857 |
|
858 |
@classmethod
|
859 |
def instance(cls, *args, **kwargs):
|
860 |
+
with Generate2img._instance_lock:
|
861 |
+
if not hasattr(Generate2img, "_instance"):
|
862 |
+
Generate2img._instance = Generate2img(*args, **kwargs)
|
863 |
+
return Generate2img._instance
|
864 |
|
865 |
def load_model(self):
|
866 |
if self.args.model_file != "" and self.is_load_model==False:
|
|
|
872 |
def get_token_encode(self, phrase):
|
873 |
tokenizer = ClipTokenizer().instance()
|
874 |
phrase = tokenizer.encode(phrase)
|
875 |
+
# phrase = phrase + [49407] * (77 - len(phrase))
|
876 |
with torch.no_grad():
|
877 |
context = self.model.text_decoder(phrase)
|
878 |
return context.to(self.args.device)
|
|
|
880 |
self.set_seeds(True)
|
881 |
self.load_model()
|
882 |
context = self.get_token_encode(phrase)
|
883 |
+
unconditional_context = self.get_token_encode(self.args.unphrase)
|
884 |
|
885 |
timesteps = list(np.arange(1, 1000, 1000 // self.args.steps))
|
886 |
print(f"running for {timesteps} timesteps")
|
|
|
889 |
|
890 |
latent_width = int(self.args.img_width) // 8
|
891 |
latent_height = int(self.args.img_height) // 8
|
892 |
+
|
893 |
+
|
894 |
+
input_image_latent = None
|
895 |
+
input_img_noise_t = None
|
896 |
+
if self.input_image_tensor!=None:
|
897 |
+
noise_index = int(len(timesteps) * self.args.input_image_strength)
|
898 |
+
if noise_index >= len(timesteps):
|
899 |
+
noise_index = noise_index - 1
|
900 |
+
input_img_noise_t = timesteps[noise_index]
|
901 |
+
with torch.no_grad():
|
902 |
+
filter = lambda x:x[:,:4,:,:] * 0.18215
|
903 |
+
input_image_latent = self.model.first_stage_model.encoder(self.input_image_tensor.to(device))
|
904 |
+
input_image_latent = self.model.first_stage_model.quant_conv(input_image_latent)
|
905 |
+
input_image_latent = filter(input_image_latent) # only the means
|
906 |
+
|
907 |
# start with random noise
|
908 |
+
latent = self.get_noise_latent( 1, latent_height, latent_width, input_image_latent, input_img_noise_t, None)
|
909 |
+
|
910 |
latent = latent.to(self.args.device)
|
911 |
+
|
912 |
with torch.no_grad():
|
913 |
# this is diffusion
|
914 |
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
|
|
|
916 |
e_t = self.get_model_latent_output(latent.clone(), timestep, self.model.unet, context.clone(),
|
917 |
unconditional_context.clone())
|
918 |
x_prev, pred_x0 = self.get_x_prev_and_pred_x0(latent, e_t, index, alphas, alphas_prev)
|
919 |
+
latent = x_prev
|
920 |
# e_t_next = get_model_output(x_prev)
|
921 |
# e_t_prime = (e_t + e_t_next) / 2
|
922 |
# x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
923 |
+
decode = self.latent_decode(latent, latent_height, latent_width)
|
924 |
+
|
925 |
+
return decode
|
926 |
+
|
927 |
|
928 |
def get_x_prev_and_pred_x0(self, x, e_t, index, alphas, alphas_prev):
|
929 |
temperature = 1
|
|
|
952 |
del unconditional_latent, latent, timesteps, context
|
953 |
return e_t
|
954 |
|
955 |
+
|
956 |
+
def add_noise(self, x , t , noise=None ):
|
957 |
+
# batch_size, channel, h, w = x.shape
|
958 |
+
if noise is None:
|
959 |
+
noise = torch.normal(0,1, size=(x.shape))
|
960 |
+
# sqrt_alpha_prod = _ALPHAS_CUMPROD[t] ** 0.5
|
961 |
+
sqrt_alpha_prod = self.model.sqrt_alphas_cumprod[t]
|
962 |
+
sqrt_one_minus_alpha_prod = self.model.sqrt_one_minus_alphas_cumprod[t]
|
963 |
+
# sqrt_one_minus_alpha_prod = (1 - _ALPHAS_CUMPROD[t]) ** 0.5
|
964 |
+
|
965 |
+
return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise.to(device)
|
966 |
+
|
967 |
+
def get_noise_latent(self, batch_size, latent_height, latent_width, input_image_latent=None, input_img_noise_t=None, noise=None):
|
968 |
+
|
969 |
+
if input_image_latent is None:
|
970 |
+
latent = torch.normal(0,1, size=(batch_size, 4, latent_height, latent_width))
|
971 |
+
# latent = torch.randn((batch_size, 4, latent_height, latent_width))
|
972 |
+
else:
|
973 |
+
latent = self.add_noise(input_image_latent, input_img_noise_t, noise)
|
974 |
+
return latent.to(device)
|
975 |
+
|
976 |
def latent_decode(self, latent, latent_height, latent_width):
|
977 |
# upsample latent space to image with autoencoder
|
978 |
# x = model.first_stage_model.post_quant_conv( 8* latent)
|
|
|
988 |
return decode_latent
|
989 |
def decode_latent2img(self, decode_latent):
|
990 |
# save image
|
991 |
+
img = Image.fromarray(decode_latent, mode="RGB")
|
|
|
992 |
return img
|
993 |
|
994 |
def set_seeds(self, cuda):
|
|
|
997 |
if cuda:
|
998 |
torch.cuda.manual_seed_all(self.args.seed)
|
999 |
@lru_cache()
|
1000 |
+
def generate2img(phrase, steps, model_file, guidance_scale, img_width, img_height, seed, device, input_image, input_mask, input_image_strength=0.5, unphrase=""):
|
1001 |
try:
|
1002 |
+
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file, input_image, input_mask, input_image_strength, unphrase)
|
1003 |
+
im = Generate2img.instance(args).forward(args.phrase)
|
1004 |
+
im = Generate2img.instance(args).decode_latent2img(im)
|
1005 |
finally:
|
1006 |
pass
|
1007 |
return im
|
|
|
1015 |
|
1016 |
parser = argparse.ArgumentParser(description='Run Stable Diffusion',
|
1017 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
1018 |
+
parser.add_argument('--steps', type=int, default=50, help="Number of steps in diffusion")
|
1019 |
parser.add_argument('--phrase', type=str, default="anthropomorphic cat portrait art ", help="Phrase to render")
|
1020 |
+
parser.add_argument('--unphrase', type=str, default="", help="unconditional Phrase to render")
|
1021 |
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
1022 |
parser.add_argument('--scale', type=float, default=7.5, help="unconditional guidance scale")
|
1023 |
+
parser.add_argument('--model_file', type=str, default="../min-stable-diffusion-pt/mdjrny-v4.pt", help="model weight file")
|
1024 |
parser.add_argument('--img_width', type=int, default=512, help="output image width")
|
1025 |
parser.add_argument('--img_height', type=int, default=512, help="output image height")
|
1026 |
parser.add_argument('--seed', type=int, default=443, help="random seed")
|
1027 |
+
parser.add_argument('--device_type', type=str, default="cpu", help="device type, support: cpu;cuda;mps")
|
1028 |
+
parser.add_argument('--input_image', type=str, default="", help="input image file")
|
1029 |
args = parser.parse_args()
|
|
|
1030 |
device = args.device_type
|
1031 |
|
1032 |
+
im = generate2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type, args.input_image, "", 1, args.unphrase)
|
1033 |
print(f"saving {args.out}")
|
1034 |
im.save(args.out)
|