xfh commited on
Commit
2f38bfd
1 Parent(s): 46dc4a6

modify phrase

Browse files
Files changed (2) hide show
  1. app.py +19 -4
  2. stable_diffusion.py +103 -30
app.py CHANGED
@@ -1,14 +1,14 @@
1
- from stable_diffusion import Text2img, Args
2
  import gradio as gr
3
  args = Args("", 5, None, 7.5, 512, 512, 443, "cpu", "./mdjrny-v4.pt")
4
- model = Text2img.instance(args)
5
  def text2img_output(phrase):
6
  return model(phrase)
7
 
8
  readme = open("me.md","rb+").read().decode("utf-8")
9
 
10
  phrase = gr.components.Textbox(
11
- value="a very beautiful young anime tennis girl, full body, long wavy blond hair, sky blue eyes, full round face, short smile, bikini, miniskirt, highly detailed, cinematic wallpaper by stanley artgerm lau ")
12
  text2img_out = gr.components.Image(type="numpy")
13
 
14
  instance = gr.Blocks()
@@ -22,4 +22,19 @@ with instance:
22
  gr.Markdown(readme)
23
 
24
 
25
- instance.queue(concurrency_count=20).launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_diffusion import Generate2img, Args
2
  import gradio as gr
3
  args = Args("", 5, None, 7.5, 512, 512, 443, "cpu", "./mdjrny-v4.pt")
4
+ model = Generate2img.instance(args)
5
  def text2img_output(phrase):
6
  return model(phrase)
7
 
8
  readme = open("me.md","rb+").read().decode("utf-8")
9
 
10
  phrase = gr.components.Textbox(
11
+ value="anthropomorphic cat portrait art")
12
  text2img_out = gr.components.Image(type="numpy")
13
 
14
  instance = gr.Blocks()
 
22
  gr.Markdown(readme)
23
 
24
 
25
+ instance.queue(concurrency_count=20).launch(share=False)
26
+ #
27
+ #
28
+ # 1) anthropomorphic cat portrait art
29
+ #
30
+ # ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered.png)
31
+ #
32
+ # 2) anthropomorphic cat portrait art(mdjrny-v4.pt)
33
+ #
34
+ # ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered2.png)
35
+ #
36
+ # 3) Kung Fu Panda(weight: wd-1-3-penultimate-ucg-cont.pt, steps:50)
37
+ #
38
+ # ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered3.png)
39
+ # ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered4.png)
40
+ #
stable_diffusion.py CHANGED
@@ -12,13 +12,13 @@ from collections import namedtuple
12
  import numpy as np
13
  from tqdm import tqdm
14
 
15
-
16
- from torch.nn import Conv2d, Linear, Module,SiLU, UpsamplingNearest2d,ModuleList
17
  from torch import Tensor
18
  from torch.nn import functional as F
19
  from torch.nn.parameter import Parameter
20
 
21
- device = "cpu"
22
 
23
  def apply_seq(seqs, x):
24
  for seq in seqs:
@@ -31,6 +31,12 @@ def gelu(self):
31
  def quick_gelu(x):
32
  return x * torch.sigmoid(x * 1.702)
33
 
 
 
 
 
 
 
34
  class Normalize(Module):
35
  def __init__(self, in_channels, num_groups=32, name="normalize"):
36
  super(Normalize, self).__init__()
@@ -166,13 +172,13 @@ class Encoder(Module):
166
  self.down = ModuleList([
167
  ResnetBlock(128, 128, name=name + "_down_block_0_0_ResnetBlock"),
168
  ResnetBlock(128, 128, name=name + "_down_block_0_1_ResnetBlock"),
169
- Conv2d(128, 128, 3, stride=2, padding=(0, 1, 0, 1)),
170
  ResnetBlock(128, 256, name=name + "_down_block_1_0_ResnetBlock"),
171
  ResnetBlock(256, 256, name=name + "_down_block_1_1_ResnetBlock"),
172
- Conv2d(256, 256, 3, stride=2, padding=(0, 1, 0, 1)),
173
  ResnetBlock(256, 512, name=name + "_down_block_2_0_ResnetBlock"),
174
  ResnetBlock(512, 512, name=name + "_down_block_2_1_ResnetBlock"),
175
- Conv2d(512, 512, 3, stride=2, padding=(0, 1, 0, 1)),
176
  ResnetBlock(512, 512, name=name + "_down_block_3_0_ResnetBlock"),
177
  ResnetBlock(512, 512, name=name + "_down_block_3_1_ResnetBlock"),
178
  ])
