Fabrice-TIERCELIN commited on
Commit
84aa88f
1 Parent(s): 2e5b5b2

Delete clipseg/training.py

Browse files
Files changed (1) hide show
  1. clipseg/training.py +0 -266
clipseg/training.py DELETED
@@ -1,266 +0,0 @@
1
- import torch
2
- import inspect
3
- import json
4
- import yaml
5
- import math
6
- import os
7
- import sys
8
-
9
- from general_utils import log
10
-
11
- import numpy as np
12
- from functools import partial
13
- from os.path import expanduser, join, isfile, basename
14
-
15
- from torch.cuda.amp import autocast, GradScaler
16
- from torch.optim.lr_scheduler import LambdaLR
17
- from contextlib import nullcontext
18
- from torch.utils.data import DataLoader
19
-
20
- from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args
21
-
22
-
23
- def cosine_warmup_lr(i, warmup=10, max_iter=90):
24
- """ Cosine LR with Warmup """
25
- if i < warmup:
26
- return (i+1)/(warmup+1)
27
- else:
28
- return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup))))
29
-
30
-
31
- def validate(model, dataset, config):
32
- data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
33
-
34
- metric_class, use_metric = config.val_metric_class, config.use_val_metric
35
- loss_fn = get_attribute(config.loss)
36
-
37
- model.eval()
38
- model.cuda()
39
-
40
- if metric_class is not None:
41
- metric = get_attribute(metric_class)()
42
-
43
- with torch.no_grad():
44
-
45
- i, losses = 0, []
46
- for data_x, data_y in data_loader:
47
-
48
- data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
49
- data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
50
-
51
- prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',))
52
- pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True)
53
-
54
- if metric_class is not None:
55
- metric.add([pred], data_y)
56
-
57
- # pred = model(data_x[0], prompts)
58
- # loss = loss_fn(pred[0], data_y[0])
59
- loss = loss_fn(pred, data_y[0])
60
- losses += [float(loss)]
61
-
62
- i += 1
63
-
64
- if config.val_max_iterations is not None and i > config.val_max_iterations:
65
- break
66
-
67
- if use_metric is None:
68
- return np.mean(losses), {}, False
69
- else:
70
- metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {}
71
- return np.mean(losses), metric_scores, True
72
-
73
-
74
- def main():
75
-
76
- config = training_config_from_cli_args()
77
-
78
- val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf')
79
-
80
- model_cls = get_attribute(config.model)
81
- _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
82
- model = model_cls(**model_args).cuda()
83
-
84
- dataset_cls = get_attribute(config.dataset)
85
- _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
86
-
87
- dataset = dataset_cls(**dataset_args)
88
-
89
- log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})')
90
-
91
- if val_interval is not None:
92
- dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'}
93
- _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters)
94
- print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
95
-
96
- dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
97
-
98
- # optimizer
99
- opt_cls = get_attribute(config.optimizer)
100
- if config.optimize == 'torch.optim.SGD':
101
- opt_args = {'momentum': config.momentum if 'momentum' in config else 0}
102
- else:
103
- opt_args = {}
104
- opt = opt_cls(model.parameters(), lr=config.lr, **opt_args)
105
-
106
- if config.lr_scheduler == 'cosine':
107
- assert config.T_max is not None and config.eta_min is not None
108
- lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min)
109
- elif config.lr_scheduler == 'warmup_cosine':
110
- lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup))
111
- else:
112
- lr_scheduler = None
113
-
114
- batch_size, max_iterations = config.batch_size, config.max_iterations
115
-
116
- loss_fn = get_attribute(config.loss)
117
-
118
- if config.amp:
119
- log.info('Using AMP')
120
- autocast_fn = autocast
121
- scaler = GradScaler()
122
- else:
123
- autocast_fn, scaler = nullcontext, None
124
-
125
-
126
- save_only_trainable = True
127
- data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
128
-
129
- # disable config when hyperparam. opt. to avoid writing logs.
130
- tracker_config = config if not config.hyperparameter_optimization else None
131
-
132
- with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger:
133
-
134
- i = 0
135
- while True:
136
- for data_x, data_y in data_loader:
137
-
138
- # between caption and output feature.
139
- # 1. Sample random captions
140
- # 2. Check alignment with CLIP
141
-
142
- # randomly mix text and visual support conditionals
143
- if config.mix:
144
-
145
- assert config.mask.startswith('text_and')
146
-
147
- with autocast_fn():
148
- # data_x[1] = text label
149
- prompts = model.sample_prompts(data_x[1])
150
-
151
- # model.clip_model()
152
-
153
- text_cond = model.compute_conditional(prompts)
154
- if model.__class__.__name__ == 'CLIPDensePredTMasked':
155
- # when mask=='separate'
156
- visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda())
157
- else:
158
- # data_x[2] = visual prompt
159
- visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda())
160
-
161
- max_txt = config.mix_text_max if config.mix_text_max is not None else 1
162
- batch_size = text_cond.shape[0]
163
-
164
- # sample weights for each element in batch
165
- text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None]
166
- text_weights = text_weights.cuda()
167
-
168
- if dataset.__class__.__name__ == 'PhraseCut':
169
- # give full weight to text where support_image is invalid
170
- visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3]
171
- text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1)
172
-
173
- cond = text_cond * text_weights + visual_s_cond * (1 - text_weights)
174
-
175
- else:
176
- # no mix
177
-
178
- if model.__class__.__name__ == 'CLIPDensePredTMasked':
179
- # compute conditional vector using CLIP masking
180
- with autocast_fn():
181
- assert config.mask == 'separate'
182
- cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda())
183
- else:
184
- cond = data_x[1]
185
- if isinstance(cond, torch.Tensor):
186
- cond = cond.cuda()
187
-
188
- with autocast_fn():
189
- visual_q = None
190
-
191
- pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True)
192
-
193
- loss = loss_fn(pred, data_y[0].cuda())
194
-
195
- if torch.isnan(loss) or torch.isinf(loss):
196
- # skip if loss is nan
197
- log.warning('Training stopped due to inf/nan loss.')
198
- sys.exit(-1)
199
-
200
- extra_loss = 0
201
- loss += extra_loss
202
-
203
- opt.zero_grad()
204
-
205
- if scaler is None:
206
- loss.backward()
207
- opt.step()
208
- else:
209
- scaler.scale(loss).backward()
210
- scaler.step(opt)
211
- scaler.update()
212
-
213
- if lr_scheduler is not None:
214
- lr_scheduler.step()
215
- if i % 2000 == 0:
216
- current_lr = [g['lr'] for g in opt.param_groups][0]
217
- log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)')
218
-
219
- logger.iter(i=i, loss=loss)
220
- i += 1
221
-
222
- if i >= max_iterations:
223
-
224
- if not isfile(join(logger.base_path, 'weights.pth')):
225
- # only write if no weights were already written
226
- logger.save_weights(only_trainable=save_only_trainable)
227
-
228
- sys.exit(0)
229
-
230
-
231
- if config.checkpoint_iterations is not None and i in config.checkpoint_iterations:
232
- logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth')
233
-
234
-
235
- if val_interval is not None and i % val_interval == val_interval - 1:
236
-
237
- val_loss, val_scores, maximize = validate(model, dataset_val, config)
238
-
239
- if len(val_scores) > 0:
240
-
241
- score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items())
242
-
243
- if maximize and val_scores[config.use_val_metric] > best_val_score:
244
- logger.save_weights(only_trainable=save_only_trainable)
245
- best_val_score = val_scores[config.use_val_metric]
246
-
247
- elif not maximize and val_scores[config.use_val_metric] < best_val_score:
248
- logger.save_weights(only_trainable=save_only_trainable)
249
- best_val_score = val_scores[config.use_val_metric]
250
-
251
- else:
252
- score_str = ''
253
- # if no score is used, fall back to loss
254
- if val_loss < best_val_loss:
255
- logger.save_weights(only_trainable=save_only_trainable)
256
- best_val_loss = val_loss
257
-
258
- log.info(f'Validation loss: {val_loss}' + score_str)
259
- logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores)
260
- model.train()
261
-
262
- print('epoch complete')
263
-
264
-
265
- if __name__ == '__main__':
266
- main()