shunk031 commited on
Commit
16b0e01
·
verified ·
1 Parent(s): 3f3c4cd

Upload model

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. modeling_isnet.py +34 -25
config.json CHANGED
@@ -1,10 +1,10 @@
1
  {
2
  "architectures": [
3
- "ISNet"
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_isnet.ISNetConfig",
7
- "AutoModel": "modeling_isnet.ISNet"
8
  },
9
  "in_channels": 3,
10
  "out_channels": 1,
 
1
  {
2
  "architectures": [
3
+ "ISNetModel"
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_isnet.ISNetConfig",
7
+ "AutoModel": "modeling_isnet.ISNetModel"
8
  },
9
  "in_channels": 3,
10
  "out_channels": 1,
modeling_isnet.py CHANGED
@@ -1,15 +1,34 @@
1
  import logging
2
- from typing import Literal, Optional, Tuple
 
3
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from transformers import PreTrainedModel
 
8
 
9
  from .configuration_isnet import ISNetConfig
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  bce_loss = nn.BCELoss(size_average=True)
14
 
15
 
@@ -540,7 +559,7 @@ class ISNetGTEncoder(nn.Module):
540
  return activated, hidden_states
541
 
542
 
543
- class ISNet(PreTrainedModel):
544
  config_class = ISNetConfig
545
 
546
  def __init__(self, config: ISNetConfig) -> None:
@@ -582,7 +601,7 @@ class ISNet(PreTrainedModel):
582
 
583
  # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
584
 
585
- def compute_loss_kl(self, preds, targets, dfs, fs, mode="MSE"):
586
  # return muti_loss_fusion(preds,targets)
587
  return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
588
 
@@ -591,25 +610,8 @@ class ISNet(PreTrainedModel):
591
  return muti_loss_fusion(preds, targets)
592
 
593
  def forward(
594
- self, pixel_values: torch.Tensor
595
- ) -> Tuple[
596
- Tuple[
597
- torch.Tensor,
598
- torch.Tensor,
599
- torch.Tensor,
600
- torch.Tensor,
601
- torch.Tensor,
602
- torch.Tensor,
603
- ],
604
- Tuple[
605
- torch.Tensor,
606
- torch.Tensor,
607
- torch.Tensor,
608
- torch.Tensor,
609
- torch.Tensor,
610
- torch.Tensor,
611
- ],
612
- ]:
613
  x = pixel_values
614
  hx = x
615
 
@@ -692,17 +694,24 @@ class ISNet(PreTrainedModel):
692
  hx5d,
693
  hx6,
694
  )
695
- return activated, hidden_states
 
 
 
 
 
 
 
696
 
697
 
698
  def convert_from_checkpoint(
699
  repo_id: str, filename: str, config: Optional[ISNetConfig] = None
700
- ) -> ISNet:
701
  from huggingface_hub import hf_hub_download
702
 
703
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
704
  config = config or ISNetConfig()
705
- model = ISNet(config)
706
 
707
  logger.info(f"Loading checkpoint from {checkpoint_path}")
708
  state_dict = torch.load(checkpoint_path)
 
1
  import logging
2
+ from dataclasses import dataclass
3
+ from typing import Literal, Optional, Tuple, Union
4
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from transformers import PreTrainedModel
9
+ from transformers.utils import ModelOutput
10
 
11
  from .configuration_isnet import ISNetConfig
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+
16
+ @dataclass
17
+ class ISNetStageOutput(ModelOutput):
18
+ d1: torch.Tensor
19
+ d2: Optional[torch.Tensor] = None
20
+ d3: Optional[torch.Tensor] = None
21
+ d4: Optional[torch.Tensor] = None
22
+ d5: Optional[torch.Tensor] = None
23
+ d6: Optional[torch.Tensor] = None
24
+
25
+
26
+ @dataclass
27
+ class ISNetModelOutput(ModelOutput):
28
+ activated: ISNetStageOutput
29
+ hidden_states: Optional[ISNetStageOutput] = None
30
+
31
+
32
  bce_loss = nn.BCELoss(size_average=True)
33
 
34
 
 
559
  return activated, hidden_states
560
 
561
 
562
+ class ISNetModel(PreTrainedModel):
563
  config_class = ISNetConfig
564
 
565
  def __init__(self, config: ISNetConfig) -> None:
 
601
 
602
  # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
603
 
604
+ def compute_loss_kl(self, preds, targets, dfs, fs, mode: LossMode = "MSE"):
605
  # return muti_loss_fusion(preds,targets)
606
  return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
607
 
 
610
  return muti_loss_fusion(preds, targets)
611
 
612
  def forward(
613
+ self, pixel_values: torch.Tensor, return_dict: Optional[bool] = None
614
+ ) -> Union[Tuple, ISNetModelOutput]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  x = pixel_values
616
  hx = x
617
 
 
694
  hx5d,
695
  hx6,
696
  )
697
+
698
+ if not return_dict:
699
+ return activated, hidden_states
700
+
701
+ return ISNetModelOutput(
702
+ activated=ISNetStageOutput(*activated),
703
+ hidden_states=ISNetStageOutput(*hidden_states),
704
+ )
705
 
706
 
707
  def convert_from_checkpoint(
708
  repo_id: str, filename: str, config: Optional[ISNetConfig] = None
709
+ ) -> ISNetModel:
710
  from huggingface_hub import hf_hub_download
711
 
712
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
713
  config = config or ISNetConfig()
714
+ model = ISNetModel(config)
715
 
716
  logger.info(f"Loading checkpoint from {checkpoint_path}")
717
  state_dict = torch.load(checkpoint_path)