monet-joe commited on
Commit
98945f7
1 Parent(s): 613d22e

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +183 -182
  2. model.py +144 -142
app.py CHANGED
@@ -1,182 +1,183 @@
1
- import os
2
- import torch
3
- import random
4
- import shutil
5
- import librosa
6
- import warnings
7
- import numpy as np
8
- import gradio as gr
9
- import librosa.display
10
- import matplotlib.pyplot as plt
11
- import torchvision.transforms as transforms
12
- from utils import get_modelist, find_wav_files
13
- from collections import Counter
14
- from model import EvalNet
15
- from PIL import Image
16
-
17
-
18
- CLASSES = ["m_bel", "f_bel", "m_folk", "f_folk"]
19
-
20
-
21
- def most_common_element(input_list):
22
- # 使用 Counter 统计每个元素的出现次数
23
- counter = Counter(input_list)
24
- # 使用 most_common 方法获取出现次数最多的元素
25
- most_common_element, _ = counter.most_common(1)[0]
26
- return most_common_element
27
-
28
-
29
- def wav_to_mel(audio_path: str, width=1.6, topdb=40):
30
- os.makedirs("./tmp", exist_ok=True)
31
- try:
32
- y, sr = librosa.load(audio_path, sr=48000)
33
- non_silents = librosa.effects.split(y, top_db=topdb)
34
- non_silent = np.concatenate([y[start:end] for start, end in non_silents])
35
- mel_spec = librosa.feature.melspectrogram(y=non_silent, sr=sr)
36
- log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
37
- dur = librosa.get_duration(y=non_silent, sr=sr)
38
- total_frames = log_mel_spec.shape[1]
39
- step = int(width * total_frames / dur)
40
- count = int(total_frames / step)
41
- begin = int(0.5 * (total_frames - count * step))
42
- end = begin + step * count
43
- for i in range(begin, end, step):
44
- librosa.display.specshow(log_mel_spec[:, i : i + step])
45
- plt.axis("off")
46
- plt.savefig(
47
- f"./tmp/mel_{round(dur, 2)}_{i}.jpg",
48
- bbox_inches="tight",
49
- pad_inches=0.0,
50
- )
51
- plt.close()
52
-
53
- except Exception as e:
54
- print(f"Error converting {audio_path} : {e}")
55
-
56
-
57
- def wav_to_cqt(audio_path: str, width=1.6, topdb=40):
58
- os.makedirs("./tmp", exist_ok=True)
59
- try:
60
- y, sr = librosa.load(audio_path, sr=48000)
61
- non_silents = librosa.effects.split(y, top_db=topdb)
62
- non_silent = np.concatenate([y[start:end] for start, end in non_silents])
63
- cqt_spec = librosa.cqt(y=non_silent, sr=sr)
64
- log_cqt_spec = librosa.power_to_db(np.abs(cqt_spec) ** 2, ref=np.max)
65
- dur = librosa.get_duration(y=non_silent, sr=sr)
66
- total_frames = log_cqt_spec.shape[1]
67
- step = int(width * total_frames / dur)
68
- count = int(total_frames / step)
69
- begin = int(0.5 * (total_frames - count * step))
70
- end = begin + step * count
71
- for i in range(begin, end, step):
72
- librosa.display.specshow(log_cqt_spec[:, i : i + step])
73
- plt.axis("off")
74
- plt.savefig(
75
- f"./tmp/cqt_{round(dur, 2)}_{i}.jpg",
76
- bbox_inches="tight",
77
- pad_inches=0.0,
78
- )
79
- plt.close()
80
-
81
- except Exception as e:
82
- print(f"Error converting {audio_path} : {e}")
83
-
84
-
85
- def wav_to_chroma(audio_path: str, width=1.6, topdb=40):
86
- os.makedirs("./tmp", exist_ok=True)
87
- try:
88
- y, sr = librosa.load(audio_path, sr=48000)
89
- non_silents = librosa.effects.split(y, top_db=topdb)
90
- non_silent = np.concatenate([y[start:end] for start, end in non_silents])
91
- chroma_spec = librosa.feature.chroma_stft(y=non_silent, sr=sr)
92
- log_chroma_spec = librosa.power_to_db(np.abs(chroma_spec) ** 2, ref=np.max)
93
- dur = librosa.get_duration(y=non_silent, sr=sr)
94
- total_frames = log_chroma_spec.shape[1]
95
- step = int(width * total_frames / dur)
96
- count = int(total_frames / step)
97
- begin = int(0.5 * (total_frames - count * step))
98
- end = begin + step * count
99
- for i in range(begin, end, step):
100
- librosa.display.specshow(log_chroma_spec[:, i : i + step])
101
- plt.axis("off")
102
- plt.savefig(
103
- f"./tmp/chroma_{round(dur, 2)}_{i}.jpg",
104
- bbox_inches="tight",
105
- pad_inches=0.0,
106
- )
107
- plt.close()
108
-
109
- except Exception as e:
110
- print(f"Error converting {audio_path} : {e}")
111
-
112
-
113
- def embed_img(img_path, input_size=224):
114
- transform = transforms.Compose(
115
- [
116
- transforms.Resize([input_size, input_size]),
117
- transforms.ToTensor(),
118
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
119
- ]
120
- )
121
- img = Image.open(img_path).convert("RGB")
122
- return transform(img).unsqueeze(0)
123
-
124
-
125
- def inference(wav_path: str, log_name: str, folder_path="./tmp"):
126
- if os.path.exists(folder_path):
127
- shutil.rmtree(folder_path)
128
-
129
- if not wav_path:
130
- wav_path = "./examples/f_bel.wav"
131
-
132
- model = EvalNet(log_name).model
133
- spec = log_name.split("_")[-1]
134
- eval("wav_to_%s" % spec)(wav_path)
135
- outputs = []
136
- all_files = os.listdir(folder_path)
137
- for file_name in all_files:
138
- if file_name.lower().endswith(".jpg"):
139
- file_path = os.path.join(folder_path, file_name)
140
- input = embed_img(file_path)
141
- output = model(input)
142
- pred_id = torch.max(output.data, 1)[1]
143
- outputs.append(pred_id)
144
-
145
- max_count_item = most_common_element(outputs)
146
- shutil.rmtree(folder_path)
147
- return os.path.basename(wav_path), translate[CLASSES[max_count_item]]
148
-
149
-
150
- if __name__ == "__main__":
151
- warnings.filterwarnings("ignore")
152
-
153
- models = get_modelist()
154
- translate = {
155
- "m_bel": "Male bel canto",
156
- "m_folk": "Male folk singing",
157
- "f_bel": "Female bel canto",
158
- "f_folk": "Female folk singing",
159
- }
160
- examples = []
161
- example_wavs = find_wav_files()
162
- model_num = len(models)
163
- for wav in example_wavs:
164
- examples.append([wav, models[random.randint(0, model_num - 1)]])
165
-
166
- with gr.Blocks() as demo:
167
- gr.Interface(
168
- fn=inference,
169
- inputs=[
170
- gr.Audio(label="Uploading a recording", type="filepath"),
171
- gr.Dropdown(choices=models, label="Select a model", value=models[0]),
172
- ],
173
- outputs=[
174
- gr.Textbox(label="Audio filename", show_copy_button=True),
175
- gr.Textbox(label="Singing style recognition", show_copy_button=True),
176
- ],
177
- examples=examples,
178
- allow_flagging="never",
179
- title="It is recommended to keep the recording length around 5s, too long will affect the recognition efficiency.",
180
- )
181
-
182
- demo.launch()
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import shutil
5
+ import librosa
6
+ import warnings
7
+ import numpy as np
8
+ import gradio as gr
9
+ import librosa.display
10
+ import matplotlib.pyplot as plt
11
+ import torchvision.transforms as transforms
12
+ from utils import get_modelist, find_wav_files
13
+ from collections import Counter
14
+ from model import EvalNet
15
+ from PIL import Image
16
+
17
+
18
+ TRANSLATE = {
19
+ "m_bel": "男声美声唱法",
20
+ "f_bel": "女声美声唱法",
21
+ "m_folk": "男声民族唱法",
22
+ "f_folk": "女声民族唱法",
23
+ }
24
+
25
+ CLASSES = list(TRANSLATE.keys())
26
+
27
+
28
+ def most_common_element(input_list):
29
+ # 使用 Counter 统计每个元素的出现次数
30
+ counter = Counter(input_list)
31
+ # 使用 most_common 方法获取出现次数最多的元素
32
+ most_common_element, _ = counter.most_common(1)[0]
33
+ return most_common_element
34
+
35
+
36
+ def wav_to_mel(audio_path: str, width=1.6, topdb=40):
37
+ os.makedirs("./tmp", exist_ok=True)
38
+ try:
39
+ y, sr = librosa.load(audio_path, sr=48000)
40
+ non_silents = librosa.effects.split(y, top_db=topdb)
41
+ non_silent = np.concatenate([y[start:end] for start, end in non_silents])
42
+ mel_spec = librosa.feature.melspectrogram(y=non_silent, sr=sr)
43
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
44
+ dur = librosa.get_duration(y=non_silent, sr=sr)
45
+ total_frames = log_mel_spec.shape[1]
46
+ step = int(width * total_frames / dur)
47
+ count = int(total_frames / step)
48
+ begin = int(0.5 * (total_frames - count * step))
49
+ end = begin + step * count
50
+ for i in range(begin, end, step):
51
+ librosa.display.specshow(log_mel_spec[:, i : i + step])
52
+ plt.axis("off")
53
+ plt.savefig(
54
+ f"./tmp/mel_{round(dur, 2)}_{i}.jpg",
55
+ bbox_inches="tight",
56
+ pad_inches=0.0,
57
+ )
58
+ plt.close()
59
+
60
+ except Exception as e:
61
+ print(f"Error converting {audio_path} : {e}")
62
+
63
+
64
+ def wav_to_cqt(audio_path: str, width=1.6, topdb=40):
65
+ os.makedirs("./tmp", exist_ok=True)
66
+ try:
67
+ y, sr = librosa.load(audio_path, sr=48000)
68
+ non_silents = librosa.effects.split(y, top_db=topdb)
69
+ non_silent = np.concatenate([y[start:end] for start, end in non_silents])
70
+ cqt_spec = librosa.cqt(y=non_silent, sr=sr)
71
+ log_cqt_spec = librosa.power_to_db(np.abs(cqt_spec) ** 2, ref=np.max)
72
+ dur = librosa.get_duration(y=non_silent, sr=sr)
73
+ total_frames = log_cqt_spec.shape[1]
74
+ step = int(width * total_frames / dur)
75
+ count = int(total_frames / step)
76
+ begin = int(0.5 * (total_frames - count * step))
77
+ end = begin + step * count
78
+ for i in range(begin, end, step):
79
+ librosa.display.specshow(log_cqt_spec[:, i : i + step])
80
+ plt.axis("off")
81
+ plt.savefig(
82
+ f"./tmp/cqt_{round(dur, 2)}_{i}.jpg",
83
+ bbox_inches="tight",
84
+ pad_inches=0.0,
85
+ )
86
+ plt.close()
87
+
88
+ except Exception as e:
89
+ print(f"Error converting {audio_path} : {e}")
90
+
91
+
92
+ def wav_to_chroma(audio_path: str, width=1.6, topdb=40):
93
+ os.makedirs("./tmp", exist_ok=True)
94
+ try:
95
+ y, sr = librosa.load(audio_path, sr=48000)
96
+ non_silents = librosa.effects.split(y, top_db=topdb)
97
+ non_silent = np.concatenate([y[start:end] for start, end in non_silents])
98
+ chroma_spec = librosa.feature.chroma_stft(y=non_silent, sr=sr)
99
+ log_chroma_spec = librosa.power_to_db(np.abs(chroma_spec) ** 2, ref=np.max)
100
+ dur = librosa.get_duration(y=non_silent, sr=sr)
101
+ total_frames = log_chroma_spec.shape[1]
102
+ step = int(width * total_frames / dur)
103
+ count = int(total_frames / step)
104
+ begin = int(0.5 * (total_frames - count * step))
105
+ end = begin + step * count
106
+ for i in range(begin, end, step):
107
+ librosa.display.specshow(log_chroma_spec[:, i : i + step])
108
+ plt.axis("off")
109
+ plt.savefig(
110
+ f"./tmp/chroma_{round(dur, 2)}_{i}.jpg",
111
+ bbox_inches="tight",
112
+ pad_inches=0.0,
113
+ )
114
+ plt.close()
115
+
116
+ except Exception as e:
117
+ print(f"Error converting {audio_path} : {e}")
118
+
119
+
120
+ def embed_img(img_path, input_size=224):
121
+ transform = transforms.Compose(
122
+ [
123
+ transforms.Resize([input_size, input_size]),
124
+ transforms.ToTensor(),
125
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
126
+ ]
127
+ )
128
+ img = Image.open(img_path).convert("RGB")
129
+ return transform(img).unsqueeze(0)
130
+
131
+
132
+ def inference(wav_path: str, log_name: str, folder_path="./tmp"):
133
+ if os.path.exists(folder_path):
134
+ shutil.rmtree(folder_path)
135
+
136
+ if not wav_path:
137
+ wav_path = "./examples/f_bel.wav"
138
+
139
+ model = EvalNet(log_name).model
140
+ spec = log_name.split("_")[-1]
141
+ eval("wav_to_%s" % spec)(wav_path)
142
+ outputs = []
143
+ all_files = os.listdir(folder_path)
144
+ for file_name in all_files:
145
+ if file_name.lower().endswith(".jpg"):
146
+ file_path = os.path.join(folder_path, file_name)
147
+ input = embed_img(file_path)
148
+ output = model(input)
149
+ pred_id = torch.max(output.data, 1)[1]
150
+ outputs.append(pred_id)
151
+
152
+ max_count_item = most_common_element(outputs)
153
+ shutil.rmtree(folder_path)
154
+ return os.path.basename(wav_path), TRANSLATE[CLASSES[max_count_item]]
155
+
156
+
157
+ if __name__ == "__main__":
158
+ warnings.filterwarnings("ignore")
159
+
160
+ models = get_modelist()
161
+ examples = []
162
+ example_wavs = find_wav_files()
163
+ model_num = len(models)
164
+ for wav in example_wavs:
165
+ examples.append([wav, models[random.randint(0, model_num - 1)]])
166
+
167
+ with gr.Blocks() as demo:
168
+ gr.Interface(
169
+ fn=inference,
170
+ inputs=[
171
+ gr.Audio(label="上传录音", type="filepath"),
172
+ gr.Dropdown(choices=models, label="选择模型", value=models[0]),
173
+ ],
174
+ outputs=[
175
+ gr.Textbox(label="音频文件名", show_copy_button=True),
176
+ gr.Textbox(label="唱法识别", show_copy_button=True),
177
+ ],
178
+ examples=examples,
179
+ allow_flagging="never",
180
+ title="建议录音时长保持在 5s 左右, 过长会影响识别效率",
181
+ )
182
+
183
+ demo.launch()
model.py CHANGED
@@ -1,142 +1,144 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.models as models
4
- from modelscope.msdatasets import MsDataset
5
- from utils import MODEL_DIR
6
-
7
-
8
- def get_backbone(ver, backbone_list):
9
- for bb in backbone_list:
10
- if ver == bb["ver"]:
11
- return bb
12
-
13
- print("Backbone name not found, using default option - alexnet.")
14
- return backbone_list[0]
15
-
16
-
17
- def model_info(m_ver):
18
- backbone_list = MsDataset.load("monetjoe/cv_backbones", split="train")
19
- backbone = get_backbone(m_ver, backbone_list)
20
- m_type = str(backbone["type"])
21
- input_size = int(backbone["input_size"])
22
- return m_type, input_size
23
-
24
-
25
- def Classifier(cls_num: int, output_size: int, linear_output: bool):
26
- q = (1.0 * output_size / cls_num) ** 0.25
27
- l1 = int(q * cls_num)
28
- l2 = int(q * l1)
29
- l3 = int(q * l2)
30
-
31
- if linear_output:
32
- return torch.nn.Sequential(
33
- nn.Dropout(),
34
- nn.Linear(output_size, l3),
35
- nn.ReLU(inplace=True),
36
- nn.Dropout(),
37
- nn.Linear(l3, l2),
38
- nn.ReLU(inplace=True),
39
- nn.Dropout(),
40
- nn.Linear(l2, l1),
41
- nn.ReLU(inplace=True),
42
- nn.Linear(l1, cls_num),
43
- )
44
-
45
- else:
46
- return torch.nn.Sequential(
47
- nn.Dropout(),
48
- nn.Conv2d(output_size, l3, kernel_size=(1, 1), stride=(1, 1)),
49
- nn.ReLU(inplace=True),
50
- nn.AdaptiveAvgPool2d(output_size=(1, 1)),
51
- nn.Flatten(),
52
- nn.Linear(l3, l2),
53
- nn.ReLU(inplace=True),
54
- nn.Dropout(),
55
- nn.Linear(l2, l1),
56
- nn.ReLU(inplace=True),
57
- nn.Linear(l1, cls_num),
58
- )
59
-
60
-
61
- class EvalNet:
62
- model = None
63
- m_type = "squeezenet"
64
- input_size = 224
65
- output_size = 512
66
-
67
- def __init__(self, log_name: str, cls_num=4):
68
- saved_model_path = f"{MODEL_DIR}/{log_name}/save.pt"
69
- m_ver = "_".join(log_name.split("_")[:-1])
70
- self.m_type, self.input_size = model_info(m_ver)
71
-
72
- if not hasattr(models, m_ver):
73
- print("Unsupported model.")
74
- exit()
75
-
76
- self.model = eval("models.%s()" % m_ver)
77
- linear_output = self._set_outsize()
78
- self._set_classifier(cls_num, linear_output)
79
- checkpoint = torch.load(saved_model_path, map_location="cpu")
80
- if torch.cuda.is_available():
81
- checkpoint = torch.load(saved_model_path)
82
-
83
- self.model.load_state_dict(checkpoint, False)
84
- self.model.eval()
85
-
86
- def _set_outsize(self, debug_mode=False):
87
- for name, module in self.model.named_modules():
88
- if (
89
- str(name).__contains__("classifier")
90
- or str(name).__eq__("fc")
91
- or str(name).__contains__("head")
92
- ):
93
- if isinstance(module, torch.nn.Linear):
94
- self.output_size = module.in_features
95
- if debug_mode:
96
- print(
97
- f"{name}(Linear): {self.output_size} -> {module.out_features}"
98
- )
99
- return True
100
-
101
- if isinstance(module, torch.nn.Conv2d):
102
- self.output_size = module.in_channels
103
- if debug_mode:
104
- print(
105
- f"{name}(Conv2d): {self.output_size} -> {module.out_channels}"
106
- )
107
- return False
108
-
109
- return False
110
-
111
- def _set_classifier(self, cls_num, linear_output):
112
- if self.m_type == "convnext":
113
- del self.model.classifier[2]
114
- self.model.classifier = nn.Sequential(
115
- *list(self.model.classifier)
116
- + list(Classifier(cls_num, self.output_size, linear_output))
117
- )
118
- return
119
-
120
- if hasattr(self.model, "classifier"):
121
- self.model.classifier = Classifier(cls_num, self.output_size, linear_output)
122
- return
123
-
124
- elif hasattr(self.model, "fc"):
125
- self.model.fc = Classifier(cls_num, self.output_size, linear_output)
126
- return
127
-
128
- elif hasattr(self.model, "head"):
129
- self.model.head = Classifier(cls_num, self.output_size, linear_output)
130
- return
131
-
132
- self.model.heads.head = Classifier(cls_num, self.output_size, linear_output)
133
-
134
- def forward(self, x):
135
- if torch.cuda.is_available():
136
- x = x.cuda()
137
- self.model = self.model.cuda()
138
-
139
- if self.m_type == "googlenet" and self.training:
140
- return self.model(x)[0]
141
- else:
142
- return self.model(x)
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ from modelscope.msdatasets import MsDataset
5
+ from utils import MODEL_DIR
6
+
7
+
8
+ def get_backbone(ver, backbone_list):
9
+ for bb in backbone_list:
10
+ if ver == bb["ver"]:
11
+ return bb
12
+
13
+ print("Backbone name not found, using default option - alexnet.")
14
+ return backbone_list[0]
15
+
16
+
17
+ def model_info(m_ver):
18
+ backbone_list = MsDataset.load(
19
+ "monetjoe/cv_backbones", split="train"
20
+ )
21
+ backbone = get_backbone(m_ver, backbone_list)
22
+ m_type = str(backbone["type"])
23
+ input_size = int(backbone["input_size"])
24
+ return m_type, input_size
25
+
26
+
27
+ def Classifier(cls_num: int, output_size: int, linear_output: bool):
28
+ q = (1.0 * output_size / cls_num) ** 0.25
29
+ l1 = int(q * cls_num)
30
+ l2 = int(q * l1)
31
+ l3 = int(q * l2)
32
+
33
+ if linear_output:
34
+ return torch.nn.Sequential(
35
+ nn.Dropout(),
36
+ nn.Linear(output_size, l3),
37
+ nn.ReLU(inplace=True),
38
+ nn.Dropout(),
39
+ nn.Linear(l3, l2),
40
+ nn.ReLU(inplace=True),
41
+ nn.Dropout(),
42
+ nn.Linear(l2, l1),
43
+ nn.ReLU(inplace=True),
44
+ nn.Linear(l1, cls_num),
45
+ )
46
+
47
+ else:
48
+ return torch.nn.Sequential(
49
+ nn.Dropout(),
50
+ nn.Conv2d(output_size, l3, kernel_size=(1, 1), stride=(1, 1)),
51
+ nn.ReLU(inplace=True),
52
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
53
+ nn.Flatten(),
54
+ nn.Linear(l3, l2),
55
+ nn.ReLU(inplace=True),
56
+ nn.Dropout(),
57
+ nn.Linear(l2, l1),
58
+ nn.ReLU(inplace=True),
59
+ nn.Linear(l1, cls_num),
60
+ )
61
+
62
+
63
+ class EvalNet:
64
+ model = None
65
+ m_type = "squeezenet"
66
+ input_size = 224
67
+ output_size = 512
68
+
69
+ def __init__(self, log_name: str, cls_num=4):
70
+ saved_model_path = f"{MODEL_DIR}/{log_name}/save.pt"
71
+ m_ver = "_".join(log_name.split("_")[:-1])
72
+ self.m_type, self.input_size = model_info(m_ver)
73
+
74
+ if not hasattr(models, m_ver):
75
+ print("Unsupported model.")
76
+ exit()
77
+
78
+ self.model = eval("models.%s()" % m_ver)
79
+ linear_output = self._set_outsize()
80
+ self._set_classifier(cls_num, linear_output)
81
+ checkpoint = torch.load(saved_model_path, map_location="cpu")
82
+ if torch.cuda.is_available():
83
+ checkpoint = torch.load(saved_model_path)
84
+
85
+ self.model.load_state_dict(checkpoint, False)
86
+ self.model.eval()
87
+
88
+ def _set_outsize(self, debug_mode=False):
89
+ for name, module in self.model.named_modules():
90
+ if (
91
+ str(name).__contains__("classifier")
92
+ or str(name).__eq__("fc")
93
+ or str(name).__contains__("head")
94
+ ):
95
+ if isinstance(module, torch.nn.Linear):
96
+ self.output_size = module.in_features
97
+ if debug_mode:
98
+ print(
99
+ f"{name}(Linear): {self.output_size} -> {module.out_features}"
100
+ )
101
+ return True
102
+
103
+ if isinstance(module, torch.nn.Conv2d):
104
+ self.output_size = module.in_channels
105
+ if debug_mode:
106
+ print(
107
+ f"{name}(Conv2d): {self.output_size} -> {module.out_channels}"
108
+ )
109
+ return False
110
+
111
+ return False
112
+
113
+ def _set_classifier(self, cls_num, linear_output):
114
+ if self.m_type == "convnext":
115
+ del self.model.classifier[2]
116
+ self.model.classifier = nn.Sequential(
117
+ *list(self.model.classifier)
118
+ + list(Classifier(cls_num, self.output_size, linear_output))
119
+ )
120
+ return
121
+
122
+ if hasattr(self.model, "classifier"):
123
+ self.model.classifier = Classifier(cls_num, self.output_size, linear_output)
124
+ return
125
+
126
+ elif hasattr(self.model, "fc"):
127
+ self.model.fc = Classifier(cls_num, self.output_size, linear_output)
128
+ return
129
+
130
+ elif hasattr(self.model, "head"):
131
+ self.model.head = Classifier(cls_num, self.output_size, linear_output)
132
+ return
133
+
134
+ self.model.heads.head = Classifier(cls_num, self.output_size, linear_output)
135
+
136
+ def forward(self, x):
137
+ if torch.cuda.is_available():
138
+ x = x.cuda()
139
+ self.model = self.model.cuda()
140
+
141
+ if self.m_type == "googlenet" and self.training:
142
+ return self.model(x)[0]
143
+ else:
144
+ return self.model(x)