Spaces:
Runtime error
Runtime error
Fabrice-TIERCELIN
commited on
Commit
•
84aa88f
1
Parent(s):
2e5b5b2
Delete clipseg/training.py
Browse files- clipseg/training.py +0 -266
clipseg/training.py
DELETED
@@ -1,266 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import inspect
|
3 |
-
import json
|
4 |
-
import yaml
|
5 |
-
import math
|
6 |
-
import os
|
7 |
-
import sys
|
8 |
-
|
9 |
-
from general_utils import log
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
from functools import partial
|
13 |
-
from os.path import expanduser, join, isfile, basename
|
14 |
-
|
15 |
-
from torch.cuda.amp import autocast, GradScaler
|
16 |
-
from torch.optim.lr_scheduler import LambdaLR
|
17 |
-
from contextlib import nullcontext
|
18 |
-
from torch.utils.data import DataLoader
|
19 |
-
|
20 |
-
from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args
|
21 |
-
|
22 |
-
|
23 |
-
def cosine_warmup_lr(i, warmup=10, max_iter=90):
|
24 |
-
""" Cosine LR with Warmup """
|
25 |
-
if i < warmup:
|
26 |
-
return (i+1)/(warmup+1)
|
27 |
-
else:
|
28 |
-
return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup))))
|
29 |
-
|
30 |
-
|
31 |
-
def validate(model, dataset, config):
|
32 |
-
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
|
33 |
-
|
34 |
-
metric_class, use_metric = config.val_metric_class, config.use_val_metric
|
35 |
-
loss_fn = get_attribute(config.loss)
|
36 |
-
|
37 |
-
model.eval()
|
38 |
-
model.cuda()
|
39 |
-
|
40 |
-
if metric_class is not None:
|
41 |
-
metric = get_attribute(metric_class)()
|
42 |
-
|
43 |
-
with torch.no_grad():
|
44 |
-
|
45 |
-
i, losses = 0, []
|
46 |
-
for data_x, data_y in data_loader:
|
47 |
-
|
48 |
-
data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
|
49 |
-
data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
|
50 |
-
|
51 |
-
prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',))
|
52 |
-
pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True)
|
53 |
-
|
54 |
-
if metric_class is not None:
|
55 |
-
metric.add([pred], data_y)
|
56 |
-
|
57 |
-
# pred = model(data_x[0], prompts)
|
58 |
-
# loss = loss_fn(pred[0], data_y[0])
|
59 |
-
loss = loss_fn(pred, data_y[0])
|
60 |
-
losses += [float(loss)]
|
61 |
-
|
62 |
-
i += 1
|
63 |
-
|
64 |
-
if config.val_max_iterations is not None and i > config.val_max_iterations:
|
65 |
-
break
|
66 |
-
|
67 |
-
if use_metric is None:
|
68 |
-
return np.mean(losses), {}, False
|
69 |
-
else:
|
70 |
-
metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {}
|
71 |
-
return np.mean(losses), metric_scores, True
|
72 |
-
|
73 |
-
|
74 |
-
def main():
|
75 |
-
|
76 |
-
config = training_config_from_cli_args()
|
77 |
-
|
78 |
-
val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf')
|
79 |
-
|
80 |
-
model_cls = get_attribute(config.model)
|
81 |
-
_, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
|
82 |
-
model = model_cls(**model_args).cuda()
|
83 |
-
|
84 |
-
dataset_cls = get_attribute(config.dataset)
|
85 |
-
_, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
|
86 |
-
|
87 |
-
dataset = dataset_cls(**dataset_args)
|
88 |
-
|
89 |
-
log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})')
|
90 |
-
|
91 |
-
if val_interval is not None:
|
92 |
-
dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'}
|
93 |
-
_, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters)
|
94 |
-
print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
|
95 |
-
|
96 |
-
dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
|
97 |
-
|
98 |
-
# optimizer
|
99 |
-
opt_cls = get_attribute(config.optimizer)
|
100 |
-
if config.optimize == 'torch.optim.SGD':
|
101 |
-
opt_args = {'momentum': config.momentum if 'momentum' in config else 0}
|
102 |
-
else:
|
103 |
-
opt_args = {}
|
104 |
-
opt = opt_cls(model.parameters(), lr=config.lr, **opt_args)
|
105 |
-
|
106 |
-
if config.lr_scheduler == 'cosine':
|
107 |
-
assert config.T_max is not None and config.eta_min is not None
|
108 |
-
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min)
|
109 |
-
elif config.lr_scheduler == 'warmup_cosine':
|
110 |
-
lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup))
|
111 |
-
else:
|
112 |
-
lr_scheduler = None
|
113 |
-
|
114 |
-
batch_size, max_iterations = config.batch_size, config.max_iterations
|
115 |
-
|
116 |
-
loss_fn = get_attribute(config.loss)
|
117 |
-
|
118 |
-
if config.amp:
|
119 |
-
log.info('Using AMP')
|
120 |
-
autocast_fn = autocast
|
121 |
-
scaler = GradScaler()
|
122 |
-
else:
|
123 |
-
autocast_fn, scaler = nullcontext, None
|
124 |
-
|
125 |
-
|
126 |
-
save_only_trainable = True
|
127 |
-
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
|
128 |
-
|
129 |
-
# disable config when hyperparam. opt. to avoid writing logs.
|
130 |
-
tracker_config = config if not config.hyperparameter_optimization else None
|
131 |
-
|
132 |
-
with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger:
|
133 |
-
|
134 |
-
i = 0
|
135 |
-
while True:
|
136 |
-
for data_x, data_y in data_loader:
|
137 |
-
|
138 |
-
# between caption and output feature.
|
139 |
-
# 1. Sample random captions
|
140 |
-
# 2. Check alignment with CLIP
|
141 |
-
|
142 |
-
# randomly mix text and visual support conditionals
|
143 |
-
if config.mix:
|
144 |
-
|
145 |
-
assert config.mask.startswith('text_and')
|
146 |
-
|
147 |
-
with autocast_fn():
|
148 |
-
# data_x[1] = text label
|
149 |
-
prompts = model.sample_prompts(data_x[1])
|
150 |
-
|
151 |
-
# model.clip_model()
|
152 |
-
|
153 |
-
text_cond = model.compute_conditional(prompts)
|
154 |
-
if model.__class__.__name__ == 'CLIPDensePredTMasked':
|
155 |
-
# when mask=='separate'
|
156 |
-
visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda())
|
157 |
-
else:
|
158 |
-
# data_x[2] = visual prompt
|
159 |
-
visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda())
|
160 |
-
|
161 |
-
max_txt = config.mix_text_max if config.mix_text_max is not None else 1
|
162 |
-
batch_size = text_cond.shape[0]
|
163 |
-
|
164 |
-
# sample weights for each element in batch
|
165 |
-
text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None]
|
166 |
-
text_weights = text_weights.cuda()
|
167 |
-
|
168 |
-
if dataset.__class__.__name__ == 'PhraseCut':
|
169 |
-
# give full weight to text where support_image is invalid
|
170 |
-
visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3]
|
171 |
-
text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1)
|
172 |
-
|
173 |
-
cond = text_cond * text_weights + visual_s_cond * (1 - text_weights)
|
174 |
-
|
175 |
-
else:
|
176 |
-
# no mix
|
177 |
-
|
178 |
-
if model.__class__.__name__ == 'CLIPDensePredTMasked':
|
179 |
-
# compute conditional vector using CLIP masking
|
180 |
-
with autocast_fn():
|
181 |
-
assert config.mask == 'separate'
|
182 |
-
cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda())
|
183 |
-
else:
|
184 |
-
cond = data_x[1]
|
185 |
-
if isinstance(cond, torch.Tensor):
|
186 |
-
cond = cond.cuda()
|
187 |
-
|
188 |
-
with autocast_fn():
|
189 |
-
visual_q = None
|
190 |
-
|
191 |
-
pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True)
|
192 |
-
|
193 |
-
loss = loss_fn(pred, data_y[0].cuda())
|
194 |
-
|
195 |
-
if torch.isnan(loss) or torch.isinf(loss):
|
196 |
-
# skip if loss is nan
|
197 |
-
log.warning('Training stopped due to inf/nan loss.')
|
198 |
-
sys.exit(-1)
|
199 |
-
|
200 |
-
extra_loss = 0
|
201 |
-
loss += extra_loss
|
202 |
-
|
203 |
-
opt.zero_grad()
|
204 |
-
|
205 |
-
if scaler is None:
|
206 |
-
loss.backward()
|
207 |
-
opt.step()
|
208 |
-
else:
|
209 |
-
scaler.scale(loss).backward()
|
210 |
-
scaler.step(opt)
|
211 |
-
scaler.update()
|
212 |
-
|
213 |
-
if lr_scheduler is not None:
|
214 |
-
lr_scheduler.step()
|
215 |
-
if i % 2000 == 0:
|
216 |
-
current_lr = [g['lr'] for g in opt.param_groups][0]
|
217 |
-
log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)')
|
218 |
-
|
219 |
-
logger.iter(i=i, loss=loss)
|
220 |
-
i += 1
|
221 |
-
|
222 |
-
if i >= max_iterations:
|
223 |
-
|
224 |
-
if not isfile(join(logger.base_path, 'weights.pth')):
|
225 |
-
# only write if no weights were already written
|
226 |
-
logger.save_weights(only_trainable=save_only_trainable)
|
227 |
-
|
228 |
-
sys.exit(0)
|
229 |
-
|
230 |
-
|
231 |
-
if config.checkpoint_iterations is not None and i in config.checkpoint_iterations:
|
232 |
-
logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth')
|
233 |
-
|
234 |
-
|
235 |
-
if val_interval is not None and i % val_interval == val_interval - 1:
|
236 |
-
|
237 |
-
val_loss, val_scores, maximize = validate(model, dataset_val, config)
|
238 |
-
|
239 |
-
if len(val_scores) > 0:
|
240 |
-
|
241 |
-
score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items())
|
242 |
-
|
243 |
-
if maximize and val_scores[config.use_val_metric] > best_val_score:
|
244 |
-
logger.save_weights(only_trainable=save_only_trainable)
|
245 |
-
best_val_score = val_scores[config.use_val_metric]
|
246 |
-
|
247 |
-
elif not maximize and val_scores[config.use_val_metric] < best_val_score:
|
248 |
-
logger.save_weights(only_trainable=save_only_trainable)
|
249 |
-
best_val_score = val_scores[config.use_val_metric]
|
250 |
-
|
251 |
-
else:
|
252 |
-
score_str = ''
|
253 |
-
# if no score is used, fall back to loss
|
254 |
-
if val_loss < best_val_loss:
|
255 |
-
logger.save_weights(only_trainable=save_only_trainable)
|
256 |
-
best_val_loss = val_loss
|
257 |
-
|
258 |
-
log.info(f'Validation loss: {val_loss}' + score_str)
|
259 |
-
logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores)
|
260 |
-
model.train()
|
261 |
-
|
262 |
-
print('epoch complete')
|
263 |
-
|
264 |
-
|
265 |
-
if __name__ == '__main__':
|
266 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|