Spaces:
Runtime error
Runtime error
root
commited on
Commit
·
cd07cf0
1
Parent(s):
49bb3b0
initial commit
Browse files- configs/hico_train.sh +4 -4
- hotr/engine/trainer.py +1 -1
configs/hico_train.sh
CHANGED
@@ -14,18 +14,18 @@ python -u main.py \
|
|
14 |
--pretrained_dec \
|
15 |
--use_consis \
|
16 |
--share_dec_param \
|
17 |
-
--epochs
|
18 |
-
--lr_drop
|
19 |
--lr 1e-4 \
|
20 |
--lr_backbone 1e-5 \
|
21 |
-
--ramp_up_epoch
|
22 |
--path_id 0 \
|
23 |
--num_hoi_queries 16 \
|
24 |
--set_cost_idx 20 \
|
25 |
--hoi_idx_loss_coef 1 \
|
26 |
--hoi_act_loss_coef 10 \
|
27 |
--backbone resnet50 \
|
28 |
-
--hoi_consistency_loss_coef 0.
|
29 |
--hoi_idx_consistency_loss_coef 1 \
|
30 |
--hoi_act_consistency_loss_coef 2 \
|
31 |
--hoi_eos_coef 0.1 \
|
|
|
14 |
--pretrained_dec \
|
15 |
--use_consis \
|
16 |
--share_dec_param \
|
17 |
+
--epochs 90 \
|
18 |
+
--lr_drop 60 \
|
19 |
--lr 1e-4 \
|
20 |
--lr_backbone 1e-5 \
|
21 |
+
--ramp_up_epoch 30 \
|
22 |
--path_id 0 \
|
23 |
--num_hoi_queries 16 \
|
24 |
--set_cost_idx 20 \
|
25 |
--hoi_idx_loss_coef 1 \
|
26 |
--hoi_act_loss_coef 10 \
|
27 |
--backbone resnet50 \
|
28 |
+
--hoi_consistency_loss_coef 0.1 \
|
29 |
--hoi_idx_consistency_loss_coef 1 \
|
30 |
--hoi_act_consistency_loss_coef 2 \
|
31 |
--hoi_eos_coef 0.1 \
|
hotr/engine/trainer.py
CHANGED
@@ -29,7 +29,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
|
29 |
consis_coef=sigmoid_rampup(epoch,ramp_up_epoch,max_consis_coef)
|
30 |
else:
|
31 |
consis_coef=cosine_rampdown(epoch-rampdown_epoch,max_epoch-rampdown_epoch,max_consis_coef)
|
32 |
-
|
33 |
print(f"\n>>> Epoch #{(epoch+1)}")
|
34 |
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
35 |
samples = samples.to(device)
|
|
|
29 |
consis_coef=sigmoid_rampup(epoch,ramp_up_epoch,max_consis_coef)
|
30 |
else:
|
31 |
consis_coef=cosine_rampdown(epoch-rampdown_epoch,max_epoch-rampdown_epoch,max_consis_coef)
|
32 |
+
|
33 |
print(f"\n>>> Epoch #{(epoch+1)}")
|
34 |
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
35 |
samples = samples.to(device)
|