Upload model
Browse files- config.json +2 -2
- modeling_isnet.py +34 -25
config.json
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
{
|
2 |
"architectures": [
|
3 |
-
"
|
4 |
],
|
5 |
"auto_map": {
|
6 |
"AutoConfig": "configuration_isnet.ISNetConfig",
|
7 |
-
"AutoModel": "modeling_isnet.
|
8 |
},
|
9 |
"in_channels": 3,
|
10 |
"out_channels": 1,
|
|
|
1 |
{
|
2 |
"architectures": [
|
3 |
+
"ISNetModel"
|
4 |
],
|
5 |
"auto_map": {
|
6 |
"AutoConfig": "configuration_isnet.ISNetConfig",
|
7 |
+
"AutoModel": "modeling_isnet.ISNetModel"
|
8 |
},
|
9 |
"in_channels": 3,
|
10 |
"out_channels": 1,
|
modeling_isnet.py
CHANGED
@@ -1,15 +1,34 @@
|
|
1 |
import logging
|
2 |
-
from
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
from transformers import PreTrainedModel
|
|
|
8 |
|
9 |
from .configuration_isnet import ISNetConfig
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
bce_loss = nn.BCELoss(size_average=True)
|
14 |
|
15 |
|
@@ -540,7 +559,7 @@ class ISNetGTEncoder(nn.Module):
|
|
540 |
return activated, hidden_states
|
541 |
|
542 |
|
543 |
-
class
|
544 |
config_class = ISNetConfig
|
545 |
|
546 |
def __init__(self, config: ISNetConfig) -> None:
|
@@ -582,7 +601,7 @@ class ISNet(PreTrainedModel):
|
|
582 |
|
583 |
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
584 |
|
585 |
-
def compute_loss_kl(self, preds, targets, dfs, fs, mode="MSE"):
|
586 |
# return muti_loss_fusion(preds,targets)
|
587 |
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
|
588 |
|
@@ -591,25 +610,8 @@ class ISNet(PreTrainedModel):
|
|
591 |
return muti_loss_fusion(preds, targets)
|
592 |
|
593 |
def forward(
|
594 |
-
self, pixel_values: torch.Tensor
|
595 |
-
) -> Tuple
|
596 |
-
Tuple[
|
597 |
-
torch.Tensor,
|
598 |
-
torch.Tensor,
|
599 |
-
torch.Tensor,
|
600 |
-
torch.Tensor,
|
601 |
-
torch.Tensor,
|
602 |
-
torch.Tensor,
|
603 |
-
],
|
604 |
-
Tuple[
|
605 |
-
torch.Tensor,
|
606 |
-
torch.Tensor,
|
607 |
-
torch.Tensor,
|
608 |
-
torch.Tensor,
|
609 |
-
torch.Tensor,
|
610 |
-
torch.Tensor,
|
611 |
-
],
|
612 |
-
]:
|
613 |
x = pixel_values
|
614 |
hx = x
|
615 |
|
@@ -692,17 +694,24 @@ class ISNet(PreTrainedModel):
|
|
692 |
hx5d,
|
693 |
hx6,
|
694 |
)
|
695 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
696 |
|
697 |
|
698 |
def convert_from_checkpoint(
|
699 |
repo_id: str, filename: str, config: Optional[ISNetConfig] = None
|
700 |
-
) ->
|
701 |
from huggingface_hub import hf_hub_download
|
702 |
|
703 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
704 |
config = config or ISNetConfig()
|
705 |
-
model =
|
706 |
|
707 |
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
708 |
state_dict = torch.load(checkpoint_path)
|
|
|
1 |
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Literal, Optional, Tuple, Union
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
from transformers import PreTrainedModel
|
9 |
+
from transformers.utils import ModelOutput
|
10 |
|
11 |
from .configuration_isnet import ISNetConfig
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class ISNetStageOutput(ModelOutput):
|
18 |
+
d1: torch.Tensor
|
19 |
+
d2: Optional[torch.Tensor] = None
|
20 |
+
d3: Optional[torch.Tensor] = None
|
21 |
+
d4: Optional[torch.Tensor] = None
|
22 |
+
d5: Optional[torch.Tensor] = None
|
23 |
+
d6: Optional[torch.Tensor] = None
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class ISNetModelOutput(ModelOutput):
|
28 |
+
activated: ISNetStageOutput
|
29 |
+
hidden_states: Optional[ISNetStageOutput] = None
|
30 |
+
|
31 |
+
|
32 |
bce_loss = nn.BCELoss(size_average=True)
|
33 |
|
34 |
|
|
|
559 |
return activated, hidden_states
|
560 |
|
561 |
|
562 |
+
class ISNetModel(PreTrainedModel):
|
563 |
config_class = ISNetConfig
|
564 |
|
565 |
def __init__(self, config: ISNetConfig) -> None:
|
|
|
601 |
|
602 |
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
603 |
|
604 |
+
def compute_loss_kl(self, preds, targets, dfs, fs, mode: LossMode = "MSE"):
|
605 |
# return muti_loss_fusion(preds,targets)
|
606 |
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
|
607 |
|
|
|
610 |
return muti_loss_fusion(preds, targets)
|
611 |
|
612 |
def forward(
|
613 |
+
self, pixel_values: torch.Tensor, return_dict: Optional[bool] = None
|
614 |
+
) -> Union[Tuple, ISNetModelOutput]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
615 |
x = pixel_values
|
616 |
hx = x
|
617 |
|
|
|
694 |
hx5d,
|
695 |
hx6,
|
696 |
)
|
697 |
+
|
698 |
+
if not return_dict:
|
699 |
+
return activated, hidden_states
|
700 |
+
|
701 |
+
return ISNetModelOutput(
|
702 |
+
activated=ISNetStageOutput(*activated),
|
703 |
+
hidden_states=ISNetStageOutput(*hidden_states),
|
704 |
+
)
|
705 |
|
706 |
|
707 |
def convert_from_checkpoint(
|
708 |
repo_id: str, filename: str, config: Optional[ISNetConfig] = None
|
709 |
+
) -> ISNetModel:
|
710 |
from huggingface_hub import hf_hub_download
|
711 |
|
712 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
713 |
config = config or ISNetConfig()
|
714 |
+
model = ISNetModel(config)
|
715 |
|
716 |
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
717 |
state_dict = torch.load(checkpoint_path)
|