@@ -181,12 +187,17 @@ class Encoder(Module):
181
  self.norm_out = Normalize(512, name=name+"_norm_out_Normalize")
182
  self.conv_out = Conv2d(512, 8, 3, padding=1)
183
  self.name = name
 
184
 
185
  def forward(self, x):
186
  x = self.conv_in(x)
187
 
188
  for l in self.down:
189
- x = l(x)
 
 
 
 
190
  x = self.mid(x)
191
  return self.conv_out(F.silu(self.norm_out(x)))
192
 
@@ -637,7 +648,8 @@ class CLIPTextTransformer(Module):
637
  self.encoder = CLIPEncoder(name=name+"_CLIPEncoder_0")
638
  self.final_layer_norm = Normalize(768, num_groups=None, name=name+"_CLIPTextTransformer_normalizer_0")
639
  # 上三角都是 -inf 值
640
- self.causal_attention_mask = Tensor(np.triu(np.ones((1, 1, 77, 77), dtype=np.float32) * -np.inf, k=1)).to(device)
 
641
  self.name = name
642
 
643
  def forward(self, input_ids):
@@ -804,7 +816,7 @@ class StableDiffusion(Module):
804
 
805
 
806
  class Args(object):
807
- def __init__(self, phrase, steps, model_type, guidance_scale, img_width, img_height, seed, device, model_file):
808
  self.phrase = phrase
809
  self.steps = steps
810
  self.model_type = model_type
@@ -814,22 +826,41 @@ class Args(object):
814
  self.seed = seed
815
  self.device = device
816
  self.model_file = model_file
 
 
 
 
817
 
 
818
 
819
- class Text2img(Module):
820
  _instance_lock = threading.Lock()
821
  def __init__(self, args: Args):
822
- super(Text2img, self).__init__()
823
  self.is_load_model=False
824
  self.args = args
825
  self.model = StableDiffusion().instance()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
 
827
  @classmethod
828
  def instance(cls, *args, **kwargs):
829
- with Text2img._instance_lock:
830
- if not hasattr(Text2img, "_instance"):
831
- Text2img._instance = Text2img(*args, **kwargs)
832
- return Text2img._instance
833
 
834
  def load_model(self):
835
  if self.args.model_file != "" and self.is_load_model==False:
@@ -841,6 +872,7 @@ class Text2img(Module):
841
  def get_token_encode(self, phrase):
842
  tokenizer = ClipTokenizer().instance()
843
  phrase = tokenizer.encode(phrase)
 
844
  with torch.no_grad():
845
  context = self.model.text_decoder(phrase)
846
  return context.to(self.args.device)
@@ -848,7 +880,7 @@ class Text2img(Module):
848
  self.set_seeds(True)
849
  self.load_model()
850
  context = self.get_token_encode(phrase)
851
- unconditional_context = self.get_token_encode("")
852
 
