myn0908 commited on
Commit
e4c85fa
1 Parent(s): 6ce9552

adapter weights for sketch2image

Browse files
Files changed (3) hide show
  1. S2I/modules/models.py +12 -3
  2. S2I/modules/utils.py +1 -0
  3. app.py +1 -1
S2I/modules/models.py CHANGED
@@ -4,7 +4,7 @@ from diffusers import DDPMScheduler
4
  from transformers import AutoTokenizer, CLIPTextModel
5
  from diffusers import AutoencoderKL, UNet2DConditionModel
6
  from peft import LoraConfig
7
- from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path
8
 
9
 
10
  class RelationShipConvolution(torch.nn.Module):
@@ -50,7 +50,16 @@ class PrimaryModel:
50
  vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
51
  vae.decoder.ignore_skip = False
52
  return vae
53
-
 
 
 
 
 
 
 
 
 
54
  def from_pretrained(self, model_name, r):
55
  if self.global_tokenizer is None:
56
  # self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
@@ -72,7 +81,7 @@ class PrimaryModel:
72
  self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
73
  p_ckpt_path = download_models()
74
  p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
75
- sd = torch.load(p_ckpt, map_location="cpu")
76
  conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
77
  self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
78
  unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
 
4
  from transformers import AutoTokenizer, CLIPTextModel
5
  from diffusers import AutoencoderKL, UNet2DConditionModel
6
  from peft import LoraConfig
7
+ from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home
8
 
9
 
10
  class RelationShipConvolution(torch.nn.Module):
 
50
  vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
51
  vae.decoder.ignore_skip = False
52
  return vae
53
+ def weights_adapter(self, p_ckpt, model_name):
54
+ if model_name == '350k-adapter':
55
+ home = get_s2i_home()
56
+ sd_sketch = torch.load(os.path.join(home, f"sketch2image_lora_350k.pkl"), map_location="cpu")
57
+ sd = torch.load(p_ckpt, map_location="cpu")
58
+ sd.update(sd_sketch)
59
+ return sd
60
+ else:
61
+ sd = torch.load(p_ckpt, map_location="cpu")
62
+ return sd
63
  def from_pretrained(self, model_name, r):
64
  if self.global_tokenizer is None:
65
  # self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
 
81
  self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
82
  p_ckpt_path = download_models()
83
  p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
84
+ sd = self.weights_adapter(p_ckpt, model_name)
85
  conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
86
  self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
87
  unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
S2I/modules/utils.py CHANGED
@@ -84,6 +84,7 @@ def get_s2i_home() -> str:
84
 
85
  def download_models():
86
  urls = {
 
87
  '350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true',
88
  '100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
89
  }
 
84
 
85
  def download_models():
86
  urls = {
87
+ '350k-adapter': 'https://huggingface.co/myn0908/sk2ks/resolve/main/adapter_weights_large_sketch2image_lora.pkl?download=true',
88
  '350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true',
89
  '100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
90
  }
app.py CHANGED
@@ -263,7 +263,7 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
263
  label="Demo Speed",
264
  interactive=True)
265
  model_options = gr.Radio(
266
- choices=["100k", "350k"],
267
  value="350k",
268
  label="Type Sketch2Image models",
269
  interactive=True)
 
263
  label="Demo Speed",
264
  interactive=True)
265
  model_options = gr.Radio(
266
+ choices=["100k", "350k", "350k-adapter"],
267
  value="350k",
268
  label="Type Sketch2Image models",
269
  interactive=True)