hafidhsoekma commited on
Commit
d4df789
1 Parent(s): 76dfa20

Update models/deep_learning/backbone_model.py

Browse files
Files changed (1) hide show
  1. models/deep_learning/backbone_model.py +109 -109
models/deep_learning/backbone_model.py CHANGED
@@ -1,109 +1,109 @@
1
- import os
2
- import sys
3
-
4
- sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
5
-
6
- import timm
7
- import torch
8
- import torch.nn as nn
9
- from transformers import CLIPModel as CLIPTransformersModel
10
-
11
- from utils import configs
12
- from utils.functional import check_data_type_variable, get_device
13
-
14
-
15
- class CLIPModel(nn.Module):
16
- def __init__(
17
- self,
18
- model_clip_name: str,
19
- freeze_model: bool,
20
- pretrained_model: bool,
21
- num_classes: int,
22
- ):
23
- super().__init__()
24
- self.model_clip_name = model_clip_name
25
- self.freeze_model = freeze_model
26
- self.pretrained_model = pretrained_model
27
- self.num_classes = num_classes
28
- self.device = get_device()
29
-
30
- self.check_arguments()
31
- self.init_model()
32
-
33
- def check_arguments(self):
34
- check_data_type_variable(self.model_clip_name, str)
35
- check_data_type_variable(self.freeze_model, bool)
36
- check_data_type_variable(self.pretrained_model, bool)
37
- check_data_type_variable(self.num_classes, int)
38
-
39
- if self.model_clip_name != configs.CLIP_NAME_MODEL:
40
- raise ValueError(
41
- f"Model clip name must be {configs.CLIP_NAME_MODEL}, but it is {self.model_clip_name}"
42
- )
43
-
44
- def init_model(self):
45
- clip_model = CLIPTransformersModel.from_pretrained(self.model_clip_name)
46
- for layer in clip_model.children():
47
- if hasattr(layer, "reset_parameters") and not self.pretrained_model:
48
- layer.reset_parameters()
49
- for param in clip_model.parameters():
50
- param.required_grad = False if not self.freeze_model else True
51
- self.vision_model = clip_model.vision_model.to(self.device)
52
- self.visual_projection = clip_model.visual_projection.to(self.device).to(
53
- self.device
54
- )
55
- self.classifier = nn.Linear(
56
- 512, 1 if self.num_classes in (1, 2) else self.num_classes
57
- ).to(self.device)
58
-
59
- def forward(self, x: torch.Tensor) -> torch.Tensor:
60
- x = self.vision_model(x)
61
- x = self.visual_projection(x.pooler_output)
62
- x = self.classifier(x)
63
- return x
64
-
65
-
66
- class TorchModel(nn.Module):
67
- def __init__(
68
- self,
69
- name_model: str,
70
- freeze_model: bool,
71
- pretrained_model: bool,
72
- num_classes: int,
73
- ):
74
- super().__init__()
75
- self.name_model = name_model
76
- self.freeze_model = freeze_model
77
- self.pretrained_model = pretrained_model
78
- self.num_classes = num_classes
79
- self.device = get_device()
80
-
81
- self.check_arguments()
82
- self.init_model()
83
-
84
- def check_arguments(self):
85
- check_data_type_variable(self.name_model, str)
86
- check_data_type_variable(self.freeze_model, bool)
87
- check_data_type_variable(self.pretrained_model, bool)
88
- check_data_type_variable(self.num_classes, int)
89
-
90
- if self.name_model not in tuple(configs.NAME_MODELS.keys()):
91
- raise ValueError(
92
- f"Name model must be in {tuple(configs.NAME_MODELS.keys())}, but it is {self.name_model}"
93
- )
94
-
95
- def init_model(self):
96
- self.model = timm.create_model(
97
- self.name_model, pretrained=self.pretrained_model, num_classes=0
98
- ).to(self.device)
99
- for param in self.model.parameters():
100
- param.required_grad = False if not self.freeze_model else True
101
- self.classifier = nn.Linear(
102
- self.model.num_features,
103
- 1 if self.num_classes in (1, 2) else self.num_classes,
104
- ).to(self.device)
105
-
106
- def forward(self, x: torch.Tensor) -> torch.Tensor:
107
- x = self.model(x)
108
- x = self.classifier(x)
109
- return x
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
5
+
6
+ import timm
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import CLIPModel as CLIPTransformersModel
10
+
11
+ from utils import configs
12
+ from utils.functional import check_data_type_variable, get_device
13
+
14
+
15
+ class CLIPModel(nn.Module):
16
+ def __init__(
17
+ self,
18
+ model_clip_name: str,
19
+ freeze_model: bool,
20
+ pretrained_model: bool,
21
+ num_classes: int,
22
+ ):
23
+ super().__init__()
24
+ self.model_clip_name = model_clip_name
25
+ self.freeze_model = freeze_model
26
+ self.pretrained_model = pretrained_model
27
+ self.num_classes = num_classes
28
+ self.device = get_device()
29
+
30
+ self.check_arguments()
31
+ self.init_model()
32
+
33
+ def check_arguments(self):
34
+ check_data_type_variable(self.model_clip_name, str)
35
+ check_data_type_variable(self.freeze_model, bool)
36
+ check_data_type_variable(self.pretrained_model, bool)
37
+ check_data_type_variable(self.num_classes, int)
38
+
39
+ if self.model_clip_name != configs.CLIP_NAME_MODEL:
40
+ raise ValueError(
41
+ f"Model clip name must be {configs.CLIP_NAME_MODEL}, but it is {self.model_clip_name}"
42
+ )
43
+
44
+ def init_model(self):
45
+ self.clip_model = CLIPTransformersModel.from_pretrained(self.model_clip_name)
46
+ for layer in self.clip_model.children():
47
+ if hasattr(layer, "reset_parameters") and not self.pretrained_model:
48
+ layer.reset_parameters()
49
+ for param in self.clip_model.parameters():
50
+ param.required_grad = False if not self.freeze_model else True
51
+ self.vision_model = self.clip_model.vision_model.to(self.device)
52
+ self.visual_projection = self.clip_model.visual_projection.to(self.device).to(
53
+ self.device
54
+ )
55
+ self.classifier = nn.Linear(
56
+ 512, 1 if self.num_classes in (1, 2) else self.num_classes
57
+ ).to(self.device)
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = self.vision_model(x)
61
+ x = self.visual_projection(x.pooler_output)
62
+ x = self.classifier(x)
63
+ return x
64
+
65
+
66
+ class TorchModel(nn.Module):
67
+ def __init__(
68
+ self,
69
+ name_model: str,
70
+ freeze_model: bool,
71
+ pretrained_model: bool,
72
+ num_classes: int,
73
+ ):
74
+ super().__init__()
75
+ self.name_model = name_model
76
+ self.freeze_model = freeze_model
77
+ self.pretrained_model = pretrained_model
78
+ self.num_classes = num_classes
79
+ self.device = get_device()
80
+
81
+ self.check_arguments()
82
+ self.init_model()
83
+
84
+ def check_arguments(self):
85
+ check_data_type_variable(self.name_model, str)
86
+ check_data_type_variable(self.freeze_model, bool)
87
+ check_data_type_variable(self.pretrained_model, bool)
88
+ check_data_type_variable(self.num_classes, int)
89
+
90
+ if self.name_model not in tuple(configs.NAME_MODELS.keys()):
91
+ raise ValueError(
92
+ f"Name model must be in {tuple(configs.NAME_MODELS.keys())}, but it is {self.name_model}"
93
+ )
94
+
95
+ def init_model(self):
96
+ self.model = timm.create_model(
97
+ self.name_model, pretrained=self.pretrained_model, num_classes=0
98
+ ).to(self.device)
99
+ for param in self.model.parameters():
100
+ param.required_grad = False if not self.freeze_model else True
101
+ self.classifier = nn.Linear(
102
+ self.model.num_features,
103
+ 1 if self.num_classes in (1, 2) else self.num_classes,
104
+ ).to(self.device)
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ x = self.model(x)
108
+ x = self.classifier(x)
109
+ return x