litagin commited on
Commit
d485dcb
·
verified ·
1 Parent(s): f8659f7

Upload 11 files

Browse files
Files changed (10) hide show
  1. .gitignore +3 -0
  2. README.md +1 -1
  3. app.py +51 -0
  4. ckpt/config.json +13 -0
  5. ckpt/model_final.pth +3 -0
  6. losses.py +176 -0
  7. models.py +88 -0
  8. requirements.txt +4 -0
  9. train.py +243 -0
  10. utils.py +29 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ flagged/
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Japanese Ero Voice Classifier
3
- emoji: 🌍
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
 
1
  ---
2
  title: Japanese Ero Voice Classifier
3
+ emoji: 🥰
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ import gradio as gr
5
+ import torch
6
+
7
+ from models import AudioClassifier
8
+ from utils import logger
9
+
10
+
11
+ ckpt_dir = Path("ckpt/")
12
+ config_path = ckpt_dir / "config.json"
13
+ assert config_path.exists(), f"config.json not found in {ckpt_dir}"
14
+ config = json.loads((ckpt_dir / "config.json").read_text())
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model = AudioClassifier(device=device, **config["model"]).to(device)
17
+
18
+ # Latest checkpoint
19
+ if (ckpt_dir / "model_final.pth").exists():
20
+ ckpt = ckpt_dir / "model_final.pth"
21
+ else:
22
+ ckpt = sorted(ckpt_dir.glob("*.pth"))[-1]
23
+ logger.info(f"Loading {ckpt}...")
24
+ model.load_state_dict(torch.load(ckpt))
25
+
26
+
27
+ def classify_audio(audio_file: str):
28
+ logger.info(f"Classifying {audio_file}...")
29
+ output = model.infer_from_file(audio_file)
30
+ logger.success(f"Predicted: {output}")
31
+ return output
32
+
33
+
34
+ desc = """
35
+ # NSFW音声分類器
36
+
37
+ 出力は以下の3つのクラスの確率です。
38
+ - usual: 通常の音声
39
+ - aegi: 喘ぎ声
40
+ - chupa: チュパ音(フェラやキス音声)
41
+ """
42
+
43
+
44
+ with gr.Interface(
45
+ fn=classify_audio,
46
+ inputs=gr.Audio(label="Input audio", type="filepath"),
47
+ outputs=gr.Text(label="Classification"),
48
+ description=desc,
49
+ allow_flagging="never",
50
+ ) as iface:
51
+ iface.launch()
ckpt/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "label2id": {
4
+ "usual": 0,
5
+ "aegi": 1,
6
+ "chupa": 2
7
+ },
8
+ "num_hidden_layers": 2,
9
+ "hidden_dim": 128
10
+ },
11
+ "lr": 0.001,
12
+ "lr_decay": 0.996
13
+ }
ckpt/model_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67ffab6e224d9c7f9acbeab40892cfda200a88c9dc2ee2714621bc90eed7a4d5
3
+ size 279357
losses.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class AsymmetricLoss(nn.Module):
6
+ def __init__(
7
+ self,
8
+ gamma_neg=4,
9
+ gamma_pos=1,
10
+ clip=0.05,
11
+ eps=1e-8,
12
+ disable_torch_grad_focal_loss=True,
13
+ ):
14
+ super(AsymmetricLoss, self).__init__()
15
+
16
+ self.gamma_neg = gamma_neg
17
+ self.gamma_pos = gamma_pos
18
+ self.clip = clip
19
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
20
+ self.eps = eps
21
+
22
+ def forward(self, x, y):
23
+ """ "
24
+ Parameters
25
+ ----------
26
+ x: input logits
27
+ y: targets (multi-label binarized vector)
28
+ """
29
+
30
+ # Calculating Probabilities
31
+ x_sigmoid = torch.sigmoid(x)
32
+ xs_pos = x_sigmoid
33
+ xs_neg = 1 - x_sigmoid
34
+
35
+ # Asymmetric Clipping
36
+ if self.clip is not None and self.clip > 0:
37
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
38
+
39
+ # Basic CE calculation
40
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
41
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
42
+ loss = los_pos + los_neg
43
+
44
+ # Asymmetric Focusing
45
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
46
+ if self.disable_torch_grad_focal_loss:
47
+ torch.set_grad_enabled(False)
48
+ pt0 = xs_pos * y
49
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
50
+ pt = pt0 + pt1
51
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
52
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
53
+ if self.disable_torch_grad_focal_loss:
54
+ torch.set_grad_enabled(True)
55
+ loss *= one_sided_w
56
+
57
+ return -loss.sum()
58
+
59
+
60
+ class AsymmetricLossOptimized(nn.Module):
61
+ """Notice - optimized version, minimizes memory allocation and gpu uploading,
62
+ favors inplace operations"""
63
+
64
+ def __init__(
65
+ self,
66
+ gamma_neg=4,
67
+ gamma_pos=1,
68
+ clip=0.05,
69
+ eps=1e-8,
70
+ disable_torch_grad_focal_loss=False,
71
+ ):
72
+ super(AsymmetricLossOptimized, self).__init__()
73
+
74
+ self.gamma_neg = gamma_neg
75
+ self.gamma_pos = gamma_pos
76
+ self.clip = clip
77
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
78
+ self.eps = eps
79
+
80
+ # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
81
+ self.targets = self.anti_targets = self.xs_pos = self.xs_neg = (
82
+ self.asymmetric_w
83
+ ) = self.loss = None
84
+
85
+ def forward(self, x, y):
86
+ """ "
87
+ Parameters
88
+ ----------
89
+ x: input logits
90
+ y: targets (multi-label binarized vector)
91
+ """
92
+
93
+ self.targets = y
94
+ self.anti_targets = 1 - y
95
+
96
+ # Calculating Probabilities
97
+ self.xs_pos = torch.sigmoid(x)
98
+ self.xs_neg = 1.0 - self.xs_pos
99
+
100
+ # Asymmetric Clipping
101
+ if self.clip is not None and self.clip > 0:
102
+ self.xs_neg.add_(self.clip).clamp_(max=1)
103
+
104
+ # Basic CE calculation
105
+ self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
106
+ self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
107
+
108
+ # Asymmetric Focusing
109
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
110
+ if self.disable_torch_grad_focal_loss:
111
+ torch.set_grad_enabled(False)
112
+ self.xs_pos = self.xs_pos * self.targets
113
+ self.xs_neg = self.xs_neg * self.anti_targets
114
+ self.asymmetric_w = torch.pow(
115
+ 1 - self.xs_pos - self.xs_neg,
116
+ self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets,
117
+ )
118
+ if self.disable_torch_grad_focal_loss:
119
+ torch.set_grad_enabled(True)
120
+ self.loss *= self.asymmetric_w
121
+
122
+ return -self.loss.sum()
123
+
124
+
125
+ class ASLSingleLabel(nn.Module):
126
+ """
127
+ This loss is intended for single-label classification problems
128
+ """
129
+
130
+ def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction="mean"):
131
+ super(ASLSingleLabel, self).__init__()
132
+
133
+ self.eps = eps
134
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
135
+ self.targets_classes = []
136
+ self.gamma_pos = gamma_pos
137
+ self.gamma_neg = gamma_neg
138
+ self.reduction = reduction
139
+
140
+ def forward(self, inputs, target):
141
+ """
142
+ "input" dimensions: - (batch_size,number_classes)
143
+ "target" dimensions: - (batch_size)
144
+ """
145
+ num_classes = inputs.size()[-1]
146
+ log_preds = self.logsoftmax(inputs)
147
+ self.targets_classes = torch.zeros_like(inputs).scatter_(
148
+ 1, target.long().unsqueeze(1), 1
149
+ )
150
+
151
+ # ASL weights
152
+ targets = self.targets_classes
153
+ anti_targets = 1 - targets
154
+ xs_pos = torch.exp(log_preds)
155
+ xs_neg = 1 - xs_pos
156
+ xs_pos = xs_pos * targets
157
+ xs_neg = xs_neg * anti_targets
158
+ asymmetric_w = torch.pow(
159
+ 1 - xs_pos - xs_neg,
160
+ self.gamma_pos * targets + self.gamma_neg * anti_targets,
161
+ )
162
+ log_preds = log_preds * asymmetric_w
163
+
164
+ if self.eps > 0: # label smoothing
165
+ self.targets_classes = self.targets_classes.mul(1 - self.eps).add(
166
+ self.eps / num_classes
167
+ )
168
+
169
+ # loss calculation
170
+ loss = -self.targets_classes.mul(log_preds)
171
+
172
+ loss = loss.sum(dim=-1)
173
+ if self.reduction == "mean":
174
+ loss = loss.mean()
175
+
176
+ return loss
models.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ # モデルの定義
6
+ class AudioClassifier(nn.Module):
7
+ def __init__(
8
+ self,
9
+ label2id: dict,
10
+ feature_dim=256,
11
+ hidden_dim=256,
12
+ device="cpu",
13
+ dropout_rate=0.5,
14
+ num_hidden_layers=2,
15
+ ):
16
+ super(AudioClassifier, self).__init__()
17
+ self.num_classes = len(label2id)
18
+ self.device = device
19
+ self.label2id = label2id
20
+ self.id2label = {v: k for k, v in self.label2id.items()}
21
+ # 最初の線形層と活性化層を追加
22
+ self.fc1 = nn.Sequential(
23
+ nn.Linear(feature_dim, hidden_dim),
24
+ nn.BatchNorm1d(hidden_dim),
25
+ nn.Mish(),
26
+ nn.Dropout(dropout_rate),
27
+ )
28
+ # 隠れ層の追加
29
+ self.hidden_layers = nn.ModuleList()
30
+ for _ in range(num_hidden_layers):
31
+ layer = nn.Sequential(
32
+ nn.Linear(hidden_dim, hidden_dim),
33
+ nn.BatchNorm1d(hidden_dim),
34
+ nn.Mish(),
35
+ nn.Dropout(dropout_rate),
36
+ )
37
+ self.hidden_layers.append(layer)
38
+ # 最後の層(クラス分類用)
39
+ self.fc_last = nn.Linear(hidden_dim, self.num_classes)
40
+
41
+ def forward(self, x):
42
+ # 最初の層を通過
43
+ x = self.fc1(x)
44
+
45
+ # 隠れ層を順に通過
46
+ for layer in self.hidden_layers:
47
+ x = layer(x)
48
+
49
+ # 最後の分類層
50
+ x = self.fc_last(x)
51
+ return x
52
+
53
+ def infer_from_features(self, features):
54
+ # 特徴量をテンソルに変換
55
+ features = (
56
+ torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device)
57
+ )
58
+
59
+ # モデルを評価モードに設定
60
+ self.eval()
61
+
62
+ # モデルの出力を取得
63
+ with torch.no_grad():
64
+ output = self.forward(features)
65
+
66
+ # ソフトマックス関数を適用して確率を計算
67
+ probs = torch.softmax(output, dim=1)
68
+
69
+ # ラベルごとの確率を計算して大きい順に並べ替えて返す
70
+ probs, indices = torch.sort(probs, descending=True)
71
+ probs = probs.cpu().numpy().squeeze()
72
+ indices = indices.cpu().numpy().squeeze()
73
+ return [(self.id2label[i], p) for i, p in zip(indices, probs)]
74
+
75
+ def infer_from_file(self, file_path):
76
+ feature = extract_features(file_path, device=self.device)
77
+ return self.infer_from_features(feature)
78
+
79
+
80
+ from pyannote.audio import Inference, Model
81
+
82
+ emb_model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM")
83
+ inference = Inference(emb_model, window="whole")
84
+
85
+
86
+ def extract_features(file_path, device="cpu"):
87
+ inference.to(torch.device(device))
88
+ return inference(file_path)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ loguru
3
+ pyannote.audio
4
+ torch
train.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+
12
+ # import torch_optimizer as optim
13
+ import transformers
14
+ from sklearn.metrics import (
15
+ accuracy_score,
16
+ classification_report,
17
+ f1_score,
18
+ precision_score,
19
+ recall_score,
20
+ )
21
+ from torch.optim.lr_scheduler import (
22
+ CosineAnnealingLR,
23
+ CosineAnnealingWarmRestarts,
24
+ ExponentialLR,
25
+ )
26
+ from torch.utils.data import DataLoader, Dataset
27
+ from torch.utils.tensorboard import SummaryWriter
28
+ from tqdm import tqdm
29
+
30
+ from models import AudioClassifier, extract_features
31
+ from losses import AsymmetricLoss, ASLSingleLabel
32
+
33
+ torch.manual_seed(42)
34
+
35
+ label2id = {
36
+ "usual": 0,
37
+ "aegi": 1,
38
+ "chupa": 2,
39
+ # "cry": 3,
40
+ # "laugh": 4,
41
+ # "silent": 5,
42
+ # "unusual": 6,
43
+ }
44
+ id2label = {v: k for k, v in label2id.items()}
45
+
46
+
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument("--exp_dir", type=str, default="data")
49
+ parser.add_argument("--ckpt_dir", type=str, required=True)
50
+ parser.add_argument("--device", type=str, default="cuda")
51
+ parser.add_argument("--epochs", type=int, default=1000)
52
+ parser.add_argument("--save_every", type=int, default=100)
53
+
54
+ args = parser.parse_args()
55
+ device = args.device
56
+ if not torch.cuda.is_available():
57
+ print("No GPU detected. Using CPU.")
58
+ device = "cpu"
59
+ print(f"Using {device} for training.")
60
+
61
+
62
+ # データセットの定義
63
+ class AudioDataset(Dataset):
64
+ def __init__(self, file_paths, labels, features):
65
+ self.file_paths = file_paths
66
+ self.labels = labels
67
+ self.features = features
68
+
69
+ def __len__(self):
70
+ return len(self.file_paths)
71
+
72
+ def __getitem__(self, idx):
73
+ return self.features[idx], self.labels[idx]
74
+
75
+
76
+ def prepare_dataset(directory):
77
+ file_paths = list(Path(directory).rglob("*.npy"))
78
+ if len(file_paths) == 0:
79
+ return [], [], []
80
+ # file_paths = [f for f in file_paths if f.parent.name in label2id]
81
+
82
+ def process(file_path: Path):
83
+ npy_feature = np.load(file_path)
84
+ id = int(label2id[file_path.parent.name])
85
+ return (
86
+ file_path,
87
+ torch.tensor(id, dtype=torch.long).to(device),
88
+ torch.tensor(npy_feature, dtype=torch.float32).to(device),
89
+ )
90
+
91
+ with ThreadPoolExecutor(max_workers=10) as executor:
92
+ results = list(tqdm(executor.map(process, file_paths), total=len(file_paths)))
93
+
94
+ file_paths, labels, features = zip(*results)
95
+
96
+ return file_paths, labels, features
97
+
98
+
99
+ print("Preparing dataset...")
100
+
101
+ exp_dir = Path(args.exp_dir)
102
+ train_file_paths, train_labels, train_feats = prepare_dataset(exp_dir / "train")
103
+ val_file_paths, val_labels, val_feats = prepare_dataset(exp_dir / "val")
104
+
105
+ print(f"Train: {len(train_file_paths)}, Val: {len(val_file_paths)}")
106
+
107
+ # データセットとデータローダーの準備
108
+ train_dataset = AudioDataset(train_file_paths, train_labels, train_feats)
109
+ print("Train dataset prepared.")
110
+ train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
111
+ print("Train loader prepared.")
112
+ if len(val_file_paths) == 0:
113
+ val_dataset = None
114
+ val_loader = None
115
+ print("No validation dataset found.")
116
+ else:
117
+ val_dataset = AudioDataset(val_file_paths, val_labels, val_feats)
118
+ print("Val dataset prepared.")
119
+ val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
120
+ print("Val loader prepared.")
121
+
122
+
123
+ # モデル、損失関数、最適化アルゴリズムの設定
124
+ config = {
125
+ "model": {
126
+ "label2id": label2id,
127
+ "num_hidden_layers": 2,
128
+ "hidden_dim": 128,
129
+ },
130
+ "lr": 1e-3,
131
+ "lr_decay": 0.996,
132
+ }
133
+ model = AudioClassifier(device="cuda", **config["model"]).to(device)
134
+ model.to(device)
135
+ # criterion = nn.CrossEntropyLoss()
136
+ criterion = ASLSingleLabel(gamma_pos=1, gamma_neg=4)
137
+ optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-2)
138
+ scheduler = ExponentialLR(optimizer, gamma=config["lr_decay"])
139
+ # scheduler = transformers.optimization.AdafactorSchedule(optimizer)
140
+ num_epochs = args.epochs
141
+ # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
142
+
143
+ print("Start training...")
144
+ current_time = datetime.now().strftime("%b%d_%H-%M-%S")
145
+ ckpt_dir = Path(args.ckpt_dir) / current_time
146
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
147
+ # Save config
148
+ with open(ckpt_dir / "config.json", "w", encoding="utf-8") as f:
149
+ json.dump(config, f, indent=4)
150
+ # 訓練ループ
151
+ save_every = args.save_every
152
+ val_interval = 1
153
+ eval_interval = 1
154
+
155
+ writer = SummaryWriter(ckpt_dir / "logs")
156
+ for epoch in tqdm(range(1, num_epochs + 1)):
157
+ train_loss = 0.0
158
+ model.train() # 訓練モードに設定
159
+ train_labels = []
160
+ train_preds = []
161
+ for inputs, labels in train_loader:
162
+ inputs, labels = inputs.to(device), labels.to(device)
163
+
164
+ # 順伝播、損失の計算、逆伝播、パラメータ更新
165
+ optimizer.zero_grad()
166
+ outputs = model(inputs)
167
+ loss = criterion(outputs.squeeze(), labels)
168
+ loss.backward()
169
+ optimizer.step()
170
+ train_loss += loss.item()
171
+
172
+ # 評価指標の計算
173
+ if epoch % eval_interval == 0:
174
+ with torch.no_grad():
175
+ # 最も高い確率を持つクラスのインデックスを取得
176
+ _, predictions = torch.max(outputs, 1)
177
+
178
+ # 実際のラベルと予測値をリストに追加
179
+ train_labels.extend(labels.cpu().numpy())
180
+ train_preds.extend(predictions.cpu().numpy())
181
+
182
+ scheduler.step()
183
+ if epoch % eval_interval == 0:
184
+ # 訓練データに対する評価指標の計算
185
+ accuracy = accuracy_score(train_labels, train_preds)
186
+ precision = precision_score(train_labels, train_preds, average="macro")
187
+ recall = recall_score(train_labels, train_preds, average="macro")
188
+ f1 = f1_score(train_labels, train_preds, average="macro")
189
+ report = classification_report(
190
+ train_labels, train_preds, target_names=list(label2id.keys())
191
+ )
192
+
193
+ writer.add_scalar("train/Accuracy", accuracy, epoch)
194
+ writer.add_scalar("train/Precision", precision, epoch)
195
+ writer.add_scalar("train/Recall", recall, epoch)
196
+ writer.add_scalar("train/F1", f1, epoch)
197
+
198
+ writer.add_scalar("Loss/train", train_loss / len(train_loader), epoch)
199
+ writer.add_scalar("Learning Rate", optimizer.param_groups[0]["lr"], epoch)
200
+
201
+ if epoch % save_every == 0:
202
+ torch.save(model.state_dict(), ckpt_dir / f"model_{epoch}.pth")
203
+
204
+ if epoch % val_interval != 0 or val_loader is None:
205
+ tqdm.write(f"loss: {train_loss / len(train_loader):4f}\n{report}")
206
+ continue
207
+ model.eval() # 評価モードに設定
208
+ val_labels = []
209
+ val_preds = []
210
+ val_loss = 0.0
211
+ with torch.no_grad():
212
+ for inputs, labels in val_loader:
213
+ inputs, labels = inputs.to(device), labels.to(device)
214
+ outputs = model(inputs)
215
+ # 最も高い確率を持つクラスのインデックスを取得
216
+ _, predictions = torch.max(outputs, 1)
217
+ val_labels.extend(labels.cpu().numpy())
218
+ val_preds.extend(predictions.cpu().numpy())
219
+ loss = criterion(outputs.squeeze(), labels)
220
+ val_loss += loss.item()
221
+
222
+ # 評価指標の計算
223
+ accuracy = accuracy_score(val_labels, val_preds)
224
+ precision = precision_score(val_labels, val_preds, average="macro")
225
+ recall = recall_score(val_labels, val_preds, average="macro")
226
+ f1 = f1_score(val_labels, val_preds, average="macro")
227
+ report = classification_report(
228
+ val_labels, val_preds, target_names=list(label2id.keys())
229
+ )
230
+
231
+ writer.add_scalar("Loss/val", val_loss / len(val_loader), epoch)
232
+ writer.add_scalar("val/Accuracy", accuracy, epoch)
233
+ writer.add_scalar("val/Precision", precision, epoch)
234
+ writer.add_scalar("val/Recall", recall, epoch)
235
+ writer.add_scalar("val/F1", f1, epoch)
236
+
237
+ tqdm.write(
238
+ f"loss: {train_loss / len(train_loader):4f}, val loss: {val_loss / len(val_loader):4f}, "
239
+ f"acc: {accuracy:4f}, f1: {f1:4f}, prec: {precision:4f}, recall: {recall:4f}\n{report}"
240
+ )
241
+ # tqdm.write(report)
242
+ # Save
243
+ torch.save(model.state_dict(), ckpt_dir / "model_final.pth")
utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import wave
4
+ from pydub import AudioSegment
5
+ import loguru
6
+
7
+
8
+ def is_audio_file(file: Path):
9
+ return file.suffix.lower() in [".wav", ".mp3", ".ogg"]
10
+
11
+
12
+ def get_audio_duration_ms(file_path):
13
+ try:
14
+ with wave.open(str(file_path), "r") as wav_file:
15
+ return wav_file.getnframes() / wav_file.getframerate() * 1000
16
+ except wave.Error as e:
17
+ audio = AudioSegment.from_file(file_path)
18
+ return len(audio)
19
+ except Exception as e:
20
+ raise e
21
+
22
+
23
+ logger = loguru.logger
24
+ logger.remove()
25
+
26
+ log_format = (
27
+ "<g>{time:MM-DD HH:mm:ss}</g> |<lvl>{level:^8}</lvl>| {file}:{line} | {message}"
28
+ )
29
+ logger.add(sys.stdout, format=log_format)