Ruining Li commited on
Commit
8cede4e
1 Parent(s): 42a369e

Updated lfs for checkpoints and changes to model

Browse files
.gitignore CHANGED
@@ -1,2 +1 @@
1
- __pycache__/
2
- ckpts/
 
1
+ __pycache__/
 
ckpts/drag-a-part-final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:162ba2040b59ed949fd9f57c861bb07eec56744d2e738e38ada8724de96d0d32
3
+ size 14265312095
ckpts/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
ckpts/stable-diffusion-v1-5/unet/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.6.0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "center_input_sample": false,
13
+ "cross_attention_dim": 768,
14
+ "down_block_types": [
15
+ "CrossAttnDownBlock2D",
16
+ "CrossAttnDownBlock2D",
17
+ "CrossAttnDownBlock2D",
18
+ "DownBlock2D"
19
+ ],
20
+ "downsample_padding": 1,
21
+ "flip_sin_to_cos": true,
22
+ "freq_shift": 0,
23
+ "in_channels": 4,
24
+ "layers_per_block": 2,
25
+ "mid_block_scale_factor": 1,
26
+ "norm_eps": 1e-05,
27
+ "norm_num_groups": 32,
28
+ "out_channels": 4,
29
+ "sample_size": 64,
30
+ "up_block_types": [
31
+ "UpBlock2D",
32
+ "CrossAttnUpBlock2D",
33
+ "CrossAttnUpBlock2D",
34
+ "CrossAttnUpBlock2D"
35
+ ]
36
+ }
model.py CHANGED
@@ -1255,20 +1255,6 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
1255
  )
1256
  elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
1257
  raise NotImplementedError
1258
- self.mid_block = UNetMidBlock2DSimpleCrossAttn(
1259
- in_channels=block_out_channels[-1],
1260
- temb_channels=blocks_time_embed_dim,
1261
- resnet_eps=norm_eps,
1262
- resnet_act_fn=act_fn,
1263
- output_scale_factor=mid_block_scale_factor,
1264
- cross_attention_dim=cross_attention_dim[-1],
1265
- attention_head_dim=attention_head_dim[-1],
1266
- resnet_groups=norm_num_groups,
1267
- resnet_time_scale_shift=resnet_time_scale_shift,
1268
- skip_time_act=resnet_skip_time_act,
1269
- only_cross_attention=mid_block_only_cross_attention,
1270
- cross_attention_norm=cross_attention_norm,
1271
- )
1272
  elif mid_block_type is None:
1273
  self.mid_block = None
1274
  else:
@@ -1512,11 +1498,6 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
1512
  y1 = y1.unsqueeze(-1).unsqueeze(-1)
1513
  y1 = torch.stack([torch.zeros_like(y1) - 1, torch.zeros_like(y1) - 1, y1, y1], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
1514
 
1515
- # assert torch.all(x_src >= 0) and torch.all(x_src <= 1)
1516
- # assert torch.all(y_src >= 0) and torch.all(y_src <= 1)
1517
- # assert torch.all(x_tgt >= 0) and torch.all(x_tgt <= 1)
1518
- # assert torch.all(y_tgt >= 0) and torch.all(y_tgt <= 1)
1519
-
1520
  value_image = torch.stack([x_src, y_src, x_tgt, y_tgt], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
1521
  value_image = value_image.expand(bsz, 4 * self.num_drags, current_resolution, current_resolution)
1522
 
@@ -1527,18 +1508,6 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
1527
 
1528
  def forward(
1529
  self,
1530
- # sample: torch.FloatTensor,
1531
- # timestep: Union[torch.Tensor, float, int],
1532
- # encoder_hidden_states: torch.Tensor,
1533
- # class_labels: Optional[torch.Tensor] = None,
1534
- # timestep_cond: Optional[torch.Tensor] = None,
1535
- # attention_mask: Optional[torch.Tensor] = None,
1536
- # cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1537
- # added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1538
- # down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1539
- # mid_block_additional_residual: Optional[torch.Tensor] = None,
1540
- # encoder_attention_mask: Optional[torch.Tensor] = None,
1541
- # return_dict: bool = True,
1542
  x: torch.FloatTensor,
1543
  t: torch.Tensor,
1544
  x_cond: torch.FloatTensor,
@@ -1546,7 +1515,6 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
1546
  force_drop_ids: Optional[torch.Tensor] = None,
1547
  hidden_cls: Optional[torch.Tensor] = None,
1548
  drags: Optional[torch.Tensor] = None,
1549
- save_features: bool = False,
1550
  ) -> torch.Tensor:
1551
  r"""
1552
  The [`UNet2DConditionModel`] forward method.
@@ -1941,11 +1909,10 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
1941
  from diffusers.utils import WEIGHTS_NAME
1942
  one_sided_attn = unet_additional_kwargs.pop("one_sided_attn", True) if unet_additional_kwargs is not None else True
1943
  model = cls.from_config(config, **unet_additional_kwargs) if unet_additional_kwargs is not None else cls.from_config(config)
1944
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1945
- if not os.path.isfile(model_file):
1946
- raise RuntimeError(f"{model_file} does not exist")
1947
-
1948
  if load:
 
 
 
1949
  state_dict = torch.load(model_file, map_location="cpu")
1950
  m, u = model.load_state_dict(state_dict, strict=False)
1951
 
 
1255
  )
1256
  elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
1257
  raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1258
  elif mid_block_type is None:
1259
  self.mid_block = None
1260
  else:
 
1498
  y1 = y1.unsqueeze(-1).unsqueeze(-1)
1499
  y1 = torch.stack([torch.zeros_like(y1) - 1, torch.zeros_like(y1) - 1, y1, y1], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
1500
 
 
 
 
 
 
1501
  value_image = torch.stack([x_src, y_src, x_tgt, y_tgt], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
1502
  value_image = value_image.expand(bsz, 4 * self.num_drags, current_resolution, current_resolution)
1503
 
 
1508
 
1509
  def forward(
1510
  self,
 
 
 
 
 
 
 
 
 
 
 
 
1511
  x: torch.FloatTensor,
1512
  t: torch.Tensor,
1513
  x_cond: torch.FloatTensor,
 
1515
  force_drop_ids: Optional[torch.Tensor] = None,
1516
  hidden_cls: Optional[torch.Tensor] = None,
1517
  drags: Optional[torch.Tensor] = None,
 
1518
  ) -> torch.Tensor:
1519
  r"""
1520
  The [`UNet2DConditionModel`] forward method.
 
1909
  from diffusers.utils import WEIGHTS_NAME
1910
  one_sided_attn = unet_additional_kwargs.pop("one_sided_attn", True) if unet_additional_kwargs is not None else True
1911
  model = cls.from_config(config, **unet_additional_kwargs) if unet_additional_kwargs is not None else cls.from_config(config)
 
 
 
 
1912
  if load:
1913
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1914
+ if not os.path.isfile(model_file):
1915
+ raise RuntimeError(f"{model_file} does not exist")
1916
  state_dict = torch.load(model_file, map_location="cpu")
1917
  m, u = model.load_state_dict(state_dict, strict=False)
1918