853
  timesteps = list(np.arange(1, 1000, 1000 // self.args.steps))
854
  print(f"running for {timesteps} timesteps")
@@ -857,9 +889,26 @@ class Text2img(Module):
857
 
858
  latent_width = int(self.args.img_width) // 8
859
  latent_height = int(self.args.img_height) // 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
  # start with random noise
861
- latent = torch.randn(1, 4, latent_height, latent_width)
 
862
  latent = latent.to(self.args.device)
 
863
  with torch.no_grad():
864
  # this is diffusion
865
  for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
@@ -867,11 +916,14 @@ class Text2img(Module):
867
  e_t = self.get_model_latent_output(latent.clone(), timestep, self.model.unet, context.clone(),
868
  unconditional_context.clone())
869
  x_prev, pred_x0 = self.get_x_prev_and_pred_x0(latent, e_t, index, alphas, alphas_prev)
 
870
  # e_t_next = get_model_output(x_prev)
871
  # e_t_prime = (e_t + e_t_next) / 2
872
  # x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
873
- latent = x_prev
874
- return self.latent_decode(latent, latent_height, latent_width)
 
 
875
 
876
  def get_x_prev_and_pred_x0(self, x, e_t, index, alphas, alphas_prev):
877
  temperature = 1
@@ -900,6 +952,27 @@ class Text2img(Module):
900
  del unconditional_latent, latent, timesteps, context
901
  return e_t
902
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903
  def latent_decode(self, latent, latent_height, latent_width):
904
  # upsample latent space to image with autoencoder
905
  # x = model.first_stage_model.post_quant_conv( 8* latent)
@@ -915,8 +988,7 @@ class Text2img(Module):
915
  return decode_latent
916
  def decode_latent2img(self, decode_latent):
917
  # save image
918
- from PIL import Image
919
- img = Image.fromarray(decode_latent)
920
  return img
921
 
922
  def set_seeds(self, cuda):
@@ -925,11 +997,11 @@ class Text2img(Module):
925
  if cuda:
926
  torch.cuda.manual_seed_all(self.args.seed)
927
  @lru_cache()
928
- def text2img(phrase, steps, model_file, guidance_scale, img_width, img_height, seed, device):
929
  try:
930
- args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
931
- im = Text2img.instance(args).forward(args.phrase)
932
- im = Text2img.instance(args).decode_latent2img(im)
933
  finally:
934
  pass
935
  return im
@@ -943,19 +1015,20 @@ if __name__ == "__main__":
943
 
944
  parser = argparse.ArgumentParser(description='Run Stable Diffusion',
945
  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
946
- parser.add_argument('--steps', type=int, default=25, help="Number of steps in diffusion")
947
  parser.add_argument('--phrase', type=str, default="anthropomorphic cat portrait art ", help="Phrase to render")
 
948
  parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
949
  parser.add_argument('--scale', type=float, default=7.5, help="unconditional guidance scale")
950
- parser.add_argument('--model_file', type=str, default="/tmp/mdjrny-v4.pt", help="model weight file")
951
  parser.add_argument('--img_width', type=int, default=512, help="output image width")
952
  parser.add_argument('--img_height', type=int, default=512, help="output image height")
953
  parser.add_argument('--seed', type=int, default=443, help="random seed")
954
- parser.add_argument('--device_type', type=str, default="cpu", help="random seed")
 
955
  args = parser.parse_args()
956
-
957
  device = args.device_type
958
 
959
- im = text2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type)
960
  print(f"saving {args.out}")
961
  im.save(args.out)
 
12
  import numpy as np
13
  from tqdm import tqdm
14
 
15
+ # ,
16
+ from torch.nn import Conv2d, Linear, Module, SiLU, UpsamplingNearest2d,ModuleList,ZeroPad2d
17
  from torch import Tensor
18
  from torch.nn import functional as F
19
  from torch.nn.parameter import Parameter
20
 
21
+ device = "mps"
22
 
23
  def apply_seq(seqs, x):
24
  for seq in seqs:
 
31
  def quick_gelu(x):
32
  return x * torch.sigmoid(x * 1.702)
33
 
34
+ # class SiLU(Module):
35
+ # def __init__(self):
36
+ # super(SiLU, self).__init__()
37
+ # self.gelu = quick_gelu
38
+ # def forward(self,x ):
39
+ # return self.gelu(x)
40
  class Normalize(Module):
41
  def __init__(self, in_channels, num_groups=32, name="normalize"):
42
  super(Normalize, self).__init__()
 
172
  self.down = ModuleList([
173
  ResnetBlock(128, 128, name=name + "_down_block_0_0_ResnetBlock"),
174
  ResnetBlock(128, 128, name=name + "_down_block_0_1_ResnetBlock"),
175
+ Conv2d(128, 128, 3, stride=2, padding=(0, 0)),
176
  ResnetBlock(128, 256, name=name + "_down_block_1_0_ResnetBlock"),
177
  ResnetBlock(256, 256, name=name + "_down_block_1_1_ResnetBlock"),
178
+ Conv2d(256, 256, 3, stride=2, padding=(0, 0)),
179
  ResnetBlock(256, 512, name=name + "_down_block_2_0_ResnetBlock"),
180
  ResnetBlock(512, 512, name=name + "_down_block_2_1_ResnetBlock"),
181
+ Conv2d(512, 512, 3, stride=2, padding=(0, 0)),
182
  ResnetBlock(512, 512, name=name + "_down_block_3_0_ResnetBlock"),
183
  ResnetBlock(512, 512, name=name + "_down_block_3_1_ResnetBlock"),
184
  ])
 
187
  self.norm_out = Normalize(512, name=name+"_norm_out_Normalize")
188
  self.conv_out = Conv2d(512, 8, 3, padding=1)
189
  self.name = name
190
+ self.zero_pad2d_0_1 = ZeroPad2d((0,1,0,1))
191
 
192
  def forward(self, x):
193
  x = self.conv_in(x)
194
 
195
  for l in self.down:
196
+ # x = l(x)
197
+ if isinstance(l, Conv2d):
198
+ x = l(self.zero_pad2d_0_1(x))
199
+ else:
200
+ x = l(x)
201
  x = self.mid(x)
202
  return self.conv_out(F.silu(self.norm_out(x)))
203
 
 
648
  self.encoder = CLIPEncoder(name=name+"_CLIPEncoder_0")
649
  self.final_layer_norm = Normalize(768, num_groups=None, name=name+"_CLIPTextTransformer_normalizer_0")
650
  # 上三角都是 -inf 值
651
+ triu = np.triu(np.ones((1, 1, 77, 77), dtype=np.float32) * -np.inf, k=1)
652
+ self.causal_attention_mask = Tensor(triu).to(device)
653
  self.name = name
654
 
655
  def forward(self, input_ids):
 
816
 
817
 
818
  class Args(object):
819
+ def __init__(self, phrase, steps, model_type, guidance_scale, img_width, img_height, seed, device, model_file, input_image:str="", input_mask:str="", input_image_strength=0.5, unphrase=""):
820
  self.phrase = phrase
821
  self.steps = steps
822
  self.model_type = model_type
 
826
  self.seed = seed
827
  self.device = device
828
  self.model_file = model_file
829
+ self.input_image = input_image
830
+ self.input_mask = input_mask
831
+ self.input_image_strength = input_image_strength
832
+ self.unphrase = unphrase
833
 
834
+ from PIL import Image
835
 
836
+ class Generate2img(Module):
837
  _instance_lock = threading.Lock()
838
  def __init__(self, args: Args):
839
+ super(Generate2img, self).__init__()
840
  self.is_load_model=False
841
  self.args = args
842
  self.model = StableDiffusion().instance()
843
+ self.get_input_image_tensor()
844
+ # self.get_input_mask_tensor()
845
+
846
+
847
+ def get_input_image_tensor(self):
848
+ if self.args.input_image!="":
849
+ input_image = Image.open(args.input_image).convert("RGB").resize((self.args.img_width, self.args.img_height), resample=Image.Resampling.LANCZOS)
850
+ self.input_image_array = torch.from_numpy(np.array(input_image)).to(device)
851
+ self.input_image_tensor = torch.from_numpy((np.array(input_image, dtype=np.float32)[None, ..., :3]/ 255.0*2.0-1))
852
+ self.input_image_tensor = self.input_image_tensor.permute(0, 3, 1, 2) # bs, channel, height, width
853
+ else:
854
+ self.input_image_tensor = None
855
+ return self.input_image_tensor
856
+
857
 
858
  @classmethod
859
  def instance(cls, *args, **kwargs):
860
+ with Generate2img._instance_lock:
861
+ if not hasattr(Generate2img, "_instance"):
862
+ Generate2img._instance = Generate2img(*args, **kwargs)
863
+ return Generate2img._instance
864
 
865
  def load_model(self):
866
  if self.args.model_file != "" and self.is_load_model==False:
 
872
  def get_token_encode(self, phrase):
873
  tokenizer = ClipTokenizer().instance()
874
  phrase = tokenizer.encode(phrase)
875
+ # phrase = phrase + [49407] * (77 - len(phrase))
876
  with torch.no_grad():
877
  context = self.model.text_decoder(phrase)
878
  return context.to(self.args.device)
 
880
  self.set_seeds(True)
881
  self.load_model()
882
  context = self.get_token_encode(phrase)
883
+ unconditional_context = self.get_token_encode(self.args.unphrase)
884
 
885
  timesteps = list(np.arange(1, 1000, 1000 // self.args.steps))
886
  print(f"running for {timesteps} timesteps")
 
889
 
890
  latent_width = int(self.args.img_width) // 8
891
  latent_height = int(self.args.img_height) // 8
892
+
893
+
894
+ input_image_latent = None
895
+ input_img_noise_t = None
896
+ if self.input_image_tensor!=None:
897
+ noise_index = int(len(timesteps) * self.args.input_image_strength)
898
+ if noise_index >= len(timesteps):
899
+ noise_index = noise_index - 1
900
+ input_img_noise_t = timesteps[noise_index]
901
+ with torch.no_grad():
902
+ filter = lambda x:x[:,:4,:,:] * 0.18215
903
+ input_image_latent = self.model.first_stage_model.encoder(self.input_image_tensor.to(device))
904
+ input_image_latent = self.model.first_stage_model.quant_conv(input_image_latent)
905
+ input_image_latent = filter(input_image_latent) # only the means
906
+
907
  # start with random noise
908
+ latent = self.get_noise_latent( 1, latent_height, latent_width, input_image_latent, input_img_noise_t, None)
909
+
910
  latent = latent.to(self.args.device)
911
+
912
  with torch.no_grad():
913
  # this is diffusion
914
  for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
 
916
  e_t = self.get_model_latent_output(latent.clone(), timestep, self.model.unet, context.clone(),
917
  unconditional_context.clone())
918
  x_prev, pred_x0 = self.get_x_prev_and_pred_x0(latent, e_t, index, alphas, alphas_prev)
919
+ latent = x_prev
920
  # e_t_next = get_model_output(x_prev)
921
  # e_t_prime = (e_t + e_t_next) / 2
922
  # x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
923
+ decode = self.latent_decode(latent, latent_height, latent_width)
924
+
925
+ return decode
926
+
927
 
928
  def get_x_prev_and_pred_x0(self, x, e_t, index, alphas, alphas_prev):
929
  temperature = 1
 
952
  del unconditional_latent, latent, timesteps, context
953
  return e_t
954
 
955
+
956
+ def add_noise(self, x , t , noise=None ):
957
+ # batch_size, channel, h, w = x.shape
958
+ if noise is None:
959
+ noise = torch.normal(0,1, size=(x.shape))
960
+ # sqrt_alpha_prod = _ALPHAS_CUMPROD[t] ** 0.5
961
+ sqrt_alpha_prod = self.model.sqrt_alphas_cumprod[t]
962
+ sqrt_one_minus_alpha_prod = self.model.sqrt_one_minus_alphas_cumprod[t]
963
+ # sqrt_one_minus_alpha_prod = (1 - _ALPHAS_CUMPROD[t]) ** 0.5
964
+
965
+ return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise.to(device)
966
+
967
+ def get_noise_latent(self, batch_size, latent_height, latent_width, input_image_latent=None, input_img_noise_t=None, noise=None):
968
+
969
+ if input_image_latent is None:
970
+ latent = torch.normal(0,1, size=(batch_size, 4, latent_height, latent_width))
971
+ # latent = torch.randn((batch_size, 4, latent_height, latent_width))
972
+ else:
973
+ latent = self.add_noise(input_image_latent, input_img_noise_t, noise)
974
+ return latent.to(device)
975
+
976
  def latent_decode(self, latent, latent_height, latent_width):
977
  # upsample latent space to image with autoencoder
978
  # x = model.first_stage_model.post_quant_conv( 8* latent)
 
988
  return decode_latent
989
  def decode_latent2img(self, decode_latent):
990
  # save image
991
+ img = Image.fromarray(decode_latent, mode="RGB")
 
992
  return img
993
 
994
  def set_seeds(self, cuda):
 
997
  if cuda:
998
  torch.cuda.manual_seed_all(self.args.seed)
999
  @lru_cache()
1000
+ def generate2img(phrase, steps, model_file, guidance_scale, img_width, img_height, seed, device, input_image, input_mask, input_image_strength=0.5, unphrase=""):
1001
  try:
1002
+ args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file, input_image, input_mask, input_image_strength, unphrase)
1003
+ im = Generate2img.instance(args).forward(args.phrase)
1004
+ im = Generate2img.instance(args).decode_latent2img(im)
1005
  finally:
1006
  pass
1007
  return im
 
1015
 
1016
  parser = argparse.ArgumentParser(description='Run Stable Diffusion',
1017
  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
1018
+ parser.add_argument('--steps', type=int, default=50, help="Number of steps in diffusion")
1019
  parser.add_argument('--phrase', type=str, default="anthropomorphic cat portrait art ", help="Phrase to render")
1020
+ parser.add_argument('--unphrase', type=str, default="", help="unconditional Phrase to render")
1021
  parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
1022
  parser.add_argument('--scale', type=float, default=7.5, help="unconditional guidance scale")
1023
+ parser.add_argument('--model_file', type=str, default="../min-stable-diffusion-pt/mdjrny-v4.pt", help="model weight file")
1024
  parser.add_argument('--img_width', type=int, default=512, help="output image width")
1025
  parser.add_argument('--img_height', type=int, default=512, help="output image height")
1026
  parser.add_argument('--seed', type=int, default=443, help="random seed")
1027
+ parser.add_argument('--device_type', type=str, default="cpu", help="device type, support: cpu;cuda;mps")
1028
+ parser.add_argument('--input_image', type=str, default="", help="input image file")
1029
  args = parser.parse_args()
 
1030
  device = args.device_type
1031
 
1032
+ im = generate2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type, args.input_image, "", 1, args.unphrase)
1033
  print(f"saving {args.out}")
1034
  im.save(args.out)