Upload 115 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- PreTrain_MeDSLIP/configs/Pretrain_MeDSLIP.yaml +34 -0
- PreTrain_MeDSLIP/data_file/observation explanation.json +77 -0
- PreTrain_MeDSLIP/data_file/preprocessing/adj_matrix.py +891 -0
- PreTrain_MeDSLIP/data_file/preprocessing/radgraph_itemized.py +139 -0
- PreTrain_MeDSLIP/data_file/preprocessing/radgraph_parsed.py +322 -0
- PreTrain_MeDSLIP/dataset/dataset.py +310 -0
- PreTrain_MeDSLIP/dataset/randaugment.py +346 -0
- PreTrain_MeDSLIP/models/__init__.py +0 -0
- PreTrain_MeDSLIP/models/model_MeDSLIP.py +530 -0
- PreTrain_MeDSLIP/models/tokenization_bert.py +578 -0
- PreTrain_MeDSLIP/models/transformer.py +210 -0
- PreTrain_MeDSLIP/optim/__init__.py +13 -0
- PreTrain_MeDSLIP/optim/adafactor.py +206 -0
- PreTrain_MeDSLIP/optim/adahessian.py +207 -0
- PreTrain_MeDSLIP/optim/adamp.py +133 -0
- PreTrain_MeDSLIP/optim/adamw.py +131 -0
- PreTrain_MeDSLIP/optim/lookahead.py +96 -0
- PreTrain_MeDSLIP/optim/nadam.py +108 -0
- PreTrain_MeDSLIP/optim/novograd.py +90 -0
- PreTrain_MeDSLIP/optim/nvnovograd.py +132 -0
- PreTrain_MeDSLIP/optim/optim_factory.py +138 -0
- PreTrain_MeDSLIP/optim/radam.py +170 -0
- PreTrain_MeDSLIP/optim/rmsprop_tf.py +160 -0
- PreTrain_MeDSLIP/optim/sgdp.py +123 -0
- PreTrain_MeDSLIP/scheduler/__init__.py +5 -0
- PreTrain_MeDSLIP/scheduler/cosine_lr.py +136 -0
- PreTrain_MeDSLIP/scheduler/plateau_lr.py +116 -0
- PreTrain_MeDSLIP/scheduler/scheduler.py +120 -0
- PreTrain_MeDSLIP/scheduler/scheduler_factory.py +87 -0
- PreTrain_MeDSLIP/scheduler/step_lr.py +73 -0
- PreTrain_MeDSLIP/scheduler/tanh_lr.py +141 -0
- PreTrain_MeDSLIP/train_MeDSLIP.py +446 -0
- PreTrain_MeDSLIP/utils.py +277 -0
- README.md +49 -3
- Sample_Finetuning_SIIMACR/I1_classification/configs/Res_train.yaml +17 -0
- Sample_Finetuning_SIIMACR/I1_classification/dataset/dataset_siim_acr.py +124 -0
- Sample_Finetuning_SIIMACR/I1_classification/dataset/randaugment.py +346 -0
- Sample_Finetuning_SIIMACR/I1_classification/models/resnet.py +88 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/__init__.py +13 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/adafactor.py +206 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/adahessian.py +207 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/adamp.py +133 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/adamw.py +131 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/lookahead.py +96 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/nadam.py +108 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/novograd.py +90 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/nvnovograd.py +132 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/optim_factory.py +138 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/radam.py +170 -0
- Sample_Finetuning_SIIMACR/I1_classification/optim/rmsprop_tf.py +160 -0
PreTrain_MeDSLIP/configs/Pretrain_MeDSLIP.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_file: "setting/rad_graph_metric_train_local.json"
|
2 |
+
valid_file: "setting/rad_graph_metric_validate_local.json"
|
3 |
+
test_file: "setting/rad_graph_metric_test_local.json"
|
4 |
+
label_file: "setting/landmark_observation_adj_mtx.npy"
|
5 |
+
pathology_book: "PreTrain_MeDSLIP/data_file/observation explanation.json"
|
6 |
+
|
7 |
+
image_res: 224
|
8 |
+
patch_size: 16
|
9 |
+
num_sentences: 12
|
10 |
+
num_tokens: 32
|
11 |
+
vision_width: 768
|
12 |
+
fea_width: 197
|
13 |
+
embed_dim: 256
|
14 |
+
batch_size: 64
|
15 |
+
test_batch_size: 32
|
16 |
+
temp: 0.07
|
17 |
+
mlm_probability: 0.15
|
18 |
+
queue_size: 8192
|
19 |
+
momentum: 0.995
|
20 |
+
alpha: 0.4
|
21 |
+
d_model: 256
|
22 |
+
res_base_model: "resnet50"
|
23 |
+
num_queries: 75
|
24 |
+
dropout: 0.1
|
25 |
+
attribute_set_size: 2
|
26 |
+
N: 4
|
27 |
+
H: 4
|
28 |
+
no_cl: False
|
29 |
+
|
30 |
+
exclude_class: False
|
31 |
+
text_encoder: "emilyalsentzer/Bio_ClinicalBERT"
|
32 |
+
shuffle_ratio: 0.5
|
33 |
+
optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.02}
|
34 |
+
schedular: {sched: cosine, lr: 1e-4, epochs: 100, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 5, cooldown_epochs: 0}
|
PreTrain_MeDSLIP/data_file/observation explanation.json
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"normal": "It means the absence of diseases and infirmity, indicating the structure is normal.",
|
3 |
+
"clear": "The lungs are clear and normal. No evidence for other diseases on lung.",
|
4 |
+
"sharp": "This means that an anatomical structure's boundary or edge is clear and normal, meaning it is free of diseases.",
|
5 |
+
"sharply": "\u2018Sharply seen\u2019 means that an anatomical structure is clearly visible.",
|
6 |
+
"unremarkable": "This represents some anatomical structures are normal, usually modifying cardiac and mediastinal silhouettes.",
|
7 |
+
"intact": "The bonny structure is complete and normal, meaning no fractures.",
|
8 |
+
"stable": "The modified anatomical structures are normal and stable. No evidence for diseases.",
|
9 |
+
"free": "It usually refers to free air and is associate with pneumothorax,atelectasis,pneumoperitoneum and emphysema.",
|
10 |
+
"effusion": "A pleural effusion is accumulation of excessive fluid in the pleural space, the potential space that surrounds each lung. A pleural effusion infiltrates the space between the visceral pleura and the parietal pleura",
|
11 |
+
"opacity": "It is defined as an area of hazy opacification due to air displacement by fluid, airway collapse, fibrosis, or a neoplastic process. It is causes include infections, interstitial lung disease, and pulmonary edema.",
|
12 |
+
"pneumothorax": "A pneumothorax is an abnormal collection of air in the pleural space between the lung and the chest wall. It may be caused by pneumonia or fibrosis and other diseases. ",
|
13 |
+
"edema": "Pulmonary edema, also known as pulmonary congestion, is excessive liquid accumulation in the tissue and air spaces of the lungs. It will show fluid in the alveolar walls",
|
14 |
+
"atelectasis": "It is the collapse or closure of a lung resulting in reduced or absent gas exchange. Findings can include lung opacification and loss of lung volume.",
|
15 |
+
"tube": "it is a surgical drain that is inserted through the chest wall and into the pleural space or the mediastinum to remove undesired substances such as air (pneumothorax), excess fluid (pleural effusion or hydrothorax), blood (hemothorax), chyle (chylothorax) or pus (empyema) from the intrathoracic space.",
|
16 |
+
"consolidation": "it is a region of normally compressible lung tissue that has filled with liquid instead of air. Consolidation must be present to diagnose pneumonia: the signs of lobar pneumonia are characteristic and clinically referred to as consolidation.",
|
17 |
+
"process": "Acute process' means there is abnormality in the anotomy structure. ",
|
18 |
+
"abnormality": "It means the exist of diseases and infirmity, indicating the structure is abnormal.",
|
19 |
+
"enlarge": "It usually modifies cardiac silhouette and heart. Cardiomegaly is a medical condition in which the heart is enlarged.",
|
20 |
+
"tip": "It refers to the top head of the tube.",
|
21 |
+
"low": "The presence of low lung volumes may be a sign of a restrictive lung condition such as pulmonary fibrosis or sarcoidosis.",
|
22 |
+
"eumonia": "Pneumonia is an inflammatory condition of the lung primarily affecting the small air sacs known as alveoli. Pneumonia may present with opacities. Complications such as pleural effusion may also be found increasing the diagnostic accuracy of lung consolidation and pleural effusion",
|
23 |
+
"line": "It refers to venous access line ot PICC lines.",
|
24 |
+
"congestion": "Pulmonary congestion is defined as accumulation of fluid in the lungs, resulting in impaired gas exchange and arterial hypoxemia. ",
|
25 |
+
"catheter": "catheter is a tube placed in the body to drain and collect urine from the bladder",
|
26 |
+
"cardiomegaly": "Cardiomegaly (sometimes megacardia or megalocardia) is a medical condition in which the heart is enlarged. ",
|
27 |
+
"fracture": "fracture is a break in a rib bone.",
|
28 |
+
"air": "It refers to the free air or gas in pleural space, indicating pneumothorax. Air displacement by fluid may lead to opacity.",
|
29 |
+
"tortuous": "the Aorta is slightly tortuous. Sometimes it may refer to varicose veins",
|
30 |
+
"lead": "It refers to the leading head of the tube.",
|
31 |
+
"disease": "It means the exist of diseases and abnormalty, indicating the structure is abnormal. ",
|
32 |
+
"calcification": "Pulmonary calcification is a common asymptomatic finding. Pulmonary calcifications are caused mainly by two mechanisms: the dystrophic form and the metastatic form",
|
33 |
+
"prominence": "It means the exist of some observation.",
|
34 |
+
"device": "It refer to some equipments like picc tub, valve catheter, pacemaker hardware, arthroplastmarker icd defib, device support equipment and mediport'",
|
35 |
+
"engorgement": "pulmonary vascular engorgement means obstruction of the normal flux of blood within the blood vessel network of the lung resulting in engorgement of pulmonary vessels",
|
36 |
+
"picc": "A peripherally inserted central catheter (PICC), also called a PICC line, is a long, thin tube that's inserted through a vein in your arm and passed through to the larger veins near your heart. ",
|
37 |
+
"clip": "Surgical clips or vascular clips usually represent the one kind of medical equipments.",
|
38 |
+
"elevation": "If tissues or anatomical structures are elevated, they are raised up higher than the normal location.",
|
39 |
+
"expand": "It means the lungs are normally expanded and clear, indicating the absence of pneumothorax.",
|
40 |
+
"nodule": "A lung nodule or pulmonary nodule is a relatively small focal density in the lung. it may be confused with the projection of a structure of the chest wall or skin, such as a nipple, a healing rib fracture or lung cancer.",
|
41 |
+
"wire": "sternotomis wires means the center line of the chest.",
|
42 |
+
"fluid": "It refers to the water of liquid in the lung and it may indicate edema and other diseases.",
|
43 |
+
"degenerative": "Degenerative disease is the result of a continuous process based on degenerative cell changes",
|
44 |
+
"pacemaker": "pacemaker device usually represents the one kind of medical equipments.",
|
45 |
+
"thicken": "Pleural thickening is an increase in the bulkiness of one or both of the pulmonary pleurae. It may cause by pulmonary Infection, empyema, tuberculosis or lung cancer.",
|
46 |
+
"marking": "It represents interstitial markings or bronchovascular markings",
|
47 |
+
"scar": "A scar (or scar tissue) is an area of fibrous tissue that replaces normal tissues after an injury.",
|
48 |
+
"hyperinflate": "Hyperinflated lungs are larger-than-normal lungs as a result of trapped air.",
|
49 |
+
"blunt": "Blunting of the costophrenic angles is usually caused by a pleural effusion, as already discussed. Other causes of costophrenic angle blunting include lung disease in the region of the costophrenic angle, and lung hyperexpansion.",
|
50 |
+
"loss": "The etiology of lung volume loss can be listed as follows: airway obstruction or compression, obesity, scoliosis, restrictive diseases such as pulmonary fibrosis and interstitial lung disease, tuberculosis, sarcoidosis, pleural effusions, rib injury (fractures or diaphragm paralysis), and heart failure",
|
51 |
+
"widen": "The mediastinum is not widened or enlarged",
|
52 |
+
"collapse": "collapse lung refers to pneumothorax or atelectasis.",
|
53 |
+
"density": "The density (more precisely, the volumetric mass density; also known as specific mass), of a substance is its mass per unit volume. ",
|
54 |
+
"emphysema": "Emphysema, or pulmonary emphysema, is a lower respiratory tract disease, characterized by air-filled spaces (pneumatosis) in the lungs, that can vary in size and may be very large",
|
55 |
+
"aerate": "Aeration (also called aerification or aeriation) is the process by which air is circulated through, mixed with or dissolved in a liquid or other substances that act as a fluid (such as soil).",
|
56 |
+
"mass": "A lung mass is an abnormal growth or area in the lungs and it can also view as lung cancer.",
|
57 |
+
"crowd": "Crowding of the bronchovascular structures is an important direct sign of volume loss. The atelectatic lung enhances densely after contrast administration because of closeness of the pulmonary arteries and arterioles within the collapsed lobe.",
|
58 |
+
"infiltrate": "A pulmonary infiltrate is a substance denser than air, such as pus, blood, or protein, which lingers within the parenchyma of the lungs. Pulmonary infiltrates are associated with pneumonia, tuberculosis and sarcoidosis.",
|
59 |
+
"obscure": "Some anatomy structures are not clear and is difficult to understand or see",
|
60 |
+
"deformity": "It means some body parts are abnormal or unjuried.",
|
61 |
+
"hernia": "Lung hernia (Sibson hernia) is a protrusion of lung outside of thoracic wall. the hernia is noted after chest trauma, thoracic surgery or certain pulmonary diseases",
|
62 |
+
"drainage": "Tube drainage represents the one kind of medical equipment",
|
63 |
+
"distention": "Distension generally refers to an enlargement, dilation, or ballooning effect. It may refer to: Abdominal distension,",
|
64 |
+
"shift": "The mediastinal shift is the deviation of the mediastinal structures towards one side of the chest cavity, usually seen on chest radiograph. It indicates a severe asymmetry of intrathoracic pressures.",
|
65 |
+
"stent": "tracheal stent represents the one kind of medical equipments",
|
66 |
+
"pressure": "Pulmonary venous pressure is intermediate between mean PAP and LAP over all physiologic pressures",
|
67 |
+
"lesion": "Lung nodules, pulmonary nodules, white spots, lesions\u2014these terms all describe the same phenomenon: an abnormality in the lungs.",
|
68 |
+
"finding": "Some observation on body parts, usually indicating abnormalty.",
|
69 |
+
"borderline": "borderline size of the cardiac silhouette means the cardiac silhouette is not enlarged and normal.",
|
70 |
+
"hardware": "It represents the one kind of medical equipments.",
|
71 |
+
"dilation": "the state of being larger or more open than normal",
|
72 |
+
"chf": "Heart failure \u2014 sometimes known as congestive heart failure \u2014 occurs when the heart muscle doesn't pump blood as well as it should. When this happens, blood often backs up and fluid can build up in the lungs, causing shortness of breath.",
|
73 |
+
"redistribution": "If the pulmonary edema is due to heart failure or fluid overload, you may also see cardiomegaly and distension of the pulmonary veins, particularly in the upper lung fields.",
|
74 |
+
"aspiration": "Aspiration pneumonia occurs when food or liquid is breathed into the airways or lungs, instead of being swallowed. ",
|
75 |
+
"tail_abnorm_obs": "Some very rare diseases.",
|
76 |
+
"excluded_obs": "Some observations that seldom appear in the reports."
|
77 |
+
}
|
PreTrain_MeDSLIP/data_file/preprocessing/adj_matrix.py
ADDED
@@ -0,0 +1,891 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code copied from AGXNet:
|
3 |
+
https://github.com/batmanlab/AGXNet
|
4 |
+
"""
|
5 |
+
|
6 |
+
"""Create adjacency matrix for representing the relations between anatomical landmarks and observations."""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
from tqdm import tqdm, trange
|
14 |
+
from torch.utils.data import Dataset, DataLoader
|
15 |
+
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser(description="Create Adjacency matrix Matrix.")
|
18 |
+
|
19 |
+
parser.add_argument(
|
20 |
+
"--input-path",
|
21 |
+
default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-sentence-parsed.csv",
|
22 |
+
help="Itemized input data path.",
|
23 |
+
)
|
24 |
+
|
25 |
+
# List of most common normal observations
|
26 |
+
NORM_OBS = [
|
27 |
+
"normal",
|
28 |
+
"clear",
|
29 |
+
"sharp",
|
30 |
+
"sharply",
|
31 |
+
"unremarkable",
|
32 |
+
"intact",
|
33 |
+
"stable",
|
34 |
+
"free",
|
35 |
+
]
|
36 |
+
|
37 |
+
# exclude
|
38 |
+
EXCLUDED_OBS = [
|
39 |
+
"none",
|
40 |
+
"unchanged",
|
41 |
+
"change",
|
42 |
+
"great",
|
43 |
+
"similar",
|
44 |
+
"large",
|
45 |
+
"small",
|
46 |
+
"moderate",
|
47 |
+
"mild",
|
48 |
+
"median",
|
49 |
+
"decrease",
|
50 |
+
"bad",
|
51 |
+
"more",
|
52 |
+
"constant",
|
53 |
+
"worsen",
|
54 |
+
"new",
|
55 |
+
"improve",
|
56 |
+
"status",
|
57 |
+
"position",
|
58 |
+
"sternotomy",
|
59 |
+
"cabg",
|
60 |
+
"replacement",
|
61 |
+
"postoperative",
|
62 |
+
"assessment",
|
63 |
+
"patient",
|
64 |
+
]
|
65 |
+
|
66 |
+
# top 90% abnormal observations
|
67 |
+
ABNORM_OBS = [
|
68 |
+
"effusion",
|
69 |
+
"opacity",
|
70 |
+
"pneumothorax",
|
71 |
+
"edema",
|
72 |
+
"atelectasis",
|
73 |
+
"tube",
|
74 |
+
"consolidation",
|
75 |
+
"process",
|
76 |
+
"abnormality",
|
77 |
+
"enlarge",
|
78 |
+
"tip",
|
79 |
+
"low",
|
80 |
+
"pneumonia",
|
81 |
+
"line",
|
82 |
+
"congestion",
|
83 |
+
"catheter",
|
84 |
+
"cardiomegaly",
|
85 |
+
"fracture",
|
86 |
+
"air",
|
87 |
+
"tortuous",
|
88 |
+
"lead",
|
89 |
+
"disease",
|
90 |
+
"calcification",
|
91 |
+
"prominence",
|
92 |
+
"device",
|
93 |
+
"engorgement",
|
94 |
+
"picc",
|
95 |
+
"clip",
|
96 |
+
"elevation",
|
97 |
+
"expand",
|
98 |
+
"nodule",
|
99 |
+
"wire",
|
100 |
+
"fluid",
|
101 |
+
"degenerative",
|
102 |
+
"pacemaker",
|
103 |
+
"thicken",
|
104 |
+
"marking",
|
105 |
+
"scar",
|
106 |
+
"hyperinflate",
|
107 |
+
"blunt",
|
108 |
+
"loss",
|
109 |
+
"widen",
|
110 |
+
"collapse",
|
111 |
+
"density",
|
112 |
+
"emphysema",
|
113 |
+
"aerate",
|
114 |
+
"mass",
|
115 |
+
"crowd",
|
116 |
+
"infiltrate",
|
117 |
+
"obscure",
|
118 |
+
"deformity",
|
119 |
+
"hernia",
|
120 |
+
"drainage",
|
121 |
+
"distention",
|
122 |
+
"shift",
|
123 |
+
"stent",
|
124 |
+
"pressure",
|
125 |
+
"lesion",
|
126 |
+
"finding",
|
127 |
+
"borderline",
|
128 |
+
"hardware",
|
129 |
+
"dilation",
|
130 |
+
"chf",
|
131 |
+
"redistribution",
|
132 |
+
"aspiration",
|
133 |
+
]
|
134 |
+
|
135 |
+
# final row and column names in adjacent matrix
|
136 |
+
LANDMARK_NAME = [
|
137 |
+
"trachea",
|
138 |
+
"left_hilar",
|
139 |
+
"right_hilar",
|
140 |
+
"hilar_unspec",
|
141 |
+
"left_pleural",
|
142 |
+
"right_pleural",
|
143 |
+
"pleural_unspec",
|
144 |
+
"heart_size",
|
145 |
+
"heart_border",
|
146 |
+
"left_diaphragm",
|
147 |
+
"right_diaphragm",
|
148 |
+
"diaphragm_unspec",
|
149 |
+
"retrocardiac",
|
150 |
+
"lower_left_lobe",
|
151 |
+
"upper_left_lobe",
|
152 |
+
"lower_right_lobe",
|
153 |
+
"middle_right_lobe",
|
154 |
+
"upper_right_lobe",
|
155 |
+
"left_lower_lung",
|
156 |
+
"left_mid_lung",
|
157 |
+
"left_upper_lung",
|
158 |
+
"left_apical_lung",
|
159 |
+
"left_lung_unspec",
|
160 |
+
"right_lower_lung",
|
161 |
+
"right_mid_lung",
|
162 |
+
"right_upper_lung",
|
163 |
+
"right_apical_lung",
|
164 |
+
"right_lung_unspec",
|
165 |
+
"lung_apices",
|
166 |
+
"lung_bases",
|
167 |
+
"left_costophrenic",
|
168 |
+
"right_costophrenic",
|
169 |
+
"costophrenic_unspec",
|
170 |
+
"cardiophrenic_sulcus",
|
171 |
+
"mediastinal",
|
172 |
+
"spine",
|
173 |
+
"clavicle",
|
174 |
+
"rib",
|
175 |
+
"stomach",
|
176 |
+
"right_atrium",
|
177 |
+
"right_ventricle",
|
178 |
+
"aorta",
|
179 |
+
"svc",
|
180 |
+
"interstitium",
|
181 |
+
"parenchymal",
|
182 |
+
"cavoatrial_junction",
|
183 |
+
"cardiopulmonary",
|
184 |
+
"pulmonary",
|
185 |
+
"lung_volumes",
|
186 |
+
"unspecified",
|
187 |
+
"other",
|
188 |
+
]
|
189 |
+
|
190 |
+
OBSERVATION_CLASS = [
|
191 |
+
"normal",
|
192 |
+
"clear",
|
193 |
+
"sharp",
|
194 |
+
"sharply",
|
195 |
+
"unremarkable",
|
196 |
+
"intact",
|
197 |
+
"stable",
|
198 |
+
"free",
|
199 |
+
"effusion",
|
200 |
+
"opacity",
|
201 |
+
"pneumothorax",
|
202 |
+
"edema",
|
203 |
+
"atelectasis",
|
204 |
+
"tube",
|
205 |
+
"consolidation",
|
206 |
+
"process",
|
207 |
+
"abnormality",
|
208 |
+
"enlarge",
|
209 |
+
"tip",
|
210 |
+
"low",
|
211 |
+
"pneumonia",
|
212 |
+
"line",
|
213 |
+
"congestion",
|
214 |
+
"catheter",
|
215 |
+
"cardiomegaly",
|
216 |
+
"fracture",
|
217 |
+
"air",
|
218 |
+
"tortuous",
|
219 |
+
"lead",
|
220 |
+
"disease",
|
221 |
+
"calcification",
|
222 |
+
"prominence",
|
223 |
+
"device",
|
224 |
+
"engorgement",
|
225 |
+
"picc",
|
226 |
+
"clip",
|
227 |
+
"elevation",
|
228 |
+
"expand",
|
229 |
+
"nodule",
|
230 |
+
"wire",
|
231 |
+
"fluid",
|
232 |
+
"degenerative",
|
233 |
+
"pacemaker",
|
234 |
+
"thicken",
|
235 |
+
"marking",
|
236 |
+
"scar",
|
237 |
+
"hyperinflate",
|
238 |
+
"blunt",
|
239 |
+
"loss",
|
240 |
+
"widen",
|
241 |
+
"collapse",
|
242 |
+
"density",
|
243 |
+
"emphysema",
|
244 |
+
"aerate",
|
245 |
+
"mass",
|
246 |
+
"crowd",
|
247 |
+
"infiltrate",
|
248 |
+
"obscure",
|
249 |
+
"deformity",
|
250 |
+
"hernia",
|
251 |
+
"drainage",
|
252 |
+
"distention",
|
253 |
+
"shift",
|
254 |
+
"stent",
|
255 |
+
"pressure",
|
256 |
+
"lesion",
|
257 |
+
"finding",
|
258 |
+
"borderline",
|
259 |
+
"hardware",
|
260 |
+
"dilation",
|
261 |
+
"chf",
|
262 |
+
"redistribution",
|
263 |
+
"aspiration",
|
264 |
+
"tail_abnorm_obs",
|
265 |
+
"excluded_obs",
|
266 |
+
]
|
267 |
+
|
268 |
+
DICT_ANATOMICAL_LANDMARKS = {
|
269 |
+
"trachea": {"a": ["trachea", "tracheal"], "m1": [], "m2": [], "sc": [], "t": "m0"},
|
270 |
+
"left_hilar": {
|
271 |
+
"a": ["hilar", "hilum", "perihilar", "infrahilar"],
|
272 |
+
"m1": ["left"],
|
273 |
+
"m2": ["right"],
|
274 |
+
"sc": [],
|
275 |
+
"t": "m1+m2-",
|
276 |
+
},
|
277 |
+
"right_hilar": {
|
278 |
+
"a": ["hilar", "hilum", "perihilar", "infrahilar"],
|
279 |
+
"m1": ["right"],
|
280 |
+
"m2": ["left"],
|
281 |
+
"sc": [],
|
282 |
+
"t": "m1+m2-",
|
283 |
+
},
|
284 |
+
"hilar_unspec": {
|
285 |
+
"a": ["hilar", "hilum", "perihilar", "infrahilar"],
|
286 |
+
"m1": ["left", "right"],
|
287 |
+
"m2": [],
|
288 |
+
"sc": ["hila", "perihilar|right|left", "perihilar|left|right"],
|
289 |
+
"t": "m1-",
|
290 |
+
},
|
291 |
+
"left_pleural": {
|
292 |
+
"a": ["pleural"],
|
293 |
+
"m1": ["left"],
|
294 |
+
"m2": ["right"],
|
295 |
+
"sc": [],
|
296 |
+
"t": "m1+m2-",
|
297 |
+
},
|
298 |
+
"right_pleural": {
|
299 |
+
"a": ["pleural"],
|
300 |
+
"m1": ["right"],
|
301 |
+
"m2": ["left"],
|
302 |
+
"sc": [],
|
303 |
+
"t": "m1+m2-",
|
304 |
+
},
|
305 |
+
"pleural_unspec": {
|
306 |
+
"a": ["pleural"],
|
307 |
+
"m1": ["left", "right"],
|
308 |
+
"m2": [],
|
309 |
+
"sc": [
|
310 |
+
"pleural|left|right",
|
311 |
+
"pleural|right|left",
|
312 |
+
"pleural|bilateral|right|left",
|
313 |
+
"pleural|bilateral|left|right",
|
314 |
+
],
|
315 |
+
"t": "m1-",
|
316 |
+
},
|
317 |
+
"heart_size": {
|
318 |
+
"a": ["heart", "cardiac"],
|
319 |
+
"m1": ["border", "borders"],
|
320 |
+
"m2": [],
|
321 |
+
"sc": [],
|
322 |
+
"t": "m1-",
|
323 |
+
},
|
324 |
+
"heart_border": {
|
325 |
+
"a": ["heart", "cardiac"],
|
326 |
+
"m1": ["border", "borders"],
|
327 |
+
"m2": [],
|
328 |
+
"sc": [],
|
329 |
+
"t": "m1+",
|
330 |
+
},
|
331 |
+
"left_diaphragm": {
|
332 |
+
"a": ["diaphragm", "hemidiaphragm"],
|
333 |
+
"m1": ["left"],
|
334 |
+
"m2": ["right"],
|
335 |
+
"sc": [],
|
336 |
+
"t": "m1+m2-",
|
337 |
+
},
|
338 |
+
"right_diaphragm": {
|
339 |
+
"a": ["diaphragm", "hemidiaphragm"],
|
340 |
+
"m1": ["right"],
|
341 |
+
"m2": ["left"],
|
342 |
+
"sc": [],
|
343 |
+
"t": "m1+m2-",
|
344 |
+
},
|
345 |
+
"diaphragm_unspec": {
|
346 |
+
"a": ["diaphragm", "diaphragms", "hemidiaphragms", "hemidiaphragm"],
|
347 |
+
"m1": ["left", "right"],
|
348 |
+
"m2": [],
|
349 |
+
"sc": ["hemidiaphragm|left|right", "hemidiaphragm|right|left"],
|
350 |
+
"t": "m1-",
|
351 |
+
},
|
352 |
+
"retrocardiac": {"a": ["retrocardiac"], "m1": [], "m2": [], "sc": [], "t": "m0"},
|
353 |
+
"lower_left_lobe": {
|
354 |
+
"a": ["lobe"],
|
355 |
+
"m1": ["left"],
|
356 |
+
"m2": ["lower"],
|
357 |
+
"sc": [],
|
358 |
+
"t": "m1+m2+",
|
359 |
+
},
|
360 |
+
"upper_left_lobe": {
|
361 |
+
"a": ["lobe"],
|
362 |
+
"m1": ["left"],
|
363 |
+
"m2": ["upper"],
|
364 |
+
"sc": ["lingula", "lingular"],
|
365 |
+
"t": "m1+m2+",
|
366 |
+
},
|
367 |
+
"lower_right_lobe": {
|
368 |
+
"a": ["lobe"],
|
369 |
+
"m1": ["right"],
|
370 |
+
"m2": ["lower"],
|
371 |
+
"sc": [],
|
372 |
+
"t": "m1+m2+",
|
373 |
+
},
|
374 |
+
"middle_right_lobe": {
|
375 |
+
"a": ["lobe"],
|
376 |
+
"m1": ["right"],
|
377 |
+
"m2": ["middle"],
|
378 |
+
"sc": [],
|
379 |
+
"t": "m1+m2+",
|
380 |
+
},
|
381 |
+
"upper_right_lobe": {
|
382 |
+
"a": ["lobe"],
|
383 |
+
"m1": ["right"],
|
384 |
+
"m2": ["upper"],
|
385 |
+
"sc": [],
|
386 |
+
"t": "m1+m2+",
|
387 |
+
},
|
388 |
+
"left_lower_lung": {
|
389 |
+
"a": ["lung"],
|
390 |
+
"m1": ["left"],
|
391 |
+
"m2": ["lower", "base", "basilar", "basal", "basis"],
|
392 |
+
"sc": ["base|left", "basilar|left", "basal|left", "lung|left|bases"],
|
393 |
+
"t": "m1+m2+",
|
394 |
+
},
|
395 |
+
"left_mid_lung": {
|
396 |
+
"a": ["lung"],
|
397 |
+
"m1": ["left"],
|
398 |
+
"m2": ["middle", "mid"],
|
399 |
+
"sc": ["midlung|left"],
|
400 |
+
"t": "m1+m2+",
|
401 |
+
},
|
402 |
+
"left_upper_lung": {
|
403 |
+
"a": ["lung"],
|
404 |
+
"m1": ["left"],
|
405 |
+
"m2": ["upper"],
|
406 |
+
"sc": [],
|
407 |
+
"t": "m1+m2+",
|
408 |
+
},
|
409 |
+
"left_apical_lung": {
|
410 |
+
"a": ["apex", "apical", "apical", "apicolateral"],
|
411 |
+
"m1": ["left"],
|
412 |
+
"m2": ["right"],
|
413 |
+
"sc": [],
|
414 |
+
"t": "m1+m2-",
|
415 |
+
},
|
416 |
+
"left_lung_unspec": {
|
417 |
+
"a": ["lung", "hemithorax"],
|
418 |
+
"m1": ["left", "left-sided"],
|
419 |
+
"m2": [
|
420 |
+
"volume",
|
421 |
+
"volumes",
|
422 |
+
"right",
|
423 |
+
"lower",
|
424 |
+
"base",
|
425 |
+
"bases",
|
426 |
+
"basilar",
|
427 |
+
"basilar",
|
428 |
+
"basal",
|
429 |
+
"basis",
|
430 |
+
"middle",
|
431 |
+
"mid",
|
432 |
+
"upper",
|
433 |
+
"apex",
|
434 |
+
"apical",
|
435 |
+
"perihilar",
|
436 |
+
],
|
437 |
+
"sc": ["left", "left side", "thorax|left|hemi"],
|
438 |
+
"t": "m1+m2-",
|
439 |
+
},
|
440 |
+
"right_lower_lung": {
|
441 |
+
"a": ["lung"],
|
442 |
+
"m1": ["right"],
|
443 |
+
"m2": ["lower", "base", "basilar", "basal", "basis"],
|
444 |
+
"sc": ["base|right", "basilar|right", "basal|right", "lung|right|bases"],
|
445 |
+
"t": "m1+m2+",
|
446 |
+
},
|
447 |
+
"right_mid_lung": {
|
448 |
+
"a": ["lung"],
|
449 |
+
"m1": ["right"],
|
450 |
+
"m2": ["middle", "mid"],
|
451 |
+
"sc": [],
|
452 |
+
"t": "m1+m2+",
|
453 |
+
},
|
454 |
+
"right_upper_lung": {
|
455 |
+
"a": ["lung"],
|
456 |
+
"m1": ["right"],
|
457 |
+
"m2": ["upper"],
|
458 |
+
"sc": [],
|
459 |
+
"t": "m1+m2+",
|
460 |
+
},
|
461 |
+
"right_apical_lung": {
|
462 |
+
"a": ["apex", "apical", "apical", "apicolateral"],
|
463 |
+
"m1": ["right"],
|
464 |
+
"m2": ["left"],
|
465 |
+
"sc": [],
|
466 |
+
"t": "m1+m2-",
|
467 |
+
},
|
468 |
+
"right_lung_unspec": {
|
469 |
+
"a": ["lung", "hemithorax"],
|
470 |
+
"m1": ["right", "right-sided"],
|
471 |
+
"m2": [
|
472 |
+
"volume",
|
473 |
+
"volumes",
|
474 |
+
"left",
|
475 |
+
"lower",
|
476 |
+
"base",
|
477 |
+
"bases",
|
478 |
+
"basilar",
|
479 |
+
"basilar",
|
480 |
+
"basal",
|
481 |
+
"basis",
|
482 |
+
"middle",
|
483 |
+
"mid",
|
484 |
+
"upper",
|
485 |
+
"apex",
|
486 |
+
"apical",
|
487 |
+
"perihilar",
|
488 |
+
],
|
489 |
+
"sc": ["right", "right side", "thorax|right|hemi"],
|
490 |
+
"t": "m1+m2-",
|
491 |
+
},
|
492 |
+
"lung_apices": {
|
493 |
+
"a": ["apices", "apical"],
|
494 |
+
"m1": ["left", "right"],
|
495 |
+
"m2": [],
|
496 |
+
"sc": ["biapical", "lungs|upper"],
|
497 |
+
"t": "m1-",
|
498 |
+
},
|
499 |
+
"lung_bases": {
|
500 |
+
"a": ["lung", "lungs"],
|
501 |
+
"m1": ["left", "right"],
|
502 |
+
"m2": ["bibasilar", "basilar", "base", "bases", "bibasal", "basal"],
|
503 |
+
"sc": [
|
504 |
+
"lung|lower",
|
505 |
+
"lungs|lower",
|
506 |
+
"bibasilar",
|
507 |
+
"basilar",
|
508 |
+
"bases",
|
509 |
+
"bibasal",
|
510 |
+
"basal",
|
511 |
+
"basal|bilateral",
|
512 |
+
"lobe|lower",
|
513 |
+
"lobes|lower",
|
514 |
+
"lobe|bilateral|lower",
|
515 |
+
"bases|both",
|
516 |
+
"bibasilar|left|right",
|
517 |
+
"bibasilar|right|left",
|
518 |
+
],
|
519 |
+
"t": "m1-m2+",
|
520 |
+
},
|
521 |
+
"left_costophrenic": {
|
522 |
+
"a": ["costophrenic"],
|
523 |
+
"m1": ["left"],
|
524 |
+
"m2": ["right"],
|
525 |
+
"sc": [],
|
526 |
+
"t": "m1+m2-",
|
527 |
+
},
|
528 |
+
"right_costophrenic": {
|
529 |
+
"a": ["costophrenic"],
|
530 |
+
"m1": ["right"],
|
531 |
+
"m2": ["left"],
|
532 |
+
"sc": [],
|
533 |
+
"t": "m1+m2-",
|
534 |
+
},
|
535 |
+
"costophrenic_unspec": {
|
536 |
+
"a": ["costophrenic"],
|
537 |
+
"m1": ["left", "right"],
|
538 |
+
"m2": [],
|
539 |
+
"sc": [],
|
540 |
+
"t": "m1-",
|
541 |
+
},
|
542 |
+
"cardiophrenic_sulcus": {
|
543 |
+
"a": ["cardiophrenic"],
|
544 |
+
"m1": [],
|
545 |
+
"m2": [],
|
546 |
+
"sc": [],
|
547 |
+
"t": "m0",
|
548 |
+
},
|
549 |
+
"mediastinal": {
|
550 |
+
"a": ["mediastinal", "cardiomediastinal", "mediastinum", "cardiomediastinum"],
|
551 |
+
"m1": [],
|
552 |
+
"m2": [],
|
553 |
+
"sc": [],
|
554 |
+
"t": "m0",
|
555 |
+
},
|
556 |
+
"spine": {"a": ["spine", "spinal"], "m1": [], "m2": [], "sc": [], "t": "m0"},
|
557 |
+
"clavicle": {
|
558 |
+
"a": ["clavicle", "clavicles"],
|
559 |
+
"m1": [],
|
560 |
+
"m2": [],
|
561 |
+
"sc": [],
|
562 |
+
"t": "m0",
|
563 |
+
},
|
564 |
+
"rib": {"a": ["rib", "ribs"], "m1": [], "m2": [], "sc": [], "t": "m0"},
|
565 |
+
"stomach": {
|
566 |
+
"a": ["stomach", "abdomen", "abdominal"],
|
567 |
+
"m1": [],
|
568 |
+
"m2": [],
|
569 |
+
"sc": [],
|
570 |
+
"t": "m0",
|
571 |
+
},
|
572 |
+
"right_atrium": {
|
573 |
+
"a": ["atrium", "atrial"],
|
574 |
+
"m1": ["right"],
|
575 |
+
"m2": ["left"],
|
576 |
+
"sc": [],
|
577 |
+
"t": "m1+m2-",
|
578 |
+
},
|
579 |
+
"right_ventricle": {
|
580 |
+
"a": ["ventricle", "ventricular"],
|
581 |
+
"m1": ["right"],
|
582 |
+
"m2": ["left"],
|
583 |
+
"sc": [],
|
584 |
+
"t": "m1+m2-",
|
585 |
+
},
|
586 |
+
"aorta": {"a": ["aorta", "aortic"], "m1": [], "m2": [], "sc": [], "t": "m0"},
|
587 |
+
"svc": {"a": ["svc"], "m1": [], "m2": [], "sc": [], "t": "m0"},
|
588 |
+
"interstitium": {
|
589 |
+
"a": ["interstitium", "interstitial"],
|
590 |
+
"m1": [],
|
591 |
+
"m2": [],
|
592 |
+
"sc": [],
|
593 |
+
"t": "m0",
|
594 |
+
},
|
595 |
+
"parenchymal": {"a": ["parenchymal"], "m1": [], "m2": [], "sc": [], "t": "m0"},
|
596 |
+
"cavoatrial_junction": {
|
597 |
+
"a": ["cavoatrial junction"],
|
598 |
+
"m1": [],
|
599 |
+
"m2": [],
|
600 |
+
"sc": [],
|
601 |
+
"t": "m0",
|
602 |
+
},
|
603 |
+
"cardiopulmonary": {
|
604 |
+
"a": ["cardiopulmonary"],
|
605 |
+
"m1": [],
|
606 |
+
"m2": [],
|
607 |
+
"sc": [],
|
608 |
+
"t": "m0",
|
609 |
+
},
|
610 |
+
"pulmonary": {"a": ["pulmonary"], "m1": [], "m2": [], "sc": [], "t": "m0"},
|
611 |
+
"lung_volumes": {
|
612 |
+
"a": ["lungs", "lung", "volume", "volumes"],
|
613 |
+
"m1": [
|
614 |
+
"left",
|
615 |
+
"right",
|
616 |
+
"lower",
|
617 |
+
"base",
|
618 |
+
"bases",
|
619 |
+
"basilar",
|
620 |
+
"basal",
|
621 |
+
"basis",
|
622 |
+
"middle",
|
623 |
+
"mid",
|
624 |
+
"upper",
|
625 |
+
"apex",
|
626 |
+
"apical",
|
627 |
+
"apical",
|
628 |
+
],
|
629 |
+
"m2": [],
|
630 |
+
"sc": [],
|
631 |
+
"t": "m1-",
|
632 |
+
},
|
633 |
+
}
|
634 |
+
|
635 |
+
|
636 |
+
class LandmarkObservationAdjacentMatrix(Dataset):
|
637 |
+
def __init__(self, LANDMARK_NAME, OBSERVATION_CLASS, df_anatomy_label):
|
638 |
+
self.LANDMARK_NAME = LANDMARK_NAME
|
639 |
+
self.OBSERVATION_CLASS = OBSERVATION_CLASS
|
640 |
+
self.df_anatomy_label = df_anatomy_label
|
641 |
+
|
642 |
+
# get all study ids
|
643 |
+
self.sids = list(self.df_anatomy_label["study_id"].unique())
|
644 |
+
|
645 |
+
def __getitem__(self, idx):
|
646 |
+
sid = self.sids[idx]
|
647 |
+
df_sid = self.df_anatomy_label[self.df_anatomy_label["study_id"] == sid]
|
648 |
+
landmark_observation_adj_mtx = (
|
649 |
+
np.zeros((len(LANDMARK_NAME), len(OBSERVATION_CLASS))) - 1.0
|
650 |
+
)
|
651 |
+
for index, row in df_sid.iterrows():
|
652 |
+
try:
|
653 |
+
observation_idx = self.OBSERVATION_CLASS.index(
|
654 |
+
row.obs_lemma_grp
|
655 |
+
) # if a rare observation, skip this instance
|
656 |
+
landmark_idx = self.LANDMARK_NAME.index(row.landmark_name)
|
657 |
+
|
658 |
+
curr_val = landmark_observation_adj_mtx[landmark_idx, observation_idx]
|
659 |
+
|
660 |
+
# for obs_lemma_grp, such as tail_abnorm_obs
|
661 |
+
# if one observation is DP, then 1.0
|
662 |
+
if row.label == "OBS-DP":
|
663 |
+
landmark_observation_adj_mtx[landmark_idx, observation_idx] = 1.0
|
664 |
+
elif row.label == "OBS-DA":
|
665 |
+
landmark_observation_adj_mtx[
|
666 |
+
landmark_idx, observation_idx
|
667 |
+
] = np.maximum(curr_val, 0.0)
|
668 |
+
except:
|
669 |
+
pass
|
670 |
+
return sid, landmark_observation_adj_mtx
|
671 |
+
|
672 |
+
def __len__(self):
|
673 |
+
return len(self.sids)
|
674 |
+
|
675 |
+
|
676 |
+
def anatomy_to_landmark(x, a, m1=[], m2=[], sc=[], t="m0"):
|
677 |
+
"""
|
678 |
+
Args:
|
679 |
+
x: input anatomy, e.g., "lobe|left|lower"
|
680 |
+
a: base anatomy set, e.g., ["hilar", "hilum", "perihilar"]
|
681 |
+
m1: level 1 modifier, e.g., ["left", "right"]
|
682 |
+
m2: level 2 modifier, e.g., ["upper", "middle", "lower"]
|
683 |
+
s: special cases, e.g., ["chest"]
|
684 |
+
t: type, ["m2+", "m1+m2-"]
|
685 |
+
Return:
|
686 |
+
flag: boolean, matched or not matched
|
687 |
+
"""
|
688 |
+
s = set(x.split("|"))
|
689 |
+
if t == "m1+m2+":
|
690 |
+
flag = (len(s & set(a)) > 0) & (len(s & set(m1)) > 0) & (len(s & set(m2)) > 0)
|
691 |
+
elif t == "m1+m2-":
|
692 |
+
flag = (len(s & set(a)) > 0) & (len(s & set(m1)) > 0) & (len(s & set(m2)) == 0)
|
693 |
+
elif t == "m1-m2+":
|
694 |
+
flag = (len(s & set(a)) > 0) & (len(s & set(m1)) == 0) & (len(s & set(m2)) > 0)
|
695 |
+
elif t == "m1-m2-":
|
696 |
+
flag = (len(s & set(a)) > 0) & (len(s & set(m1)) == 0) & (len(s & set(m2)) == 0)
|
697 |
+
elif t == "m1+":
|
698 |
+
flag = (len(s & set(a)) > 0) & (len(s & set(m1)) > 0)
|
699 |
+
elif t == "m2+":
|
700 |
+
flag = (len(s & set(a)) > 0) & (len(s & set(m2)) > 0)
|
701 |
+
elif t == "m1-":
|
702 |
+
flag = (len(s & set(a)) > 0) & (len(s & set(m1)) == 0)
|
703 |
+
elif t == "m2-":
|
704 |
+
flag = (len(s & set(a)) > 0) & (len(s & set(m2)) == 0)
|
705 |
+
elif t == "m0":
|
706 |
+
flag = len(s & set(a)) > 0
|
707 |
+
|
708 |
+
if sc:
|
709 |
+
flag = flag | (x in sc)
|
710 |
+
return flag
|
711 |
+
|
712 |
+
|
713 |
+
def create_adj_matrix(args):
|
714 |
+
# load anatomy label table, text table and master table
|
715 |
+
print("Loading parsed RadGraph data...")
|
716 |
+
df_anatomy_label = pd.read_csv(args.input_path, dtype=str)
|
717 |
+
|
718 |
+
# manual lemmatization correction
|
719 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["enlargement", "increase"])
|
720 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "enlarge"
|
721 |
+
|
722 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["engorge"])
|
723 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "engorgement"
|
724 |
+
|
725 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["opacification", "opacity-"])
|
726 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "opacity"
|
727 |
+
|
728 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["calcify"])
|
729 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "calcification"
|
730 |
+
|
731 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["effusion ;"])
|
732 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "effusion"
|
733 |
+
|
734 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(
|
735 |
+
["atelectatic", "atelectasis ;", "atelectase"]
|
736 |
+
)
|
737 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "atelectasis"
|
738 |
+
|
739 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["aeration"])
|
740 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "aerate"
|
741 |
+
|
742 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["distend", "distension"])
|
743 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "distention"
|
744 |
+
|
745 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["wide"])
|
746 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "widen"
|
747 |
+
|
748 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["prominent"])
|
749 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "prominence"
|
750 |
+
|
751 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["haze"])
|
752 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "haziness"
|
753 |
+
|
754 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["masse"])
|
755 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "mass"
|
756 |
+
|
757 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["kyphotic"])
|
758 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "kyphosis"
|
759 |
+
|
760 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["degenerate"])
|
761 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "degenerative"
|
762 |
+
|
763 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["obscuration"])
|
764 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "obscure"
|
765 |
+
|
766 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["fibrotic"])
|
767 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "fibrosis"
|
768 |
+
|
769 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["nodular", "nodularity"])
|
770 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "nodule"
|
771 |
+
|
772 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["ventilate"])
|
773 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "ventilation"
|
774 |
+
|
775 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["tortuosity"])
|
776 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "tortuous"
|
777 |
+
|
778 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["elongate"])
|
779 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "elongation"
|
780 |
+
|
781 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["elevate"])
|
782 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "elevation"
|
783 |
+
|
784 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["drain"])
|
785 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "drainage"
|
786 |
+
|
787 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["deviate"])
|
788 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "deviation"
|
789 |
+
|
790 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["consolidative", "consolidate"])
|
791 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "consolidation"
|
792 |
+
|
793 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["dilate", "dilatation"])
|
794 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "dilation"
|
795 |
+
|
796 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(
|
797 |
+
["hydropneumothorax", "pneumothoraces", "pneumothorace"]
|
798 |
+
)
|
799 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "pneumothorax"
|
800 |
+
|
801 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(["improvement", "improved"])
|
802 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "improve"
|
803 |
+
|
804 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(
|
805 |
+
[
|
806 |
+
"can not be assess",
|
807 |
+
"can not be evaluate",
|
808 |
+
"not well see",
|
809 |
+
"not well assess",
|
810 |
+
"can not be accurately assess",
|
811 |
+
"not well evaluate",
|
812 |
+
"not well visualize",
|
813 |
+
"difficult to evaluate",
|
814 |
+
"poorly see",
|
815 |
+
]
|
816 |
+
)
|
817 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "difficult to assess"
|
818 |
+
|
819 |
+
idx_replace = df_anatomy_label["obs_lemma"] == "pacer"
|
820 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "pacemaker"
|
821 |
+
|
822 |
+
idx_replace = df_anatomy_label["obs_lemma"].isin(
|
823 |
+
["infection", "infectious", "infectious process"]
|
824 |
+
)
|
825 |
+
df_anatomy_label.loc[idx_replace, "obs_lemma"] = "pneumonia"
|
826 |
+
|
827 |
+
df_anatomy_label.loc[df_anatomy_label["label"].isna(), "label"] = "OBS-NA"
|
828 |
+
|
829 |
+
# step 1: map anatomy name to landmark name
|
830 |
+
landmark_name = []
|
831 |
+
for index, row in tqdm(
|
832 |
+
df_anatomy_label.iterrows(), total=df_anatomy_label.shape[0]
|
833 |
+
):
|
834 |
+
x = row.anatomy
|
835 |
+
flag = False
|
836 |
+
for k, v in DICT_ANATOMICAL_LANDMARKS.items():
|
837 |
+
flag = anatomy_to_landmark(x, v["a"], v["m1"], v["m2"], v["sc"], v["t"])
|
838 |
+
if flag:
|
839 |
+
landmark_name.append(k)
|
840 |
+
break
|
841 |
+
if (not flag) & (row.anatomy == "unspecified"):
|
842 |
+
landmark_name.append("unspecified")
|
843 |
+
elif (not flag) & (row.anatomy != "unspecified"):
|
844 |
+
landmark_name.append("other")
|
845 |
+
|
846 |
+
df_anatomy_label["landmark_name"] = landmark_name
|
847 |
+
|
848 |
+
# create a new obs_lemma column to grouop other abnormal observation class
|
849 |
+
df_anatomy_label["obs_lemma_grp"] = df_anatomy_label["obs_lemma"]
|
850 |
+
|
851 |
+
idx1 = df_anatomy_label["obs_lemma"].isin(NORM_OBS)
|
852 |
+
idx2 = df_anatomy_label["obs_lemma"].isin(ABNORM_OBS)
|
853 |
+
idx3 = df_anatomy_label["obs_lemma"].isin(EXCLUDED_OBS)
|
854 |
+
|
855 |
+
df_anatomy_label.loc[idx3, "obs_lemma_grp"] = "excluded_obs"
|
856 |
+
|
857 |
+
idx = (~idx1) & (~idx2) & (~idx3) # abnormal observations that are in the tail
|
858 |
+
df_anatomy_label.loc[idx, "obs_lemma_grp"] = "tail_abnorm_obs"
|
859 |
+
|
860 |
+
# step 2: get landmark - observation adjacent matrix
|
861 |
+
dataset = LandmarkObservationAdjacentMatrix(
|
862 |
+
LANDMARK_NAME, OBSERVATION_CLASS, df_anatomy_label
|
863 |
+
)
|
864 |
+
loader = DataLoader(
|
865 |
+
dataset, batch_size=32, shuffle=False, num_workers=8, drop_last=False
|
866 |
+
)
|
867 |
+
|
868 |
+
sid_lst = []
|
869 |
+
adj_mtx_lst = []
|
870 |
+
for index, data in tqdm(enumerate(loader), total=len(loader)):
|
871 |
+
sid, landmark_observation_adj_mtx = data
|
872 |
+
sid_lst.append(sid)
|
873 |
+
adj_mtx_lst.append(landmark_observation_adj_mtx)
|
874 |
+
|
875 |
+
# step 3: convert outputs to a dictionary and then save to a pickel file
|
876 |
+
full_sids = np.concatenate(sid_lst, axis=0)
|
877 |
+
full_adj_mtx = np.concatenate(adj_mtx_lst, axis=0)
|
878 |
+
dict_adj_mtx = {}
|
879 |
+
for i in trange(len(full_sids)):
|
880 |
+
sid = full_sids[i]
|
881 |
+
dict_adj_mtx[sid] = full_adj_mtx[i]
|
882 |
+
|
883 |
+
np.save("landmark_observation_sids.npy", full_sids)
|
884 |
+
print("landmark_observation_sids.npy has been saved!")
|
885 |
+
np.save("landmark_observation_adj_mtx.npy", full_adj_mtx)
|
886 |
+
print("landmark_observation_sids.npy has been saved!")
|
887 |
+
|
888 |
+
|
889 |
+
if __name__ == "__main__":
|
890 |
+
args = parser.parse_args()
|
891 |
+
create_adj_matrix(args)
|
PreTrain_MeDSLIP/data_file/preprocessing/radgraph_itemized.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code copied from AGXNet:
|
3 |
+
https://github.com/batmanlab/AGXNet
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import pandas as pd
|
8 |
+
import json
|
9 |
+
from tqdm import tqdm
|
10 |
+
import nltk
|
11 |
+
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser(description="Itemize RadGraph Dataset.")
|
14 |
+
|
15 |
+
parser.add_argument(
|
16 |
+
"--data-path",
|
17 |
+
default="/PATH TO RADGRAPH DATA/RadGraph/physionet.org/files/radgraph/1.0.0/MIMIC-CXR_graphs.json",
|
18 |
+
help="RadGraph data path.",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--output-path",
|
22 |
+
default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-itemized.csv",
|
23 |
+
help="Output path for itemized RadGraph data.",
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def get_ids(key):
|
28 |
+
"""Convert keys in the RadGraph file into IDs"""
|
29 |
+
lst = key.split("/")
|
30 |
+
partition = lst[0] # dataset partition
|
31 |
+
pid = lst[1][1:] # patient id
|
32 |
+
sid = lst[2].split(".")[0][1:] # study id, remove .txt
|
33 |
+
return partition, pid, sid
|
34 |
+
|
35 |
+
|
36 |
+
def get_sen_from_token_ix(text, ix):
|
37 |
+
"""get the sentence to which the input token index belongs."""
|
38 |
+
sen_lst = nltk.sent_tokenize(text)
|
39 |
+
dict_ws = {}
|
40 |
+
ix_w = 0
|
41 |
+
ix_s = 0
|
42 |
+
for s in sen_lst:
|
43 |
+
words = nltk.word_tokenize(s)
|
44 |
+
for w in words:
|
45 |
+
dict_ws[ix_w] = ix_s
|
46 |
+
ix_w += 1
|
47 |
+
ix_s += 1
|
48 |
+
return dict_ws[ix], sen_lst[dict_ws[ix]]
|
49 |
+
|
50 |
+
|
51 |
+
def get_entity_relation(value):
|
52 |
+
"""itemize each relation"""
|
53 |
+
source_lst = []
|
54 |
+
target_lst = []
|
55 |
+
token_lst = []
|
56 |
+
token_ix_lst = []
|
57 |
+
label_lst = []
|
58 |
+
relation_lst = []
|
59 |
+
sen_lst = []
|
60 |
+
sen_ix_lst = []
|
61 |
+
|
62 |
+
text = value["text"]
|
63 |
+
|
64 |
+
entities = value["entities"]
|
65 |
+
for k, v in entities.items():
|
66 |
+
six, sen = get_sen_from_token_ix(text, v["start_ix"])
|
67 |
+
relations = v["relations"]
|
68 |
+
|
69 |
+
# source node has no out going edge
|
70 |
+
if (len(relations) == 0) or (relations[0] is None):
|
71 |
+
source_lst.append(k)
|
72 |
+
token_ix_lst.append(v["start_ix"])
|
73 |
+
token_lst.append(v["tokens"])
|
74 |
+
label_lst.append(v["label"])
|
75 |
+
relation_lst.append(None)
|
76 |
+
target_lst.append(None)
|
77 |
+
sen_ix_lst.append(six)
|
78 |
+
sen_lst.append(sen)
|
79 |
+
else:
|
80 |
+
for r in relations:
|
81 |
+
source_lst.append(k)
|
82 |
+
token_ix_lst.append(v["start_ix"])
|
83 |
+
token_lst.append(v["tokens"])
|
84 |
+
label_lst.append(v["label"])
|
85 |
+
relation_lst.append(r[0])
|
86 |
+
target_lst.append(r[1])
|
87 |
+
sen_ix_lst.append(six)
|
88 |
+
sen_lst.append(sen)
|
89 |
+
|
90 |
+
# save outputs in a dataframe
|
91 |
+
return pd.DataFrame(
|
92 |
+
{
|
93 |
+
"source": source_lst,
|
94 |
+
"token": token_lst,
|
95 |
+
"token_ix": token_ix_lst,
|
96 |
+
"label": label_lst,
|
97 |
+
"relation": relation_lst,
|
98 |
+
"target": target_lst,
|
99 |
+
"sentence_ix": sen_ix_lst,
|
100 |
+
"sentence": sen_lst,
|
101 |
+
}
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def radgraph_itemize(args):
|
106 |
+
"""Convert nested RadGraph data to itemized examples."""
|
107 |
+
|
108 |
+
print("Loading RadGraph data...")
|
109 |
+
f = open(args.data_path)
|
110 |
+
data = json.load(f)
|
111 |
+
print("RadGraph data is loaded.")
|
112 |
+
|
113 |
+
# create itemized RadGraph data
|
114 |
+
df_lst = []
|
115 |
+
pid_lst = []
|
116 |
+
sid_lst = []
|
117 |
+
text_lst = []
|
118 |
+
print("Itemizing RadGraph data...")
|
119 |
+
for key, value in tqdm(data.items()):
|
120 |
+
_, pid, sid = get_ids(key)
|
121 |
+
pid_lst.append(pid)
|
122 |
+
sid_lst.append(sid)
|
123 |
+
text_lst.append(data[key]["text"])
|
124 |
+
df = get_entity_relation(value)
|
125 |
+
df["subject_id"] = pid
|
126 |
+
df["study_id"] = sid
|
127 |
+
df_lst.append(df)
|
128 |
+
|
129 |
+
# entity level dataframe
|
130 |
+
df_itemized = pd.concat(df_lst)
|
131 |
+
|
132 |
+
# save dataframes to a .csv file
|
133 |
+
df_itemized.to_csv(args.output_path, index=False)
|
134 |
+
print("Outputs have been saved!")
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
args = parser.parse_args()
|
139 |
+
radgraph_itemize(args)
|
PreTrain_MeDSLIP/data_file/preprocessing/radgraph_parsed.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code copied from AGXNet:
|
3 |
+
https://github.com/batmanlab/AGXNet
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import pandas as pd
|
8 |
+
from tqdm import tqdm
|
9 |
+
import spacy
|
10 |
+
|
11 |
+
sp = spacy.load("en_core_web_sm")
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser(description="Pharse RadGraph Relations.")
|
14 |
+
|
15 |
+
parser.add_argument(
|
16 |
+
"--input-path",
|
17 |
+
default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-itemized.csv",
|
18 |
+
help="Itemized input data path.",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--output-path",
|
22 |
+
default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-sentence-parsed.csv",
|
23 |
+
help="Output path for parsed relations.",
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def obs_lemmatization(x):
|
28 |
+
"""
|
29 |
+
Lemmatize observation
|
30 |
+
Args:
|
31 |
+
x: a observation token
|
32 |
+
Return:
|
33 |
+
normalized observation
|
34 |
+
"""
|
35 |
+
w_lst = []
|
36 |
+
for word in sp(str(x)):
|
37 |
+
w_lst.append(word.lemma_)
|
38 |
+
return " ".join(w_lst)
|
39 |
+
|
40 |
+
|
41 |
+
def radgraph_parse(args):
|
42 |
+
"""Pharse RadGraph relations."""
|
43 |
+
|
44 |
+
print("Loading itemized RadGraph data...")
|
45 |
+
df_itemized = pd.read_csv(args.input_path)
|
46 |
+
|
47 |
+
# get all study_id
|
48 |
+
sid_lst = list(df_itemized["study_id"].unique())
|
49 |
+
|
50 |
+
tuple_lst = []
|
51 |
+
print("Preprocessing sentences...")
|
52 |
+
for sid in tqdm(sid_lst):
|
53 |
+
idx_s = df_itemized["study_id"] == sid
|
54 |
+
df_sid = df_itemized[idx_s]
|
55 |
+
|
56 |
+
# unique sentence index
|
57 |
+
sen_ids = list(df_sid["sentence_ix"].unique())
|
58 |
+
|
59 |
+
for si in sen_ids:
|
60 |
+
idx_sen = df_sid["sentence_ix"] == si
|
61 |
+
df_sen = df_sid[idx_sen]
|
62 |
+
sen = df_sen["sentence"].iloc[0]
|
63 |
+
|
64 |
+
# step 1, select all target anatomy entities (e.g., lobe) with label = ANAT-DP and target = NaN
|
65 |
+
idx_a = (df_sen["label"] == "ANAT-DP") & (df_sen["target"].isnull())
|
66 |
+
df_a = df_sen[idx_a]
|
67 |
+
|
68 |
+
if sum(idx_a) > 0:
|
69 |
+
for _, row_a in df_a.iterrows():
|
70 |
+
anatomy_source_keys = []
|
71 |
+
sen = row_a.sentence
|
72 |
+
source_key = row_a.source
|
73 |
+
|
74 |
+
# step 2, get detailed target anatomy (e.g., lower left lobe)
|
75 |
+
token_a = [row_a["token"].lower()]
|
76 |
+
anatomy_source_keys.append(source_key)
|
77 |
+
idx_t = (df_sen["label"] == "ANAT-DP") & (
|
78 |
+
df_sen["target"] == source_key
|
79 |
+
)
|
80 |
+
if sum(idx_t) > 0:
|
81 |
+
df_t = df_sen[idx_t]
|
82 |
+
for _, row in df_t.iterrows():
|
83 |
+
token_a += [row["token"].lower()]
|
84 |
+
anatomy_source_keys.append(
|
85 |
+
row["source"]
|
86 |
+
) # save keys of all anatomy token, i.e., lower, left, lobe
|
87 |
+
anatomy = "|".join(token_a)
|
88 |
+
|
89 |
+
else:
|
90 |
+
anatomy = row_a["token"].lower()
|
91 |
+
|
92 |
+
# step 3: get observations associated with the target anatomy (e.g., normal, effusion)
|
93 |
+
idx_o = (
|
94 |
+
(df_sen["label"].isin(["OBS-DA", "OBS-DP", "OBS-U"]))
|
95 |
+
& (df_sen["target"].isin(anatomy_source_keys))
|
96 |
+
& (df_sen["relation"] == "located_at")
|
97 |
+
)
|
98 |
+
if sum(idx_o) > 0:
|
99 |
+
df_o = df_sen[idx_o]
|
100 |
+
|
101 |
+
anatomy_lst = []
|
102 |
+
obs_lst = []
|
103 |
+
label_lst = []
|
104 |
+
obs_modify_lst = []
|
105 |
+
obs_suggestive_lst = []
|
106 |
+
|
107 |
+
for _, row_o in df_o.iterrows():
|
108 |
+
anatomy_lst.append(anatomy)
|
109 |
+
obs_lst.append(row_o["token"].lower())
|
110 |
+
label_lst.append(row_o["label"])
|
111 |
+
|
112 |
+
# step 4: get obs modification
|
113 |
+
idx_o_m = (df_sen["target"] == row_o.source) & (
|
114 |
+
df_sen["relation"] == "modify"
|
115 |
+
)
|
116 |
+
obs_modify = None
|
117 |
+
if sum(idx_o_m) > 0:
|
118 |
+
df_o_m = df_sen[idx_o_m]
|
119 |
+
temp_lst = []
|
120 |
+
for _, row_om in df_o_m.iterrows():
|
121 |
+
# if the modification is present
|
122 |
+
if row_om.label == "OBS-DP":
|
123 |
+
temp_lst.append(row_om["token"].lower())
|
124 |
+
if len(temp_lst) > 0:
|
125 |
+
obs_modify = "|".join(temp_lst)
|
126 |
+
obs_modify_lst.append(obs_modify)
|
127 |
+
|
128 |
+
# step 5: get suggestive of obs
|
129 |
+
idx_o_s = (df_sen["target"] == row_o.source) & (
|
130 |
+
df_sen["relation"] == "suggestive_of"
|
131 |
+
)
|
132 |
+
obs_suggestive = None
|
133 |
+
if sum(idx_o_s) > 0:
|
134 |
+
df_o_s = df_sen[idx_o_s]
|
135 |
+
temp_lst = []
|
136 |
+
for _, row_os in df_o_s.iterrows():
|
137 |
+
# if the modification is present
|
138 |
+
if row_os.label == "OBS-DP":
|
139 |
+
temp_lst.append(row_os["token"].lower())
|
140 |
+
if len(temp_lst) > 0:
|
141 |
+
obs_suggestive = "|".join(temp_lst)
|
142 |
+
obs_suggestive_lst.append(obs_suggestive)
|
143 |
+
|
144 |
+
else:
|
145 |
+
anatomy_lst = [anatomy]
|
146 |
+
obs_lst = [None]
|
147 |
+
label_lst = [None]
|
148 |
+
obs_modify_lst = [None]
|
149 |
+
obs_suggestive_lst = [None]
|
150 |
+
|
151 |
+
# step 4: get observations that are not associated with the target anatomy
|
152 |
+
idx_oo = (
|
153 |
+
(df_sen["label"].isin(["OBS-DA", "OBS-DP", "OBS-U"]))
|
154 |
+
& (df_sen["target"].isna())
|
155 |
+
& (df_sen["relation"].isna())
|
156 |
+
)
|
157 |
+
if sum(idx_oo) > 0:
|
158 |
+
df_oo = df_sen[idx_oo]
|
159 |
+
for _, row_oo in df_oo.iterrows():
|
160 |
+
anatomy_lst.append("unspecified")
|
161 |
+
obs_lst.append(row_oo["token"].lower())
|
162 |
+
label_lst.append(row_oo["label"])
|
163 |
+
# obs_modify_lst.append(None)
|
164 |
+
# obs_suggestive_lst.append(None)
|
165 |
+
|
166 |
+
# step 5: get obs modification
|
167 |
+
idx_o_m = (df_sen["target"] == row_oo.source) & (
|
168 |
+
df_sen["relation"] == "modify"
|
169 |
+
)
|
170 |
+
obs_modify = None
|
171 |
+
if sum(idx_o_m) > 0:
|
172 |
+
df_o_m = df_sen[idx_o_m]
|
173 |
+
temp_lst = []
|
174 |
+
for _, row_om in df_o_m.iterrows():
|
175 |
+
# if the modification is present
|
176 |
+
if row_om.label == "OBS-DP":
|
177 |
+
temp_lst.append(row_om["token"].lower())
|
178 |
+
if len(temp_lst) > 0:
|
179 |
+
obs_modify = "|".join(temp_lst)
|
180 |
+
obs_modify_lst.append(obs_modify)
|
181 |
+
|
182 |
+
# step 5: get suggestive of obs
|
183 |
+
idx_o_s = (df_sen["target"] == row_oo.source) & (
|
184 |
+
df_sen["relation"] == "suggestive_of"
|
185 |
+
)
|
186 |
+
obs_suggestive = None
|
187 |
+
if sum(idx_o_s) > 0:
|
188 |
+
df_o_s = df_sen[idx_o_s]
|
189 |
+
temp_lst = []
|
190 |
+
for _, row_os in df_o_s.iterrows():
|
191 |
+
# if the modification is present
|
192 |
+
if row_os.label == "OBS-DP":
|
193 |
+
temp_lst.append(row_os["token"].lower())
|
194 |
+
if len(temp_lst) > 0:
|
195 |
+
obs_suggestive = "|".join(temp_lst)
|
196 |
+
obs_suggestive_lst.append(obs_suggestive)
|
197 |
+
|
198 |
+
# step 6: create tuple of 7 values (sid, sentence_id, sentence, anatomy, obs, label)
|
199 |
+
t_lst = []
|
200 |
+
for i in range(len(obs_lst)):
|
201 |
+
t_lst.append(
|
202 |
+
(
|
203 |
+
sid,
|
204 |
+
si,
|
205 |
+
sen,
|
206 |
+
anatomy_lst[i],
|
207 |
+
obs_lst[i],
|
208 |
+
label_lst[i],
|
209 |
+
obs_modify_lst[i],
|
210 |
+
obs_suggestive_lst[i],
|
211 |
+
)
|
212 |
+
)
|
213 |
+
|
214 |
+
# remove duplicates caused by 1 obs "located_at" multiple anatomies
|
215 |
+
tuple_lst.append(list(set(t_lst)))
|
216 |
+
|
217 |
+
# if the sentence does not have any ANATOMY token
|
218 |
+
else:
|
219 |
+
idx_o = (df_sen["label"].isin(["OBS-DA", "OBS-DP", "OBS-U"])) & (
|
220 |
+
df_sen["target"].isnull()
|
221 |
+
)
|
222 |
+
if sum(idx_o) > 0:
|
223 |
+
df_o = df_sen[idx_o]
|
224 |
+
|
225 |
+
obs_lst = []
|
226 |
+
label_lst = []
|
227 |
+
obs_modify_lst = []
|
228 |
+
obs_suggestive_lst = []
|
229 |
+
|
230 |
+
for _, row_o in df_o.iterrows():
|
231 |
+
obs_lst.append(row_o["token"].lower())
|
232 |
+
label_lst.append(row_o["label"])
|
233 |
+
|
234 |
+
# step 4: get obs modification
|
235 |
+
idx_o_m = (df_sen["target"] == row_o.source) & (
|
236 |
+
df_sen["relation"] == "modify"
|
237 |
+
)
|
238 |
+
obs_modify = None
|
239 |
+
if sum(idx_o_m) > 0:
|
240 |
+
df_o_m = df_sen[idx_o_m]
|
241 |
+
temp_lst = []
|
242 |
+
for _, row_om in df_o_m.iterrows():
|
243 |
+
# if the modification is present
|
244 |
+
if row_om.label == "OBS-DP":
|
245 |
+
temp_lst.append(row_om["token"].lower())
|
246 |
+
if len(temp_lst) > 0:
|
247 |
+
obs_modify = "|".join(temp_lst)
|
248 |
+
obs_modify_lst.append(obs_modify)
|
249 |
+
|
250 |
+
# step 5: get suggestive of obs
|
251 |
+
idx_o_s = (df_sen["target"] == row_o.source) & (
|
252 |
+
df_sen["relation"] == "suggestive_of"
|
253 |
+
)
|
254 |
+
obs_suggestive = None
|
255 |
+
if sum(idx_o_s) > 0:
|
256 |
+
df_o_s = df_sen[idx_o_s]
|
257 |
+
temp_lst = []
|
258 |
+
for _, row_os in df_o_s.iterrows():
|
259 |
+
# if the modification is present
|
260 |
+
if row_os.label == "OBS-DP":
|
261 |
+
temp_lst.append(row_os["token"].lower())
|
262 |
+
if len(temp_lst) > 0:
|
263 |
+
obs_suggestive = "|".join(temp_lst)
|
264 |
+
obs_suggestive_lst.append(obs_suggestive)
|
265 |
+
else:
|
266 |
+
obs_lst = [None]
|
267 |
+
label_lst = [None]
|
268 |
+
obs_modify_lst = [None]
|
269 |
+
obs_suggestive_lst = [None]
|
270 |
+
|
271 |
+
# step 6: create tuple of 7 values (sid, sentence_id, sentence, anatomy, obs, label)
|
272 |
+
t_lst = []
|
273 |
+
for i in range(len(obs_lst)):
|
274 |
+
t_lst.append(
|
275 |
+
(
|
276 |
+
sid,
|
277 |
+
si,
|
278 |
+
sen,
|
279 |
+
"unspecified",
|
280 |
+
obs_lst[i],
|
281 |
+
label_lst[i],
|
282 |
+
obs_modify_lst[i],
|
283 |
+
obs_suggestive_lst[i],
|
284 |
+
)
|
285 |
+
)
|
286 |
+
|
287 |
+
# remove duplicates if existing
|
288 |
+
tuple_lst.append(list(set(t_lst)))
|
289 |
+
|
290 |
+
# flatten nested list
|
291 |
+
df_lst = [item for sublist in tuple_lst for item in sublist]
|
292 |
+
df_anatomy_label = pd.DataFrame(
|
293 |
+
df_lst,
|
294 |
+
columns=[
|
295 |
+
"study_id",
|
296 |
+
"sen_id",
|
297 |
+
"sentence",
|
298 |
+
"anatomy",
|
299 |
+
"observation",
|
300 |
+
"label",
|
301 |
+
"obs_modify",
|
302 |
+
"obs_suggestive",
|
303 |
+
],
|
304 |
+
)
|
305 |
+
|
306 |
+
# lemmatize observation tokens (e.g., normalize opacities to opacity)
|
307 |
+
obs_lemma_lst = []
|
308 |
+
print("Lemmatizing observation tokens...")
|
309 |
+
for t in tqdm(df_lst):
|
310 |
+
obs = t[4]
|
311 |
+
obs_lemma = obs_lemmatization(obs)
|
312 |
+
obs_lemma_lst.append(obs_lemma)
|
313 |
+
|
314 |
+
# save preprocessed sentence level data
|
315 |
+
df_anatomy_label["obs_lemma"] = obs_lemma_lst
|
316 |
+
df_anatomy_label.to_csv(args.output_path, index=False)
|
317 |
+
print("Output file has been saved!")
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == "__main__":
|
321 |
+
args = parser.parse_args()
|
322 |
+
radgraph_parse(args)
|
PreTrain_MeDSLIP/dataset/dataset.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
import PIL
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from torchvision import transforms
|
8 |
+
from PIL import Image
|
9 |
+
import random
|
10 |
+
from dataset.randaugment import RandomAugment
|
11 |
+
|
12 |
+
|
13 |
+
class MeDSLIP_Dataset(Dataset):
|
14 |
+
def __init__(self, csv_path, np_path, mode="train", num_neg_samples=7):
|
15 |
+
self.num_neg_samples = num_neg_samples
|
16 |
+
self.ann = json.load(open(csv_path, "r"))
|
17 |
+
self.img_path_list = list(self.ann)
|
18 |
+
self.anaomy_list = [
|
19 |
+
"trachea",
|
20 |
+
"left_hilar",
|
21 |
+
"right_hilar",
|
22 |
+
"hilar_unspec",
|
23 |
+
"left_pleural",
|
24 |
+
"right_pleural",
|
25 |
+
"pleural_unspec",
|
26 |
+
"heart_size",
|
27 |
+
"heart_border",
|
28 |
+
"left_diaphragm",
|
29 |
+
"right_diaphragm",
|
30 |
+
"diaphragm_unspec",
|
31 |
+
"retrocardiac",
|
32 |
+
"lower_left_lobe",
|
33 |
+
"upper_left_lobe",
|
34 |
+
"lower_right_lobe",
|
35 |
+
"middle_right_lobe",
|
36 |
+
"upper_right_lobe",
|
37 |
+
"left_lower_lung",
|
38 |
+
"left_mid_lung",
|
39 |
+
"left_upper_lung",
|
40 |
+
"left_apical_lung",
|
41 |
+
"left_lung_unspec",
|
42 |
+
"right_lower_lung",
|
43 |
+
"right_mid_lung",
|
44 |
+
"right_upper_lung",
|
45 |
+
"right_apical_lung",
|
46 |
+
"right_lung_unspec",
|
47 |
+
"lung_apices",
|
48 |
+
"lung_bases",
|
49 |
+
"left_costophrenic",
|
50 |
+
"right_costophrenic",
|
51 |
+
"costophrenic_unspec",
|
52 |
+
"cardiophrenic_sulcus",
|
53 |
+
"mediastinal",
|
54 |
+
"spine",
|
55 |
+
"clavicle",
|
56 |
+
"rib",
|
57 |
+
"stomach",
|
58 |
+
"right_atrium",
|
59 |
+
"right_ventricle",
|
60 |
+
"aorta",
|
61 |
+
"svc",
|
62 |
+
"interstitium",
|
63 |
+
"parenchymal",
|
64 |
+
"cavoatrial_junction",
|
65 |
+
"cardiopulmonary",
|
66 |
+
"pulmonary",
|
67 |
+
"lung_volumes",
|
68 |
+
"unspecified",
|
69 |
+
"other",
|
70 |
+
]
|
71 |
+
self.obs_list = [
|
72 |
+
"normal",
|
73 |
+
"clear",
|
74 |
+
"sharp",
|
75 |
+
"sharply",
|
76 |
+
"unremarkable",
|
77 |
+
"intact",
|
78 |
+
"stable",
|
79 |
+
"free",
|
80 |
+
"effusion",
|
81 |
+
"opacity",
|
82 |
+
"pneumothorax",
|
83 |
+
"edema",
|
84 |
+
"atelectasis",
|
85 |
+
"tube",
|
86 |
+
"consolidation",
|
87 |
+
"process",
|
88 |
+
"abnormality",
|
89 |
+
"enlarge",
|
90 |
+
"tip",
|
91 |
+
"low",
|
92 |
+
"pneumonia",
|
93 |
+
"line",
|
94 |
+
"congestion",
|
95 |
+
"catheter",
|
96 |
+
"cardiomegaly",
|
97 |
+
"fracture",
|
98 |
+
"air",
|
99 |
+
"tortuous",
|
100 |
+
"lead",
|
101 |
+
"disease",
|
102 |
+
"calcification",
|
103 |
+
"prominence",
|
104 |
+
"device",
|
105 |
+
"engorgement",
|
106 |
+
"picc",
|
107 |
+
"clip",
|
108 |
+
"elevation",
|
109 |
+
"expand",
|
110 |
+
"nodule",
|
111 |
+
"wire",
|
112 |
+
"fluid",
|
113 |
+
"degenerative",
|
114 |
+
"pacemaker",
|
115 |
+
"thicken",
|
116 |
+
"marking",
|
117 |
+
"scar",
|
118 |
+
"hyperinflate",
|
119 |
+
"blunt",
|
120 |
+
"loss",
|
121 |
+
"widen",
|
122 |
+
"collapse",
|
123 |
+
"density",
|
124 |
+
"emphysema",
|
125 |
+
"aerate",
|
126 |
+
"mass",
|
127 |
+
"crowd",
|
128 |
+
"infiltrate",
|
129 |
+
"obscure",
|
130 |
+
"deformity",
|
131 |
+
"hernia",
|
132 |
+
"drainage",
|
133 |
+
"distention",
|
134 |
+
"shift",
|
135 |
+
"stent",
|
136 |
+
"pressure",
|
137 |
+
"lesion",
|
138 |
+
"finding",
|
139 |
+
"borderline",
|
140 |
+
"hardware",
|
141 |
+
"dilation",
|
142 |
+
"chf",
|
143 |
+
"redistribution",
|
144 |
+
"aspiration",
|
145 |
+
"tail_abnorm_obs",
|
146 |
+
"excluded_obs",
|
147 |
+
]
|
148 |
+
self.rad_graph_results = np.load(np_path)
|
149 |
+
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
150 |
+
if mode == "train":
|
151 |
+
self.transform = transforms.Compose(
|
152 |
+
[
|
153 |
+
transforms.RandomResizedCrop(
|
154 |
+
224, scale=(0.2, 1.0), interpolation=Image.BICUBIC
|
155 |
+
),
|
156 |
+
transforms.RandomHorizontalFlip(),
|
157 |
+
RandomAugment(
|
158 |
+
2,
|
159 |
+
7,
|
160 |
+
isPIL=True,
|
161 |
+
augs=[
|
162 |
+
"Identity",
|
163 |
+
"AutoContrast",
|
164 |
+
"Equalize",
|
165 |
+
"Brightness",
|
166 |
+
"Sharpness",
|
167 |
+
"ShearX",
|
168 |
+
"ShearY",
|
169 |
+
"TranslateX",
|
170 |
+
"TranslateY",
|
171 |
+
"Rotate",
|
172 |
+
],
|
173 |
+
),
|
174 |
+
transforms.ToTensor(),
|
175 |
+
normalize,
|
176 |
+
]
|
177 |
+
)
|
178 |
+
if mode == "test":
|
179 |
+
self.transform = transforms.Compose(
|
180 |
+
[
|
181 |
+
transforms.Resize([224, 224]),
|
182 |
+
transforms.ToTensor(),
|
183 |
+
normalize,
|
184 |
+
]
|
185 |
+
)
|
186 |
+
|
187 |
+
def __getitem__(self, index):
|
188 |
+
img_path = self.img_path_list[index]
|
189 |
+
class_label = self.rad_graph_results[
|
190 |
+
self.ann[img_path]["labels_id"], :, :
|
191 |
+
]
|
192 |
+
labels_pathology = np.zeros(class_label.shape[-1]) - 1
|
193 |
+
labels_anatomy = np.zeros(class_label.shape[0]) - 1
|
194 |
+
labels_pathology, index_list_pathology = self.triplet_extraction_pathology(
|
195 |
+
class_label
|
196 |
+
)
|
197 |
+
labels_anatomy, index_list_anatomy = self.triplet_extraction_anatomy(
|
198 |
+
class_label
|
199 |
+
)
|
200 |
+
index_list_pathology = np.array(index_list_pathology)
|
201 |
+
index_list_anatomy = np.array(index_list_anatomy)
|
202 |
+
|
203 |
+
img = PIL.Image.open(img_path).convert("RGB")
|
204 |
+
image = self.transform(img)
|
205 |
+
|
206 |
+
return {
|
207 |
+
"image": image,
|
208 |
+
"label_pathology": labels_pathology,
|
209 |
+
"index_pathology": index_list_pathology,
|
210 |
+
"label_anatomy": labels_anatomy,
|
211 |
+
"index_anatomy": index_list_anatomy,
|
212 |
+
"matrix": class_label,
|
213 |
+
}
|
214 |
+
|
215 |
+
def triplet_extraction_pathology(self, class_label):
|
216 |
+
"""
|
217 |
+
This is for ProtoCL. Therefore, we need to extract anatomies to use in pathology stream.
|
218 |
+
"""
|
219 |
+
|
220 |
+
exist_labels = np.zeros(class_label.shape[-1]) - 1
|
221 |
+
anatomy_list = []
|
222 |
+
for i in range(class_label.shape[1]):
|
223 |
+
temp_list = []
|
224 |
+
### extract the exist label for each pathology and maintain -1 if not mentioned. ###
|
225 |
+
if 0 in class_label[:, i]:
|
226 |
+
exist_labels[i] = 0
|
227 |
+
|
228 |
+
if 1 in class_label[:, i]:
|
229 |
+
exist_labels[i] = 1
|
230 |
+
### if the pathology exists try to get its anatomy.###
|
231 |
+
### Note that, the contrastive loss will only be caculated on exist pathology as it is meaningless to predict their anatomy for the non-exist entities###
|
232 |
+
temp_list.append(-1)
|
233 |
+
|
234 |
+
try:
|
235 |
+
temp_list = temp_list + random.sample(
|
236 |
+
np.where(class_label[:, i] != 1)[0].tolist(),
|
237 |
+
self.num_neg_samples,
|
238 |
+
)
|
239 |
+
except:
|
240 |
+
print("fatal error")
|
241 |
+
if temp_list == []:
|
242 |
+
temp_list = temp_list + random.sample(
|
243 |
+
np.where(class_label[:, i] != 1)[0].tolist(),
|
244 |
+
self.num_neg_samples + 1,
|
245 |
+
)
|
246 |
+
anatomy_list.append(temp_list)
|
247 |
+
|
248 |
+
return exist_labels, anatomy_list
|
249 |
+
|
250 |
+
def triplet_extraction_anatomy(self, class_label):
|
251 |
+
"""
|
252 |
+
This is for ProtoCL. Therefore, we need to extract pathological labels to use in anatomy stream.
|
253 |
+
"""
|
254 |
+
exist_labels = np.zeros(class_label.shape[0]) - 1
|
255 |
+
pathology_list = []
|
256 |
+
for i in range(class_label.shape[0]):
|
257 |
+
temp_list = []
|
258 |
+
### extract the exist label for each pathology and maintain -1 if not mentioned. ###
|
259 |
+
if 0 in class_label[i, :]:
|
260 |
+
exist_labels[i] = 0
|
261 |
+
|
262 |
+
if 1 in class_label[i, :]:
|
263 |
+
exist_labels[i] = 1
|
264 |
+
### if the pathology exists try to get its anatomy.###
|
265 |
+
### Note that, the contrastive loss will only be caculated on exist pathology as it is meaningless to predict their anatomy for the non-exist entities###
|
266 |
+
temp_list.append(-1)
|
267 |
+
|
268 |
+
try:
|
269 |
+
temp_list = temp_list + random.sample(
|
270 |
+
np.where(class_label[i, :] != 1)[0].tolist(),
|
271 |
+
self.num_neg_samples,
|
272 |
+
)
|
273 |
+
except:
|
274 |
+
print("fatal error")
|
275 |
+
if temp_list == []:
|
276 |
+
temp_list = temp_list + random.sample(
|
277 |
+
np.where(class_label[i, :] != 1)[0].tolist(),
|
278 |
+
self.num_neg_samples + 1,
|
279 |
+
)
|
280 |
+
pathology_list.append(temp_list)
|
281 |
+
|
282 |
+
return exist_labels, pathology_list
|
283 |
+
|
284 |
+
def __len__(self):
|
285 |
+
return len(self.ann)
|
286 |
+
|
287 |
+
|
288 |
+
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
289 |
+
loaders = []
|
290 |
+
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
|
291 |
+
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
|
292 |
+
):
|
293 |
+
if is_train:
|
294 |
+
shuffle = sampler is None
|
295 |
+
drop_last = True
|
296 |
+
else:
|
297 |
+
shuffle = False
|
298 |
+
drop_last = False
|
299 |
+
loader = DataLoader(
|
300 |
+
dataset,
|
301 |
+
batch_size=bs,
|
302 |
+
num_workers=n_worker,
|
303 |
+
pin_memory=True,
|
304 |
+
sampler=sampler,
|
305 |
+
shuffle=shuffle,
|
306 |
+
collate_fn=collate_fn,
|
307 |
+
drop_last=drop_last,
|
308 |
+
)
|
309 |
+
loaders.append(loader)
|
310 |
+
return loaders
|
PreTrain_MeDSLIP/dataset/randaugment.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
## aug functions
|
6 |
+
def identity_func(img):
|
7 |
+
return img
|
8 |
+
|
9 |
+
|
10 |
+
def autocontrast_func(img, cutoff=0):
|
11 |
+
"""
|
12 |
+
same output as PIL.ImageOps.autocontrast
|
13 |
+
"""
|
14 |
+
n_bins = 256
|
15 |
+
|
16 |
+
def tune_channel(ch):
|
17 |
+
n = ch.size
|
18 |
+
cut = cutoff * n // 100
|
19 |
+
if cut == 0:
|
20 |
+
high, low = ch.max(), ch.min()
|
21 |
+
else:
|
22 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
23 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
24 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
25 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
26 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
27 |
+
if high <= low:
|
28 |
+
table = np.arange(n_bins)
|
29 |
+
else:
|
30 |
+
scale = (n_bins - 1) / (high - low)
|
31 |
+
offset = -low * scale
|
32 |
+
table = np.arange(n_bins) * scale + offset
|
33 |
+
table[table < 0] = 0
|
34 |
+
table[table > n_bins - 1] = n_bins - 1
|
35 |
+
table = table.clip(0, 255).astype(np.uint8)
|
36 |
+
return table[ch]
|
37 |
+
|
38 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
39 |
+
out = cv2.merge(channels)
|
40 |
+
return out
|
41 |
+
|
42 |
+
|
43 |
+
def equalize_func(img):
|
44 |
+
"""
|
45 |
+
same output as PIL.ImageOps.equalize
|
46 |
+
PIL's implementation is different from cv2.equalize
|
47 |
+
"""
|
48 |
+
n_bins = 256
|
49 |
+
|
50 |
+
def tune_channel(ch):
|
51 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
52 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
53 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
54 |
+
if step == 0:
|
55 |
+
return ch
|
56 |
+
n = np.empty_like(hist)
|
57 |
+
n[0] = step // 2
|
58 |
+
n[1:] = hist[:-1]
|
59 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
60 |
+
return table[ch]
|
61 |
+
|
62 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
63 |
+
out = cv2.merge(channels)
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
68 |
+
"""
|
69 |
+
like PIL, rotate by degree, not radians
|
70 |
+
"""
|
71 |
+
H, W = img.shape[0], img.shape[1]
|
72 |
+
center = W / 2, H / 2
|
73 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
74 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
75 |
+
return out
|
76 |
+
|
77 |
+
|
78 |
+
def solarize_func(img, thresh=128):
|
79 |
+
"""
|
80 |
+
same output as PIL.ImageOps.posterize
|
81 |
+
"""
|
82 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
83 |
+
table = table.clip(0, 255).astype(np.uint8)
|
84 |
+
out = table[img]
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
def color_func(img, factor):
|
89 |
+
"""
|
90 |
+
same output as PIL.ImageEnhance.Color
|
91 |
+
"""
|
92 |
+
## implementation according to PIL definition, quite slow
|
93 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
94 |
+
# out = blend(degenerate, img, factor)
|
95 |
+
# M = (
|
96 |
+
# np.eye(3) * factor
|
97 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
98 |
+
# )[np.newaxis, np.newaxis, :]
|
99 |
+
M = np.float32(
|
100 |
+
[[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
|
101 |
+
) * factor + np.float32([[0.114], [0.587], [0.299]])
|
102 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
def contrast_func(img, factor):
|
107 |
+
"""
|
108 |
+
same output as PIL.ImageEnhance.Contrast
|
109 |
+
"""
|
110 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
111 |
+
table = (
|
112 |
+
np.array([(el - mean) * factor + mean for el in range(256)])
|
113 |
+
.clip(0, 255)
|
114 |
+
.astype(np.uint8)
|
115 |
+
)
|
116 |
+
out = table[img]
|
117 |
+
return out
|
118 |
+
|
119 |
+
|
120 |
+
def brightness_func(img, factor):
|
121 |
+
"""
|
122 |
+
same output as PIL.ImageEnhance.Contrast
|
123 |
+
"""
|
124 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
125 |
+
out = table[img]
|
126 |
+
return out
|
127 |
+
|
128 |
+
|
129 |
+
def sharpness_func(img, factor):
|
130 |
+
"""
|
131 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
132 |
+
areas are same
|
133 |
+
"""
|
134 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
135 |
+
kernel[1][1] = 5
|
136 |
+
kernel /= 13
|
137 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
138 |
+
if factor == 0.0:
|
139 |
+
out = degenerate
|
140 |
+
elif factor == 1.0:
|
141 |
+
out = img
|
142 |
+
else:
|
143 |
+
out = img.astype(np.float32)
|
144 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
145 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
146 |
+
out = out.astype(np.uint8)
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
151 |
+
H, W = img.shape[0], img.shape[1]
|
152 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
153 |
+
out = cv2.warpAffine(
|
154 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
155 |
+
).astype(np.uint8)
|
156 |
+
return out
|
157 |
+
|
158 |
+
|
159 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
160 |
+
"""
|
161 |
+
same output as PIL.Image.transform
|
162 |
+
"""
|
163 |
+
H, W = img.shape[0], img.shape[1]
|
164 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
165 |
+
out = cv2.warpAffine(
|
166 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
167 |
+
).astype(np.uint8)
|
168 |
+
return out
|
169 |
+
|
170 |
+
|
171 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
172 |
+
"""
|
173 |
+
same output as PIL.Image.transform
|
174 |
+
"""
|
175 |
+
H, W = img.shape[0], img.shape[1]
|
176 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
177 |
+
out = cv2.warpAffine(
|
178 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
179 |
+
).astype(np.uint8)
|
180 |
+
return out
|
181 |
+
|
182 |
+
|
183 |
+
def posterize_func(img, bits):
|
184 |
+
"""
|
185 |
+
same output as PIL.ImageOps.posterize
|
186 |
+
"""
|
187 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
188 |
+
return out
|
189 |
+
|
190 |
+
|
191 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
192 |
+
H, W = img.shape[0], img.shape[1]
|
193 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
194 |
+
out = cv2.warpAffine(
|
195 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
196 |
+
).astype(np.uint8)
|
197 |
+
return out
|
198 |
+
|
199 |
+
|
200 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
201 |
+
replace = np.array(replace, dtype=np.uint8)
|
202 |
+
H, W = img.shape[0], img.shape[1]
|
203 |
+
rh, rw = np.random.random(2)
|
204 |
+
pad_size = pad_size // 2
|
205 |
+
ch, cw = int(rh * H), int(rw * W)
|
206 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
207 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
208 |
+
out = img.copy()
|
209 |
+
out[x1:x2, y1:y2, :] = replace
|
210 |
+
return out
|
211 |
+
|
212 |
+
|
213 |
+
### level to args
|
214 |
+
def enhance_level_to_args(MAX_LEVEL):
|
215 |
+
def level_to_args(level):
|
216 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
217 |
+
|
218 |
+
return level_to_args
|
219 |
+
|
220 |
+
|
221 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
222 |
+
def level_to_args(level):
|
223 |
+
level = (level / MAX_LEVEL) * 0.3
|
224 |
+
if np.random.random() > 0.5:
|
225 |
+
level = -level
|
226 |
+
return (level, replace_value)
|
227 |
+
|
228 |
+
return level_to_args
|
229 |
+
|
230 |
+
|
231 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
232 |
+
def level_to_args(level):
|
233 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
234 |
+
if np.random.random() > 0.5:
|
235 |
+
level = -level
|
236 |
+
return (level, replace_value)
|
237 |
+
|
238 |
+
return level_to_args
|
239 |
+
|
240 |
+
|
241 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
242 |
+
def level_to_args(level):
|
243 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
244 |
+
return (level, replace_value)
|
245 |
+
|
246 |
+
return level_to_args
|
247 |
+
|
248 |
+
|
249 |
+
def solarize_level_to_args(MAX_LEVEL):
|
250 |
+
def level_to_args(level):
|
251 |
+
level = int((level / MAX_LEVEL) * 256)
|
252 |
+
return (level,)
|
253 |
+
|
254 |
+
return level_to_args
|
255 |
+
|
256 |
+
|
257 |
+
def none_level_to_args(level):
|
258 |
+
return ()
|
259 |
+
|
260 |
+
|
261 |
+
def posterize_level_to_args(MAX_LEVEL):
|
262 |
+
def level_to_args(level):
|
263 |
+
level = int((level / MAX_LEVEL) * 4)
|
264 |
+
return (level,)
|
265 |
+
|
266 |
+
return level_to_args
|
267 |
+
|
268 |
+
|
269 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
270 |
+
def level_to_args(level):
|
271 |
+
level = (level / MAX_LEVEL) * 30
|
272 |
+
if np.random.random() < 0.5:
|
273 |
+
level = -level
|
274 |
+
return (level, replace_value)
|
275 |
+
|
276 |
+
return level_to_args
|
277 |
+
|
278 |
+
|
279 |
+
func_dict = {
|
280 |
+
"Identity": identity_func,
|
281 |
+
"AutoContrast": autocontrast_func,
|
282 |
+
"Equalize": equalize_func,
|
283 |
+
"Rotate": rotate_func,
|
284 |
+
"Solarize": solarize_func,
|
285 |
+
"Color": color_func,
|
286 |
+
"Contrast": contrast_func,
|
287 |
+
"Brightness": brightness_func,
|
288 |
+
"Sharpness": sharpness_func,
|
289 |
+
"ShearX": shear_x_func,
|
290 |
+
"TranslateX": translate_x_func,
|
291 |
+
"TranslateY": translate_y_func,
|
292 |
+
"Posterize": posterize_func,
|
293 |
+
"ShearY": shear_y_func,
|
294 |
+
}
|
295 |
+
|
296 |
+
translate_const = 10
|
297 |
+
MAX_LEVEL = 10
|
298 |
+
replace_value = (128, 128, 128)
|
299 |
+
arg_dict = {
|
300 |
+
"Identity": none_level_to_args,
|
301 |
+
"AutoContrast": none_level_to_args,
|
302 |
+
"Equalize": none_level_to_args,
|
303 |
+
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
|
304 |
+
"Solarize": solarize_level_to_args(MAX_LEVEL),
|
305 |
+
"Color": enhance_level_to_args(MAX_LEVEL),
|
306 |
+
"Contrast": enhance_level_to_args(MAX_LEVEL),
|
307 |
+
"Brightness": enhance_level_to_args(MAX_LEVEL),
|
308 |
+
"Sharpness": enhance_level_to_args(MAX_LEVEL),
|
309 |
+
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
|
310 |
+
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
311 |
+
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
312 |
+
"Posterize": posterize_level_to_args(MAX_LEVEL),
|
313 |
+
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
|
314 |
+
}
|
315 |
+
|
316 |
+
|
317 |
+
class RandomAugment(object):
|
318 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
319 |
+
self.N = N
|
320 |
+
self.M = M
|
321 |
+
self.isPIL = isPIL
|
322 |
+
if augs:
|
323 |
+
self.augs = augs
|
324 |
+
else:
|
325 |
+
self.augs = list(arg_dict.keys())
|
326 |
+
|
327 |
+
def get_random_ops(self):
|
328 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
329 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
330 |
+
|
331 |
+
def __call__(self, img):
|
332 |
+
if self.isPIL:
|
333 |
+
img = np.array(img)
|
334 |
+
ops = self.get_random_ops()
|
335 |
+
for name, prob, level in ops:
|
336 |
+
if np.random.random() > prob:
|
337 |
+
continue
|
338 |
+
args = arg_dict[name](level)
|
339 |
+
img = func_dict[name](img, *args)
|
340 |
+
return img
|
341 |
+
|
342 |
+
|
343 |
+
if __name__ == "__main__":
|
344 |
+
a = RandomAugment()
|
345 |
+
img = np.random.randn(32, 32, 3)
|
346 |
+
a(img)
|
PreTrain_MeDSLIP/models/__init__.py
ADDED
File without changes
|
PreTrain_MeDSLIP/models/model_MeDSLIP.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
|
2 |
+
from sklearn.metrics import log_loss
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from .transformer import *
|
10 |
+
import torchvision.models as models
|
11 |
+
from einops import rearrange
|
12 |
+
from transformers import AutoModel
|
13 |
+
|
14 |
+
"""
|
15 |
+
args.N
|
16 |
+
args.d_model
|
17 |
+
args.res_base_model
|
18 |
+
args.H
|
19 |
+
args.num_queries
|
20 |
+
args.dropout
|
21 |
+
args.attribute_set_size
|
22 |
+
"""
|
23 |
+
|
24 |
+
|
25 |
+
class MeDSLIP(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self, config, anatomy_book, pathology_book, mode="train",
|
28 |
+
):
|
29 |
+
super(MeDSLIP, self).__init__()
|
30 |
+
self.mode = mode
|
31 |
+
self.d_model = config["d_model"]
|
32 |
+
# """ book embedding"""
|
33 |
+
with torch.no_grad():
|
34 |
+
bert_model = self._get_bert_basemodel(
|
35 |
+
config["text_encoder"], freeze_layers=None
|
36 |
+
).to(anatomy_book["input_ids"].device)
|
37 |
+
self.anatomy_book = bert_model(
|
38 |
+
input_ids=anatomy_book["input_ids"],
|
39 |
+
attention_mask=anatomy_book["attention_mask"],
|
40 |
+
) # (**encoded_inputs)
|
41 |
+
self.anatomy_book = self.anatomy_book.last_hidden_state[:, 0, :]
|
42 |
+
self.pathology_book = bert_model(
|
43 |
+
input_ids=pathology_book["input_ids"],
|
44 |
+
attention_mask=pathology_book["attention_mask"],
|
45 |
+
) # (**encoded_inputs)
|
46 |
+
self.pathology_book = self.pathology_book.last_hidden_state[:, 0, :]
|
47 |
+
self.pathology_embedding_layer = nn.Linear(768, 256)
|
48 |
+
self.cl_fc_pathology = nn.Linear(256, 768)
|
49 |
+
|
50 |
+
self.pathology_name = [
|
51 |
+
"normal",
|
52 |
+
"clear",
|
53 |
+
"sharp",
|
54 |
+
"sharply",
|
55 |
+
"unremarkable",
|
56 |
+
"intact",
|
57 |
+
"stable",
|
58 |
+
"free",
|
59 |
+
"effusion",
|
60 |
+
"opacity",
|
61 |
+
"pneumothorax",
|
62 |
+
"edema",
|
63 |
+
"atelectasis",
|
64 |
+
"tube",
|
65 |
+
"consolidation",
|
66 |
+
"process",
|
67 |
+
"abnormality",
|
68 |
+
"enlarge",
|
69 |
+
"tip",
|
70 |
+
"low",
|
71 |
+
"pneumonia",
|
72 |
+
"line",
|
73 |
+
"congestion",
|
74 |
+
"catheter",
|
75 |
+
"cardiomegaly",
|
76 |
+
"fracture",
|
77 |
+
"air",
|
78 |
+
"tortuous",
|
79 |
+
"lead",
|
80 |
+
"pathology",
|
81 |
+
"calcification",
|
82 |
+
"prominence",
|
83 |
+
"device",
|
84 |
+
"engorgement",
|
85 |
+
"picc",
|
86 |
+
"clip",
|
87 |
+
"elevation",
|
88 |
+
"expand",
|
89 |
+
"nodule",
|
90 |
+
"wire",
|
91 |
+
"fluid",
|
92 |
+
"degenerative",
|
93 |
+
"pacemaker",
|
94 |
+
"thicken",
|
95 |
+
"marking",
|
96 |
+
"scar",
|
97 |
+
"hyperinflate",
|
98 |
+
"blunt",
|
99 |
+
"loss",
|
100 |
+
"widen",
|
101 |
+
"coll_eapse",
|
102 |
+
"density",
|
103 |
+
"emphysema",
|
104 |
+
"aerate",
|
105 |
+
"mass",
|
106 |
+
"crowd",
|
107 |
+
"infiltrate",
|
108 |
+
"obscure",
|
109 |
+
"deformity",
|
110 |
+
"hernia",
|
111 |
+
"drainage",
|
112 |
+
"distention",
|
113 |
+
"shift",
|
114 |
+
"stent",
|
115 |
+
"pressure",
|
116 |
+
"lesion",
|
117 |
+
"finding",
|
118 |
+
"borderline",
|
119 |
+
"hardware",
|
120 |
+
"dilation",
|
121 |
+
"chf",
|
122 |
+
"redistribution",
|
123 |
+
"aspiration",
|
124 |
+
"tail_abnorm_obs",
|
125 |
+
"excluded_obs",
|
126 |
+
]
|
127 |
+
|
128 |
+
self.excluded_pathology = [
|
129 |
+
"pneumonia",
|
130 |
+
"infiltrate",
|
131 |
+
"mass",
|
132 |
+
"nodule",
|
133 |
+
"emphysema",
|
134 |
+
"fibrosis",
|
135 |
+
"thicken",
|
136 |
+
"hernia",
|
137 |
+
]
|
138 |
+
|
139 |
+
self.keep_class_dim_pathology = [
|
140 |
+
self.pathology_name.index(i)
|
141 |
+
for i in self.pathology_name
|
142 |
+
if i not in self.excluded_pathology
|
143 |
+
]
|
144 |
+
""" visual backbone"""
|
145 |
+
self.resnet_dict = {
|
146 |
+
"resnet18": models.resnet18(pretrained=False),
|
147 |
+
"resnet50": models.resnet50(pretrained=False),
|
148 |
+
}
|
149 |
+
resnet = self._get_res_basemodel(config["res_base_model"])
|
150 |
+
num_ftrs = int(resnet.fc.in_features / 2)
|
151 |
+
self.res_features = nn.Sequential(*list(resnet.children())[:-3])
|
152 |
+
|
153 |
+
self.res_l1_pathology = nn.Linear(num_ftrs, num_ftrs)
|
154 |
+
self.res_l2_pathology = nn.Linear(num_ftrs, self.d_model)
|
155 |
+
|
156 |
+
self.cl_fc_anatomy = nn.Linear(256, 768)
|
157 |
+
self.res_l1_anatomy = nn.Linear(num_ftrs, num_ftrs)
|
158 |
+
self.res_l2_anatomy = nn.Linear(num_ftrs, self.d_model)
|
159 |
+
|
160 |
+
self.mask_generator = nn.Linear(num_ftrs, num_ftrs)
|
161 |
+
|
162 |
+
###################################
|
163 |
+
""" Query Decoder"""
|
164 |
+
###################################
|
165 |
+
|
166 |
+
self.H = config["H"]
|
167 |
+
decoder_layer = TransformerDecoderLayer(
|
168 |
+
self.d_model, config["H"], 1024, 0.1, "relu", normalize_before=True
|
169 |
+
)
|
170 |
+
decoder_norm = nn.LayerNorm(self.d_model)
|
171 |
+
self.decoder_anatomy = TransformerDecoder(
|
172 |
+
decoder_layer, config["N"], decoder_norm, return_intermediate=False
|
173 |
+
)
|
174 |
+
self.decoder_pathology = TransformerDecoder(
|
175 |
+
decoder_layer, config["N"], decoder_norm, return_intermediate=False
|
176 |
+
)
|
177 |
+
|
178 |
+
# Learnable Queries
|
179 |
+
self.dropout_feas_anatomy = nn.Dropout(config["dropout"])
|
180 |
+
self.dropout_feas_pathology = nn.Dropout(config["dropout"])
|
181 |
+
|
182 |
+
# Attribute classifier
|
183 |
+
self.classifier_anatomy = nn.Linear(self.d_model, config["attribute_set_size"])
|
184 |
+
self.classifier_pathology = nn.Linear(
|
185 |
+
self.d_model, config["attribute_set_size"]
|
186 |
+
)
|
187 |
+
|
188 |
+
self.apply(self._init_weights)
|
189 |
+
|
190 |
+
def _get_res_basemodel(self, res_model_name):
|
191 |
+
try:
|
192 |
+
res_model = self.resnet_dict[res_model_name]
|
193 |
+
print("Image feature extractor:", res_model_name)
|
194 |
+
return res_model
|
195 |
+
except:
|
196 |
+
raise (
|
197 |
+
"Invalid model name. Check the config file and pass one of: resnet18 or resnet50"
|
198 |
+
)
|
199 |
+
|
200 |
+
def _get_bert_basemodel(self, bert_model_name, freeze_layers):
|
201 |
+
try:
|
202 |
+
model = AutoModel.from_pretrained(bert_model_name)
|
203 |
+
print("text feature extractor:", bert_model_name)
|
204 |
+
except:
|
205 |
+
raise (
|
206 |
+
"Invalid model name. Check the config file and pass a BERT model from transformers lybrary"
|
207 |
+
)
|
208 |
+
|
209 |
+
if freeze_layers is not None:
|
210 |
+
for layer_idx in freeze_layers:
|
211 |
+
for param in list(model.encoder.layer[layer_idx].parameters()):
|
212 |
+
param.requires_grad = False
|
213 |
+
return model
|
214 |
+
|
215 |
+
def image_encoder(self, xis):
|
216 |
+
# patch features
|
217 |
+
"""
|
218 |
+
16 torch.Size([16, 1024, 14, 14])
|
219 |
+
torch.Size([16, 196, 1024])
|
220 |
+
torch.Size([3136, 1024])
|
221 |
+
torch.Size([16, 196, 256])
|
222 |
+
"""
|
223 |
+
batch_size = xis.shape[0]
|
224 |
+
res_fea = self.res_features(xis) # batch_size,feature_size,patch_num,patch_num
|
225 |
+
res_fea = rearrange(res_fea, "b d n1 n2 -> b (n1 n2) d")
|
226 |
+
x = rearrange(res_fea, "b n d -> (b n) d")
|
227 |
+
|
228 |
+
mask = self.mask_generator(x)
|
229 |
+
x_pathology = mask * x
|
230 |
+
x_anatomy = (1 - mask) * x
|
231 |
+
|
232 |
+
x_pathology = self.res_l1_pathology(x_pathology)
|
233 |
+
x_anatomy = self.res_l1_anatomy(x_anatomy)
|
234 |
+
x_pathology = F.relu(x_pathology)
|
235 |
+
x_anatomy = F.relu(x_anatomy)
|
236 |
+
|
237 |
+
x_pathology = self.res_l2_pathology(x_pathology)
|
238 |
+
x_anatomy = self.res_l2_anatomy(x_anatomy)
|
239 |
+
|
240 |
+
out_emb_pathology = rearrange(x_pathology, "(b n) d -> b n d", b=batch_size)
|
241 |
+
out_emb_anatomy = rearrange(x_anatomy, "(b n) d -> b n d", b=batch_size)
|
242 |
+
return out_emb_pathology, out_emb_anatomy
|
243 |
+
|
244 |
+
def forward(
|
245 |
+
self,
|
246 |
+
images,
|
247 |
+
labels_pathology=None,
|
248 |
+
labels_anatomy=None,
|
249 |
+
matrix=None,
|
250 |
+
sample_index_pathology=None,
|
251 |
+
sample_index_anatomy=None,
|
252 |
+
is_train=True,
|
253 |
+
text_gen=False,
|
254 |
+
no_cl=False,
|
255 |
+
exclude_class=False,
|
256 |
+
):
|
257 |
+
|
258 |
+
B = images.shape[0]
|
259 |
+
device = images.device
|
260 |
+
""" Visual Backbone """
|
261 |
+
x_pathology, x_anatomy = self.image_encoder(images) # batch_size,patch_num,dim
|
262 |
+
|
263 |
+
features_pathology = x_pathology.transpose(0, 1) # patch_num b dim
|
264 |
+
features_anatomy = x_anatomy.transpose(0, 1) # patch_num b dim
|
265 |
+
|
266 |
+
query_embed_pathology = self.pathology_embedding_layer(self.pathology_book)
|
267 |
+
query_embed_anatomy = self.pathology_embedding_layer(self.anatomy_book)
|
268 |
+
query_embed_pathology = query_embed_pathology.unsqueeze(1).repeat(1, B, 1)
|
269 |
+
query_embed_anatomy = query_embed_anatomy.unsqueeze(1).repeat(1, B, 1)
|
270 |
+
|
271 |
+
features_pathology, ws_pathology = self.decoder_pathology(
|
272 |
+
query_embed_pathology,
|
273 |
+
features_pathology,
|
274 |
+
memory_key_padding_mask=None,
|
275 |
+
pos=None,
|
276 |
+
query_pos=None,
|
277 |
+
)
|
278 |
+
features_anatomy, ws_anatomy = self.decoder_anatomy(
|
279 |
+
query_embed_anatomy,
|
280 |
+
features_anatomy,
|
281 |
+
memory_key_padding_mask=None,
|
282 |
+
pos=None,
|
283 |
+
query_pos=None,
|
284 |
+
)
|
285 |
+
|
286 |
+
ap_pathology = features_pathology
|
287 |
+
ap_anatomy = features_anatomy
|
288 |
+
|
289 |
+
ap_logits = torch.bmm(
|
290 |
+
ap_pathology.transpose(0, 1), ap_anatomy.transpose(0, 1).transpose(1, 2)
|
291 |
+
).transpose(
|
292 |
+
1, 2
|
293 |
+
)
|
294 |
+
if text_gen:
|
295 |
+
output_logits = ap_logits
|
296 |
+
matrix_zero = matrix
|
297 |
+
|
298 |
+
masks = matrix_zero >= 0
|
299 |
+
ap_logits = ap_logits[masks]
|
300 |
+
matrix_zero = matrix_zero[masks]
|
301 |
+
|
302 |
+
loss_ap = F.binary_cross_entropy_with_logits(
|
303 |
+
ap_logits.float(), matrix_zero.float()
|
304 |
+
)
|
305 |
+
|
306 |
+
out_pathology = self.dropout_feas_pathology(features_pathology)
|
307 |
+
out_anatomy = self.dropout_feas_anatomy(features_anatomy)
|
308 |
+
|
309 |
+
if is_train == True and no_cl == False:
|
310 |
+
|
311 |
+
# get anatomytomy query
|
312 |
+
anatomytomy_query = torch.zeros(
|
313 |
+
[
|
314 |
+
sample_index_pathology.shape[0],
|
315 |
+
sample_index_pathology.shape[1],
|
316 |
+
sample_index_pathology.shape[2],
|
317 |
+
self.anatomy_book.shape[-1],
|
318 |
+
]
|
319 |
+
).to(
|
320 |
+
device
|
321 |
+
)
|
322 |
+
entity_query = torch.zeros(
|
323 |
+
[
|
324 |
+
sample_index_anatomy.shape[0],
|
325 |
+
sample_index_anatomy.shape[1],
|
326 |
+
sample_index_anatomy.shape[2],
|
327 |
+
self.pathology_book.shape[-1],
|
328 |
+
]
|
329 |
+
).to(device)
|
330 |
+
|
331 |
+
anatomytomy_query = self.anatomy_book[sample_index_pathology, :] * (
|
332 |
+
sample_index_pathology != -1
|
333 |
+
).int().unsqueeze(-1).repeat(
|
334 |
+
1, 1, 1, 768
|
335 |
+
) # batch, Q , position_num ,dim
|
336 |
+
entity_query = self.pathology_book[sample_index_anatomy, :] * (
|
337 |
+
sample_index_anatomy != -1
|
338 |
+
).int().unsqueeze(-1).repeat(1, 1, 1, 768)
|
339 |
+
|
340 |
+
matrix_zero_pathology = matrix
|
341 |
+
matrix_zero_anatomy = matrix.transpose(1, 2)
|
342 |
+
matrix_zero_pathology[matrix_zero_pathology < 1] = 0
|
343 |
+
matrix_zero_anatomy[matrix_zero_anatomy < 1] = 0
|
344 |
+
matrix_zero_pathology = matrix_zero_pathology.unsqueeze(3).repeat(
|
345 |
+
1, 1, 1, anatomytomy_query.shape[-1]
|
346 |
+
)
|
347 |
+
matrix_zero_anatomy = matrix_zero_anatomy.unsqueeze(3).repeat(
|
348 |
+
1, 1, 1, entity_query.shape[-1]
|
349 |
+
)
|
350 |
+
|
351 |
+
anatomy_temp = self.anatomy_book
|
352 |
+
pathology_temp = self.pathology_book
|
353 |
+
anatomy_temp = anatomy_temp.unsqueeze(0).repeat(
|
354 |
+
anatomytomy_query.shape[0], 1, 1
|
355 |
+
)
|
356 |
+
pathology_temp = pathology_temp.unsqueeze(0).repeat(
|
357 |
+
entity_query.shape[0], 1, 1
|
358 |
+
)
|
359 |
+
anatomy_temp = anatomy_temp.unsqueeze(2).repeat(
|
360 |
+
1, 1, anatomytomy_query.shape[1], 1
|
361 |
+
)
|
362 |
+
pathology_temp = pathology_temp.unsqueeze(2).repeat(
|
363 |
+
1, 1, entity_query.shape[1], 1
|
364 |
+
)
|
365 |
+
|
366 |
+
posi_matrix_pathology = (matrix_zero_pathology * anatomy_temp).transpose(
|
367 |
+
1, 2
|
368 |
+
)
|
369 |
+
posi_matrix_anatomy = (matrix_zero_anatomy * pathology_temp).transpose(1, 2)
|
370 |
+
|
371 |
+
for i in range(anatomytomy_query.shape[0]):
|
372 |
+
for j in range(anatomytomy_query.shape[1]):
|
373 |
+
if (posi_matrix_pathology[i, j] != 0).sum() > 0:
|
374 |
+
num_posi = (
|
375 |
+
torch.nonzero(posi_matrix_pathology[i, j], as_tuple=True)[0]
|
376 |
+
.unique()
|
377 |
+
.shape[0]
|
378 |
+
)
|
379 |
+
assert anatomytomy_query[i, j, 0, :].sum() == 0
|
380 |
+
anatomytomy_query[i, j, 0, :] = (
|
381 |
+
posi_matrix_pathology[i, j, :, :].sum(dim=0) / num_posi
|
382 |
+
)
|
383 |
+
|
384 |
+
for i in range(entity_query.shape[0]):
|
385 |
+
for j in range(entity_query.shape[1]):
|
386 |
+
if (posi_matrix_anatomy[i, j] != 0).sum() > 0:
|
387 |
+
num_posi = (
|
388 |
+
torch.nonzero(posi_matrix_anatomy[i, j], as_tuple=True)[0]
|
389 |
+
.unique()
|
390 |
+
.shape[0]
|
391 |
+
)
|
392 |
+
assert entity_query[i, j, 0, :].sum() == 0
|
393 |
+
entity_query[i, j, 0, :] = (
|
394 |
+
posi_matrix_anatomy[i, j, :, :].sum(dim=0) / num_posi
|
395 |
+
)
|
396 |
+
# Got anatomytomy query
|
397 |
+
|
398 |
+
# [Q,B,A]
|
399 |
+
ll_pathology = out_pathology.transpose(0, 1) # B Q A
|
400 |
+
ll_anatomy = out_anatomy.transpose(0, 1) # B Q A
|
401 |
+
|
402 |
+
Q_pathology = ll_pathology.shape[1]
|
403 |
+
Q_anatomy = ll_anatomy.shape[1]
|
404 |
+
|
405 |
+
ll_pathology = ll_pathology.reshape(
|
406 |
+
ll_pathology.shape[0] * ll_pathology.shape[1], -1
|
407 |
+
)
|
408 |
+
ll_anatomy = ll_anatomy.reshape(
|
409 |
+
ll_anatomy.shape[0] * ll_anatomy.shape[1], -1
|
410 |
+
)
|
411 |
+
|
412 |
+
ll_pathology = self.cl_fc_pathology(ll_pathology)
|
413 |
+
ll_anatomy = self.cl_fc_anatomy(ll_anatomy)
|
414 |
+
|
415 |
+
ll_pathology = ll_pathology.unsqueeze(dim=-1)
|
416 |
+
ll_anatomy = ll_anatomy.unsqueeze(dim=-1)
|
417 |
+
|
418 |
+
anatomytomy_query = anatomytomy_query.reshape(B * Q_pathology, 8, 768)
|
419 |
+
entity_query = entity_query.reshape(B * Q_anatomy, 8, 768)
|
420 |
+
|
421 |
+
ll_pathology = torch.bmm(
|
422 |
+
anatomytomy_query, ll_pathology
|
423 |
+
).squeeze() # B Q position_num
|
424 |
+
ll_anatomy = torch.bmm(
|
425 |
+
entity_query, ll_anatomy
|
426 |
+
).squeeze() # B Q position_num
|
427 |
+
|
428 |
+
cl_labels_pathology = torch.zeros((ll_pathology.shape[0])).to(device)
|
429 |
+
cl_labels_anatomy = torch.zeros((ll_anatomy.shape[0])).to(device)
|
430 |
+
|
431 |
+
if exclude_class == True:
|
432 |
+
cl_labels_pathology = cl_labels_pathology.reshape(B, Q_pathology)
|
433 |
+
cl_labels_anatomy = cl_labels_anatomy.reshape(B, Q_anatomy)
|
434 |
+
|
435 |
+
cl_labels_pathology = cl_labels_pathology[
|
436 |
+
:, self.keep_class_dim_pathology
|
437 |
+
]
|
438 |
+
cl_labels_anatomy = cl_labels_anatomy[:, self.keep_class_dim_pathology]
|
439 |
+
|
440 |
+
cl_labels_pathology = cl_labels_pathology.reshape(-1)
|
441 |
+
cl_labels_anatomy = cl_labels_anatomy.reshape(-1)
|
442 |
+
|
443 |
+
ll_pathology = ll_pathology.reshape(B, Q_pathology, -1)
|
444 |
+
ll_anatomy = ll_anatomy.reshape(B, Q_anatomy, -1)
|
445 |
+
|
446 |
+
ll_pathology = ll_pathology[:, self.keep_class_dim_pathology, :]
|
447 |
+
ll_pathology = ll_pathology.reshape(
|
448 |
+
B * (len(self.keep_class_dim_pathology)), -1
|
449 |
+
)
|
450 |
+
ll_anatomy = ll_anatomy.reshape(B * Q_anatomy, -1)
|
451 |
+
|
452 |
+
x_pathology = self.classifier_pathology(out_pathology).transpose(0, 1)
|
453 |
+
x_anatomy = self.classifier_anatomy(out_anatomy).transpose(
|
454 |
+
0, 1
|
455 |
+
) # B query Atributes
|
456 |
+
|
457 |
+
if exclude_class == True:
|
458 |
+
labels_pathology = labels_pathology[:, self.keep_class_dim_pathology]
|
459 |
+
x_pathology = x_pathology[:, self.keep_class_dim_pathology, :]
|
460 |
+
|
461 |
+
labels_pathology = labels_pathology.reshape(-1, 1)
|
462 |
+
labels_anatomy = labels_anatomy.reshape(-1, 1)
|
463 |
+
logits_pathology = x_pathology.reshape(-1, x_pathology.shape[-1])
|
464 |
+
logits_anatomy = x_anatomy.reshape(-1, x_anatomy.shape[-1])
|
465 |
+
Mask_pathology = ((labels_pathology != -1) & (labels_pathology != 2)).squeeze()
|
466 |
+
Mask_anatomy = ((labels_anatomy != -1) & (labels_anatomy != 2)).squeeze()
|
467 |
+
|
468 |
+
cl_mask_pathology = (labels_pathology == 1).squeeze()
|
469 |
+
cl_mask_anatomy = (labels_anatomy == 1).squeeze()
|
470 |
+
if is_train == True:
|
471 |
+
labels_pathology = labels_pathology[Mask_pathology].long()
|
472 |
+
labels_anatomy = labels_anatomy[Mask_anatomy].long()
|
473 |
+
logits_pathology = logits_pathology[Mask_pathology]
|
474 |
+
logits_anatomy = logits_anatomy[Mask_anatomy]
|
475 |
+
loss_ce_pathology = F.cross_entropy(
|
476 |
+
logits_pathology, labels_pathology[:, 0]
|
477 |
+
)
|
478 |
+
loss_ce_anatomy = F.cross_entropy(logits_anatomy, labels_anatomy[:, 0])
|
479 |
+
if no_cl == False:
|
480 |
+
cl_labels_pathology = cl_labels_pathology[cl_mask_pathology].long()
|
481 |
+
cl_labels_anatomy = cl_labels_anatomy[cl_mask_anatomy].long()
|
482 |
+
ll_pathology = ll_pathology[cl_mask_pathology]
|
483 |
+
ll_anatomy = ll_anatomy[cl_mask_anatomy]
|
484 |
+
loss_cl_pathology = F.cross_entropy(ll_pathology, cl_labels_pathology)
|
485 |
+
loss_cl_anatomy = F.cross_entropy(ll_anatomy, cl_labels_anatomy)
|
486 |
+
loss_ce = loss_ce_pathology + loss_ce_anatomy
|
487 |
+
loss_cl = loss_cl_pathology + loss_cl_anatomy
|
488 |
+
loss = loss_ce + loss_cl + loss_ap
|
489 |
+
else:
|
490 |
+
loss_cl = torch.tensor(0)
|
491 |
+
loss = loss_ce_pathology + loss_ce_anatomy + loss_ap
|
492 |
+
else:
|
493 |
+
loss = 0
|
494 |
+
if is_train == True:
|
495 |
+
if text_gen:
|
496 |
+
return (
|
497 |
+
loss,
|
498 |
+
x_pathology,
|
499 |
+
ws_pathology,
|
500 |
+
x_anatomy,
|
501 |
+
ws_anatomy,
|
502 |
+
output_logits,
|
503 |
+
)
|
504 |
+
else:
|
505 |
+
return (
|
506 |
+
loss,
|
507 |
+
loss_ce_pathology,
|
508 |
+
loss_cl_pathology,
|
509 |
+
loss_ce_anatomy,
|
510 |
+
loss_cl_anatomy,
|
511 |
+
loss_ap,
|
512 |
+
)
|
513 |
+
else:
|
514 |
+
return loss, x_pathology, ws_pathology, x_anatomy, ws_anatomy
|
515 |
+
|
516 |
+
@staticmethod
|
517 |
+
def _init_weights(module):
|
518 |
+
r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0."""
|
519 |
+
|
520 |
+
if isinstance(module, nn.Linear):
|
521 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
522 |
+
|
523 |
+
elif isinstance(module, nn.MultiheadAttention):
|
524 |
+
module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
|
525 |
+
module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
|
526 |
+
|
527 |
+
elif isinstance(module, nn.Embedding):
|
528 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
529 |
+
if module.padding_idx is not None:
|
530 |
+
module.weight.data[module.padding_idx].zero_()
|
PreTrain_MeDSLIP/models/tokenization_bert.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes for Bert."""
|
16 |
+
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import os
|
20 |
+
import unicodedata
|
21 |
+
from typing import List, Optional, Tuple
|
22 |
+
|
23 |
+
from transformers.tokenization_utils import (
|
24 |
+
PreTrainedTokenizer,
|
25 |
+
_is_control,
|
26 |
+
_is_punctuation,
|
27 |
+
_is_whitespace,
|
28 |
+
)
|
29 |
+
from transformers.utils import logging
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
35 |
+
|
36 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
37 |
+
"vocab_file": {
|
38 |
+
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
|
39 |
+
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
|
40 |
+
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
|
41 |
+
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
|
42 |
+
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
|
43 |
+
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
|
44 |
+
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
|
45 |
+
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
|
46 |
+
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
|
47 |
+
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
|
48 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
|
49 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
|
50 |
+
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
|
51 |
+
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
|
52 |
+
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
|
53 |
+
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
|
54 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
|
55 |
+
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
60 |
+
"bert-base-uncased": 512,
|
61 |
+
"bert-large-uncased": 512,
|
62 |
+
"bert-base-cased": 512,
|
63 |
+
"bert-large-cased": 512,
|
64 |
+
"bert-base-multilingual-uncased": 512,
|
65 |
+
"bert-base-multilingual-cased": 512,
|
66 |
+
"bert-base-chinese": 512,
|
67 |
+
"bert-base-german-cased": 512,
|
68 |
+
"bert-large-uncased-whole-word-masking": 512,
|
69 |
+
"bert-large-cased-whole-word-masking": 512,
|
70 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
|
71 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
|
72 |
+
"bert-base-cased-finetuned-mrpc": 512,
|
73 |
+
"bert-base-german-dbmdz-cased": 512,
|
74 |
+
"bert-base-german-dbmdz-uncased": 512,
|
75 |
+
"TurkuNLP/bert-base-finnish-cased-v1": 512,
|
76 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
|
77 |
+
"wietsedv/bert-base-dutch-cased": 512,
|
78 |
+
}
|
79 |
+
|
80 |
+
PRETRAINED_INIT_CONFIGURATION = {
|
81 |
+
"bert-base-uncased": {"do_lower_case": True},
|
82 |
+
"bert-large-uncased": {"do_lower_case": True},
|
83 |
+
"bert-base-cased": {"do_lower_case": False},
|
84 |
+
"bert-large-cased": {"do_lower_case": False},
|
85 |
+
"bert-base-multilingual-uncased": {"do_lower_case": True},
|
86 |
+
"bert-base-multilingual-cased": {"do_lower_case": False},
|
87 |
+
"bert-base-chinese": {"do_lower_case": False},
|
88 |
+
"bert-base-german-cased": {"do_lower_case": False},
|
89 |
+
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
|
90 |
+
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
|
91 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
|
92 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
|
93 |
+
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
|
94 |
+
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
|
95 |
+
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
|
96 |
+
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
|
97 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
|
98 |
+
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
|
99 |
+
}
|
100 |
+
|
101 |
+
|
102 |
+
def load_vocab(vocab_file):
|
103 |
+
"""Loads a vocabulary file into a dictionary."""
|
104 |
+
vocab = collections.OrderedDict()
|
105 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
106 |
+
tokens = reader.readlines()
|
107 |
+
for index, token in enumerate(tokens):
|
108 |
+
token = token.rstrip("\n")
|
109 |
+
vocab[token] = index
|
110 |
+
return vocab
|
111 |
+
|
112 |
+
|
113 |
+
def whitespace_tokenize(text):
|
114 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
115 |
+
text = text.strip()
|
116 |
+
if not text:
|
117 |
+
return []
|
118 |
+
tokens = text.split()
|
119 |
+
return tokens
|
120 |
+
|
121 |
+
|
122 |
+
class BertTokenizer(PreTrainedTokenizer):
|
123 |
+
r"""
|
124 |
+
Construct a BERT tokenizer. Based on WordPiece.
|
125 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
|
126 |
+
Users should refer to this superclass for more information regarding those methods.
|
127 |
+
Args:
|
128 |
+
vocab_file (:obj:`str`):
|
129 |
+
File containing the vocabulary.
|
130 |
+
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
131 |
+
Whether or not to lowercase the input when tokenizing.
|
132 |
+
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
133 |
+
Whether or not to do basic tokenization before WordPiece.
|
134 |
+
never_split (:obj:`Iterable`, `optional`):
|
135 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
136 |
+
:obj:`do_basic_tokenize=True`
|
137 |
+
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
|
138 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
139 |
+
token instead.
|
140 |
+
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
|
141 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
142 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
143 |
+
token of a sequence built with special tokens.
|
144 |
+
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
|
145 |
+
The token used for padding, for example when batching sequences of different lengths.
|
146 |
+
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
|
147 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
148 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
149 |
+
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
|
150 |
+
The token used for masking values. This is the token used when training this model with masked language
|
151 |
+
modeling. This is the token which the model will try to predict.
|
152 |
+
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
153 |
+
Whether or not to tokenize Chinese characters.
|
154 |
+
This should likely be deactivated for Japanese (see this `issue
|
155 |
+
<https://github.com/huggingface/transformers/issues/328>`__).
|
156 |
+
strip_accents: (:obj:`bool`, `optional`):
|
157 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
158 |
+
value for :obj:`lowercase` (as in the original BERT).
|
159 |
+
"""
|
160 |
+
|
161 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
162 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
163 |
+
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
164 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
165 |
+
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
vocab_file,
|
169 |
+
do_lower_case=True,
|
170 |
+
do_basic_tokenize=True,
|
171 |
+
never_split=None,
|
172 |
+
unk_token="[UNK]",
|
173 |
+
sep_token="[SEP]",
|
174 |
+
pad_token="[PAD]",
|
175 |
+
cls_token="[CLS]",
|
176 |
+
mask_token="[MASK]",
|
177 |
+
tokenize_chinese_chars=True,
|
178 |
+
strip_accents=None,
|
179 |
+
**kwargs
|
180 |
+
):
|
181 |
+
super().__init__(
|
182 |
+
do_lower_case=do_lower_case,
|
183 |
+
do_basic_tokenize=do_basic_tokenize,
|
184 |
+
never_split=never_split,
|
185 |
+
unk_token=unk_token,
|
186 |
+
sep_token=sep_token,
|
187 |
+
pad_token=pad_token,
|
188 |
+
cls_token=cls_token,
|
189 |
+
mask_token=mask_token,
|
190 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
191 |
+
strip_accents=strip_accents,
|
192 |
+
**kwargs,
|
193 |
+
)
|
194 |
+
|
195 |
+
if not os.path.isfile(vocab_file):
|
196 |
+
raise ValueError(
|
197 |
+
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
198 |
+
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
199 |
+
vocab_file
|
200 |
+
)
|
201 |
+
)
|
202 |
+
self.vocab = load_vocab(vocab_file)
|
203 |
+
self.ids_to_tokens = collections.OrderedDict(
|
204 |
+
[(ids, tok) for tok, ids in self.vocab.items()]
|
205 |
+
)
|
206 |
+
self.do_basic_tokenize = do_basic_tokenize
|
207 |
+
if do_basic_tokenize:
|
208 |
+
self.basic_tokenizer = BasicTokenizer(
|
209 |
+
do_lower_case=do_lower_case,
|
210 |
+
never_split=never_split,
|
211 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
212 |
+
strip_accents=strip_accents,
|
213 |
+
)
|
214 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(
|
215 |
+
vocab=self.vocab, unk_token=self.unk_token
|
216 |
+
)
|
217 |
+
|
218 |
+
@property
|
219 |
+
def do_lower_case(self):
|
220 |
+
return self.basic_tokenizer.do_lower_case
|
221 |
+
|
222 |
+
@property
|
223 |
+
def vocab_size(self):
|
224 |
+
return len(self.vocab)
|
225 |
+
|
226 |
+
def get_vocab(self):
|
227 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
228 |
+
|
229 |
+
def _tokenize(self, text):
|
230 |
+
split_tokens = []
|
231 |
+
if self.do_basic_tokenize:
|
232 |
+
for token in self.basic_tokenizer.tokenize(
|
233 |
+
text, never_split=self.all_special_tokens
|
234 |
+
):
|
235 |
+
|
236 |
+
# If the token is part of the never_split set
|
237 |
+
if token in self.basic_tokenizer.never_split:
|
238 |
+
split_tokens.append(token)
|
239 |
+
else:
|
240 |
+
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
241 |
+
else:
|
242 |
+
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
243 |
+
return split_tokens
|
244 |
+
|
245 |
+
def _convert_token_to_id(self, token):
|
246 |
+
""" Converts a token (str) in an id using the vocab. """
|
247 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
248 |
+
|
249 |
+
def _convert_id_to_token(self, index):
|
250 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
251 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
252 |
+
|
253 |
+
def convert_tokens_to_string(self, tokens):
|
254 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
255 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
256 |
+
return out_string
|
257 |
+
|
258 |
+
def build_inputs_with_special_tokens(
|
259 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
260 |
+
) -> List[int]:
|
261 |
+
"""
|
262 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
263 |
+
adding special tokens. A BERT sequence has the following format:
|
264 |
+
- single sequence: ``[CLS] X ``
|
265 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
266 |
+
Args:
|
267 |
+
token_ids_0 (:obj:`List[int]`):
|
268 |
+
List of IDs to which the special tokens will be added.
|
269 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
270 |
+
Optional second list of IDs for sequence pairs.
|
271 |
+
Returns:
|
272 |
+
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
273 |
+
"""
|
274 |
+
if token_ids_1 is None:
|
275 |
+
return [self.cls_token_id] + token_ids_0
|
276 |
+
cls = [self.cls_token_id]
|
277 |
+
sep = [self.sep_token_id]
|
278 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
279 |
+
|
280 |
+
def get_special_tokens_mask(
|
281 |
+
self,
|
282 |
+
token_ids_0: List[int],
|
283 |
+
token_ids_1: Optional[List[int]] = None,
|
284 |
+
already_has_special_tokens: bool = False,
|
285 |
+
) -> List[int]:
|
286 |
+
"""
|
287 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
288 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
289 |
+
Args:
|
290 |
+
token_ids_0 (:obj:`List[int]`):
|
291 |
+
List of IDs.
|
292 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
293 |
+
Optional second list of IDs for sequence pairs.
|
294 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
295 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
296 |
+
Returns:
|
297 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
298 |
+
"""
|
299 |
+
|
300 |
+
if already_has_special_tokens:
|
301 |
+
if token_ids_1 is not None:
|
302 |
+
raise ValueError(
|
303 |
+
"You should not supply a second sequence if the provided sequence of "
|
304 |
+
"ids is already formatted with special tokens for the model."
|
305 |
+
)
|
306 |
+
return list(
|
307 |
+
map(
|
308 |
+
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
|
309 |
+
token_ids_0,
|
310 |
+
)
|
311 |
+
)
|
312 |
+
|
313 |
+
if token_ids_1 is not None:
|
314 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
315 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
316 |
+
|
317 |
+
def create_token_type_ids_from_sequences(
|
318 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
319 |
+
) -> List[int]:
|
320 |
+
"""
|
321 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
|
322 |
+
pair mask has the following format:
|
323 |
+
::
|
324 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
325 |
+
| first sequence | second sequence |
|
326 |
+
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
|
327 |
+
Args:
|
328 |
+
token_ids_0 (:obj:`List[int]`):
|
329 |
+
List of IDs.
|
330 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
331 |
+
Optional second list of IDs for sequence pairs.
|
332 |
+
Returns:
|
333 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
334 |
+
sequence(s).
|
335 |
+
"""
|
336 |
+
sep = [self.sep_token_id]
|
337 |
+
cls = [self.cls_token_id]
|
338 |
+
if token_ids_1 is None:
|
339 |
+
return len(cls + token_ids_0 + sep) * [0]
|
340 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
341 |
+
|
342 |
+
def save_vocabulary(
|
343 |
+
self, save_directory: str, filename_prefix: Optional[str] = None
|
344 |
+
) -> Tuple[str]:
|
345 |
+
index = 0
|
346 |
+
if os.path.isdir(save_directory):
|
347 |
+
vocab_file = os.path.join(
|
348 |
+
save_directory,
|
349 |
+
(filename_prefix + "-" if filename_prefix else "")
|
350 |
+
+ VOCAB_FILES_NAMES["vocab_file"],
|
351 |
+
)
|
352 |
+
else:
|
353 |
+
vocab_file = (
|
354 |
+
filename_prefix + "-" if filename_prefix else ""
|
355 |
+
) + save_directory
|
356 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
357 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
358 |
+
if index != token_index:
|
359 |
+
logger.warning(
|
360 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
361 |
+
" Please check that the vocabulary is not corrupted!".format(
|
362 |
+
vocab_file
|
363 |
+
)
|
364 |
+
)
|
365 |
+
index = token_index
|
366 |
+
writer.write(token + "\n")
|
367 |
+
index += 1
|
368 |
+
return (vocab_file,)
|
369 |
+
|
370 |
+
|
371 |
+
class BasicTokenizer(object):
|
372 |
+
"""
|
373 |
+
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
374 |
+
Args:
|
375 |
+
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
376 |
+
Whether or not to lowercase the input when tokenizing.
|
377 |
+
never_split (:obj:`Iterable`, `optional`):
|
378 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
379 |
+
:obj:`do_basic_tokenize=True`
|
380 |
+
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
381 |
+
Whether or not to tokenize Chinese characters.
|
382 |
+
This should likely be deactivated for Japanese (see this `issue
|
383 |
+
<https://github.com/huggingface/transformers/issues/328>`__).
|
384 |
+
strip_accents: (:obj:`bool`, `optional`):
|
385 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
386 |
+
value for :obj:`lowercase` (as in the original BERT).
|
387 |
+
"""
|
388 |
+
|
389 |
+
def __init__(
|
390 |
+
self,
|
391 |
+
do_lower_case=True,
|
392 |
+
never_split=None,
|
393 |
+
tokenize_chinese_chars=True,
|
394 |
+
strip_accents=None,
|
395 |
+
):
|
396 |
+
if never_split is None:
|
397 |
+
never_split = []
|
398 |
+
self.do_lower_case = do_lower_case
|
399 |
+
self.never_split = set(never_split)
|
400 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
401 |
+
self.strip_accents = strip_accents
|
402 |
+
|
403 |
+
def tokenize(self, text, never_split=None):
|
404 |
+
"""
|
405 |
+
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
|
406 |
+
WordPieceTokenizer.
|
407 |
+
Args:
|
408 |
+
**never_split**: (`optional`) list of str
|
409 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
410 |
+
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
|
411 |
+
"""
|
412 |
+
# union() returns a new set by concatenating the two sets.
|
413 |
+
never_split = (
|
414 |
+
self.never_split.union(set(never_split))
|
415 |
+
if never_split
|
416 |
+
else self.never_split
|
417 |
+
)
|
418 |
+
text = self._clean_text(text)
|
419 |
+
|
420 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
421 |
+
# models. This is also applied to the English models now, but it doesn't
|
422 |
+
# matter since the English models were not trained on any Chinese data
|
423 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
424 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
425 |
+
# words in the English Wikipedia.).
|
426 |
+
if self.tokenize_chinese_chars:
|
427 |
+
text = self._tokenize_chinese_chars(text)
|
428 |
+
orig_tokens = whitespace_tokenize(text)
|
429 |
+
split_tokens = []
|
430 |
+
for token in orig_tokens:
|
431 |
+
if token not in never_split:
|
432 |
+
if self.do_lower_case:
|
433 |
+
token = token.lower()
|
434 |
+
if self.strip_accents is not False:
|
435 |
+
token = self._run_strip_accents(token)
|
436 |
+
elif self.strip_accents:
|
437 |
+
token = self._run_strip_accents(token)
|
438 |
+
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
439 |
+
|
440 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
441 |
+
return output_tokens
|
442 |
+
|
443 |
+
def _run_strip_accents(self, text):
|
444 |
+
"""Strips accents from a piece of text."""
|
445 |
+
text = unicodedata.normalize("NFD", text)
|
446 |
+
output = []
|
447 |
+
for char in text:
|
448 |
+
cat = unicodedata.category(char)
|
449 |
+
if cat == "Mn":
|
450 |
+
continue
|
451 |
+
output.append(char)
|
452 |
+
return "".join(output)
|
453 |
+
|
454 |
+
def _run_split_on_punc(self, text, never_split=None):
|
455 |
+
"""Splits punctuation on a piece of text."""
|
456 |
+
if never_split is not None and text in never_split:
|
457 |
+
return [text]
|
458 |
+
chars = list(text)
|
459 |
+
i = 0
|
460 |
+
start_new_word = True
|
461 |
+
output = []
|
462 |
+
while i < len(chars):
|
463 |
+
char = chars[i]
|
464 |
+
if _is_punctuation(char):
|
465 |
+
output.append([char])
|
466 |
+
start_new_word = True
|
467 |
+
else:
|
468 |
+
if start_new_word:
|
469 |
+
output.append([])
|
470 |
+
start_new_word = False
|
471 |
+
output[-1].append(char)
|
472 |
+
i += 1
|
473 |
+
|
474 |
+
return ["".join(x) for x in output]
|
475 |
+
|
476 |
+
def _tokenize_chinese_chars(self, text):
|
477 |
+
"""Adds whitespace around any CJK character."""
|
478 |
+
output = []
|
479 |
+
for char in text:
|
480 |
+
cp = ord(char)
|
481 |
+
if self._is_chinese_char(cp):
|
482 |
+
output.append(" ")
|
483 |
+
output.append(char)
|
484 |
+
output.append(" ")
|
485 |
+
else:
|
486 |
+
output.append(char)
|
487 |
+
return "".join(output)
|
488 |
+
|
489 |
+
def _is_chinese_char(self, cp):
|
490 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
491 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
492 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
493 |
+
#
|
494 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
495 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
496 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
497 |
+
# space-separated words, so they are not treated specially and handled
|
498 |
+
# like the all of the other languages.
|
499 |
+
if (
|
500 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
501 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
502 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
503 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
504 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
505 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
506 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
507 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
508 |
+
): #
|
509 |
+
return True
|
510 |
+
|
511 |
+
return False
|
512 |
+
|
513 |
+
def _clean_text(self, text):
|
514 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
515 |
+
output = []
|
516 |
+
for char in text:
|
517 |
+
cp = ord(char)
|
518 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
519 |
+
continue
|
520 |
+
if _is_whitespace(char):
|
521 |
+
output.append(" ")
|
522 |
+
else:
|
523 |
+
output.append(char)
|
524 |
+
return "".join(output)
|
525 |
+
|
526 |
+
|
527 |
+
class WordpieceTokenizer(object):
|
528 |
+
"""Runs WordPiece tokenization."""
|
529 |
+
|
530 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
531 |
+
self.vocab = vocab
|
532 |
+
self.unk_token = unk_token
|
533 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
534 |
+
|
535 |
+
def tokenize(self, text):
|
536 |
+
"""
|
537 |
+
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
538 |
+
tokenization using the given vocabulary.
|
539 |
+
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
|
540 |
+
Args:
|
541 |
+
text: A single token or whitespace separated tokens. This should have
|
542 |
+
already been passed through `BasicTokenizer`.
|
543 |
+
Returns:
|
544 |
+
A list of wordpiece tokens.
|
545 |
+
"""
|
546 |
+
|
547 |
+
output_tokens = []
|
548 |
+
for token in whitespace_tokenize(text):
|
549 |
+
chars = list(token)
|
550 |
+
if len(chars) > self.max_input_chars_per_word:
|
551 |
+
output_tokens.append(self.unk_token)
|
552 |
+
continue
|
553 |
+
|
554 |
+
is_bad = False
|
555 |
+
start = 0
|
556 |
+
sub_tokens = []
|
557 |
+
while start < len(chars):
|
558 |
+
end = len(chars)
|
559 |
+
cur_substr = None
|
560 |
+
while start < end:
|
561 |
+
substr = "".join(chars[start:end])
|
562 |
+
if start > 0:
|
563 |
+
substr = "##" + substr
|
564 |
+
if substr in self.vocab:
|
565 |
+
cur_substr = substr
|
566 |
+
break
|
567 |
+
end -= 1
|
568 |
+
if cur_substr is None:
|
569 |
+
is_bad = True
|
570 |
+
break
|
571 |
+
sub_tokens.append(cur_substr)
|
572 |
+
start = end
|
573 |
+
|
574 |
+
if is_bad:
|
575 |
+
output_tokens.append(self.unk_token)
|
576 |
+
else:
|
577 |
+
output_tokens.extend(sub_tokens)
|
578 |
+
return output_tokens
|
PreTrain_MeDSLIP/models/transformer.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code modified from DETR tranformer:
|
3 |
+
https://github.com/facebookresearch/detr
|
4 |
+
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
5 |
+
"""
|
6 |
+
|
7 |
+
import copy
|
8 |
+
from typing import Optional, List
|
9 |
+
import pickle as cp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn, Tensor
|
14 |
+
|
15 |
+
|
16 |
+
class TransformerDecoder(nn.Module):
|
17 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
18 |
+
super().__init__()
|
19 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
20 |
+
self.num_layers = num_layers
|
21 |
+
self.norm = norm
|
22 |
+
self.return_intermediate = return_intermediate
|
23 |
+
|
24 |
+
def forward(
|
25 |
+
self,
|
26 |
+
tgt,
|
27 |
+
memory,
|
28 |
+
tgt_mask: Optional[Tensor] = None,
|
29 |
+
memory_mask: Optional[Tensor] = None,
|
30 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
31 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
32 |
+
pos: Optional[Tensor] = None,
|
33 |
+
query_pos: Optional[Tensor] = None,
|
34 |
+
):
|
35 |
+
output = tgt
|
36 |
+
T, B, C = memory.shape
|
37 |
+
intermediate = []
|
38 |
+
atten_layers = []
|
39 |
+
for n, layer in enumerate(self.layers):
|
40 |
+
|
41 |
+
residual = True
|
42 |
+
output, ws = layer(
|
43 |
+
output,
|
44 |
+
memory,
|
45 |
+
tgt_mask=tgt_mask,
|
46 |
+
memory_mask=memory_mask,
|
47 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
48 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
49 |
+
pos=pos,
|
50 |
+
query_pos=query_pos,
|
51 |
+
residual=residual,
|
52 |
+
)
|
53 |
+
atten_layers.append(ws)
|
54 |
+
if self.return_intermediate:
|
55 |
+
intermediate.append(self.norm(output))
|
56 |
+
if self.norm is not None:
|
57 |
+
output = self.norm(output)
|
58 |
+
if self.return_intermediate:
|
59 |
+
intermediate.pop()
|
60 |
+
intermediate.append(output)
|
61 |
+
|
62 |
+
if self.return_intermediate:
|
63 |
+
return torch.stack(intermediate)
|
64 |
+
return output, atten_layers
|
65 |
+
|
66 |
+
|
67 |
+
class TransformerDecoderLayer(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
d_model,
|
71 |
+
nhead,
|
72 |
+
dim_feedforward=2048,
|
73 |
+
dropout=0.1,
|
74 |
+
activation="relu",
|
75 |
+
normalize_before=False,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
79 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
80 |
+
# Implementation of Feedforward model
|
81 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
82 |
+
self.dropout = nn.Dropout(dropout)
|
83 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
84 |
+
|
85 |
+
self.norm1 = nn.LayerNorm(d_model)
|
86 |
+
self.norm2 = nn.LayerNorm(d_model)
|
87 |
+
self.norm3 = nn.LayerNorm(d_model)
|
88 |
+
self.dropout1 = nn.Dropout(dropout)
|
89 |
+
self.dropout2 = nn.Dropout(dropout)
|
90 |
+
self.dropout3 = nn.Dropout(dropout)
|
91 |
+
|
92 |
+
self.activation = _get_activation_fn(activation)
|
93 |
+
self.normalize_before = normalize_before
|
94 |
+
|
95 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
96 |
+
return tensor if pos is None else tensor + pos
|
97 |
+
|
98 |
+
def forward_post(
|
99 |
+
self,
|
100 |
+
tgt,
|
101 |
+
memory,
|
102 |
+
tgt_mask: Optional[Tensor] = None,
|
103 |
+
memory_mask: Optional[Tensor] = None,
|
104 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
105 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
106 |
+
pos: Optional[Tensor] = None,
|
107 |
+
query_pos: Optional[Tensor] = None,
|
108 |
+
residual=True,
|
109 |
+
):
|
110 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
111 |
+
tgt2, ws = self.self_attn(
|
112 |
+
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
113 |
+
)
|
114 |
+
tgt = self.norm1(tgt)
|
115 |
+
tgt2, ws = self.multihead_attn(
|
116 |
+
query=self.with_pos_embed(tgt, query_pos),
|
117 |
+
key=self.with_pos_embed(memory, pos),
|
118 |
+
value=memory,
|
119 |
+
attn_mask=memory_mask,
|
120 |
+
key_padding_mask=memory_key_padding_mask,
|
121 |
+
)
|
122 |
+
|
123 |
+
# attn_weights [B,NUM_Q,T]
|
124 |
+
tgt = tgt + self.dropout2(tgt2)
|
125 |
+
tgt = self.norm2(tgt)
|
126 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
127 |
+
tgt = tgt + self.dropout3(tgt2)
|
128 |
+
tgt = self.norm3(tgt)
|
129 |
+
return tgt, ws
|
130 |
+
|
131 |
+
def forward_pre(
|
132 |
+
self,
|
133 |
+
tgt,
|
134 |
+
memory,
|
135 |
+
tgt_mask: Optional[Tensor] = None,
|
136 |
+
memory_mask: Optional[Tensor] = None,
|
137 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
138 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
139 |
+
pos: Optional[Tensor] = None,
|
140 |
+
query_pos: Optional[Tensor] = None,
|
141 |
+
):
|
142 |
+
tgt2 = self.norm1(tgt)
|
143 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
144 |
+
tgt2, ws = self.self_attn(
|
145 |
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
146 |
+
)
|
147 |
+
tgt = tgt + self.dropout1(tgt2)
|
148 |
+
tgt2 = self.norm2(tgt)
|
149 |
+
tgt2, attn_weights = self.multihead_attn(
|
150 |
+
query=self.with_pos_embed(tgt2, query_pos),
|
151 |
+
key=self.with_pos_embed(memory, pos),
|
152 |
+
value=memory,
|
153 |
+
attn_mask=memory_mask,
|
154 |
+
key_padding_mask=memory_key_padding_mask,
|
155 |
+
)
|
156 |
+
tgt = tgt + self.dropout2(tgt2)
|
157 |
+
tgt2 = self.norm3(tgt)
|
158 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
159 |
+
tgt = tgt + self.dropout3(tgt2)
|
160 |
+
return tgt, attn_weights
|
161 |
+
|
162 |
+
def forward(
|
163 |
+
self,
|
164 |
+
tgt,
|
165 |
+
memory,
|
166 |
+
tgt_mask: Optional[Tensor] = None,
|
167 |
+
memory_mask: Optional[Tensor] = None,
|
168 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
169 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
170 |
+
pos: Optional[Tensor] = None,
|
171 |
+
query_pos: Optional[Tensor] = None,
|
172 |
+
residual=True,
|
173 |
+
):
|
174 |
+
if self.normalize_before:
|
175 |
+
return self.forward_pre(
|
176 |
+
tgt,
|
177 |
+
memory,
|
178 |
+
tgt_mask,
|
179 |
+
memory_mask,
|
180 |
+
tgt_key_padding_mask,
|
181 |
+
memory_key_padding_mask,
|
182 |
+
pos,
|
183 |
+
query_pos,
|
184 |
+
)
|
185 |
+
return self.forward_post(
|
186 |
+
tgt,
|
187 |
+
memory,
|
188 |
+
tgt_mask,
|
189 |
+
memory_mask,
|
190 |
+
tgt_key_padding_mask,
|
191 |
+
memory_key_padding_mask,
|
192 |
+
pos,
|
193 |
+
query_pos,
|
194 |
+
residual,
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
def _get_clones(module, N):
|
199 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
200 |
+
|
201 |
+
|
202 |
+
def _get_activation_fn(activation):
|
203 |
+
"""Return an activation function given a string"""
|
204 |
+
if activation == "relu":
|
205 |
+
return F.relu
|
206 |
+
if activation == "gelu":
|
207 |
+
return F.gelu
|
208 |
+
if activation == "glu":
|
209 |
+
return F.glu
|
210 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
PreTrain_MeDSLIP/optim/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .adamp import AdamP
|
2 |
+
from .adamw import AdamW
|
3 |
+
from .adafactor import Adafactor
|
4 |
+
from .adahessian import Adahessian
|
5 |
+
from .lookahead import Lookahead
|
6 |
+
from .nadam import Nadam
|
7 |
+
from .novograd import NovoGrad
|
8 |
+
from .nvnovograd import NvNovoGrad
|
9 |
+
from .radam import RAdam
|
10 |
+
from .rmsprop_tf import RMSpropTF
|
11 |
+
from .sgdp import SGDP
|
12 |
+
|
13 |
+
from .optim_factory import create_optimizer
|
PreTrain_MeDSLIP/optim/adafactor.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Adafactor Optimizer
|
2 |
+
|
3 |
+
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
4 |
+
|
5 |
+
Original header/copyright below.
|
6 |
+
|
7 |
+
"""
|
8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
9 |
+
#
|
10 |
+
# This source code is licensed under the MIT license found in the
|
11 |
+
# LICENSE file in the root directory of this source tree.
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
|
15 |
+
|
16 |
+
class Adafactor(torch.optim.Optimizer):
|
17 |
+
"""Implements Adafactor algorithm.
|
18 |
+
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
|
19 |
+
(see https://arxiv.org/abs/1804.04235)
|
20 |
+
|
21 |
+
Note that this optimizer internally adjusts the learning rate depending on the
|
22 |
+
*scale_parameter*, *relative_step* and *warmup_init* options.
|
23 |
+
|
24 |
+
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
25 |
+
`relative_step=False`.
|
26 |
+
|
27 |
+
Arguments:
|
28 |
+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
29 |
+
lr (float, optional): external learning rate (default: None)
|
30 |
+
eps (tuple[float, float]): regularization constants for square gradient
|
31 |
+
and parameter scale respectively (default: (1e-30, 1e-3))
|
32 |
+
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
|
33 |
+
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
|
34 |
+
beta1 (float): coefficient used for computing running averages of gradient (default: None)
|
35 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
36 |
+
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
|
37 |
+
relative_step (bool): if True, time-dependent learning rate is computed
|
38 |
+
instead of external learning rate (default: True)
|
39 |
+
warmup_init (bool): time-dependent learning rate computation depends on
|
40 |
+
whether warm-up initialization is being used (default: False)
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
params,
|
46 |
+
lr=None,
|
47 |
+
eps=1e-30,
|
48 |
+
eps_scale=1e-3,
|
49 |
+
clip_threshold=1.0,
|
50 |
+
decay_rate=-0.8,
|
51 |
+
betas=None,
|
52 |
+
weight_decay=0.0,
|
53 |
+
scale_parameter=True,
|
54 |
+
warmup_init=False,
|
55 |
+
):
|
56 |
+
relative_step = lr is None
|
57 |
+
if warmup_init and not relative_step:
|
58 |
+
raise ValueError("warmup_init requires relative_step=True")
|
59 |
+
|
60 |
+
beta1 = (
|
61 |
+
None if betas is None else betas[0]
|
62 |
+
) # make it compat with standard betas arg
|
63 |
+
defaults = dict(
|
64 |
+
lr=lr,
|
65 |
+
eps=eps,
|
66 |
+
eps_scale=eps_scale,
|
67 |
+
clip_threshold=clip_threshold,
|
68 |
+
decay_rate=decay_rate,
|
69 |
+
beta1=beta1,
|
70 |
+
weight_decay=weight_decay,
|
71 |
+
scale_parameter=scale_parameter,
|
72 |
+
relative_step=relative_step,
|
73 |
+
warmup_init=warmup_init,
|
74 |
+
)
|
75 |
+
super(Adafactor, self).__init__(params, defaults)
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def _get_lr(param_group, param_state):
|
79 |
+
if param_group["relative_step"]:
|
80 |
+
min_step = (
|
81 |
+
1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
|
82 |
+
)
|
83 |
+
lr_t = min(min_step, 1.0 / math.sqrt(param_state["step"]))
|
84 |
+
param_scale = 1.0
|
85 |
+
if param_group["scale_parameter"]:
|
86 |
+
param_scale = max(param_group["eps_scale"], param_state["RMS"])
|
87 |
+
param_group["lr"] = lr_t * param_scale
|
88 |
+
return param_group["lr"]
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def _get_options(param_group, param_shape):
|
92 |
+
factored = len(param_shape) >= 2
|
93 |
+
use_first_moment = param_group["beta1"] is not None
|
94 |
+
return factored, use_first_moment
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def _rms(tensor):
|
98 |
+
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
99 |
+
|
100 |
+
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
|
101 |
+
r_factor = (
|
102 |
+
(exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
|
103 |
+
.rsqrt_()
|
104 |
+
.unsqueeze(-1)
|
105 |
+
)
|
106 |
+
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
107 |
+
return torch.mul(r_factor, c_factor)
|
108 |
+
|
109 |
+
def step(self, closure=None):
|
110 |
+
"""Performs a single optimization step.
|
111 |
+
Arguments:
|
112 |
+
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
113 |
+
"""
|
114 |
+
loss = None
|
115 |
+
if closure is not None:
|
116 |
+
loss = closure()
|
117 |
+
|
118 |
+
for group in self.param_groups:
|
119 |
+
for p in group["params"]:
|
120 |
+
if p.grad is None:
|
121 |
+
continue
|
122 |
+
grad = p.grad.data
|
123 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
124 |
+
grad = grad.float()
|
125 |
+
if grad.is_sparse:
|
126 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
127 |
+
|
128 |
+
state = self.state[p]
|
129 |
+
grad_shape = grad.shape
|
130 |
+
|
131 |
+
factored, use_first_moment = self._get_options(group, grad_shape)
|
132 |
+
# State Initialization
|
133 |
+
if len(state) == 0:
|
134 |
+
state["step"] = 0
|
135 |
+
|
136 |
+
if use_first_moment:
|
137 |
+
# Exponential moving average of gradient values
|
138 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
139 |
+
if factored:
|
140 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
141 |
+
state["exp_avg_sq_col"] = torch.zeros(
|
142 |
+
grad_shape[:-2] + grad_shape[-1:]
|
143 |
+
).to(grad)
|
144 |
+
else:
|
145 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
146 |
+
|
147 |
+
state["RMS"] = 0
|
148 |
+
else:
|
149 |
+
if use_first_moment:
|
150 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
151 |
+
if factored:
|
152 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
153 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
154 |
+
else:
|
155 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
156 |
+
|
157 |
+
p_data_fp32 = p.data
|
158 |
+
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
159 |
+
p_data_fp32 = p_data_fp32.float()
|
160 |
+
|
161 |
+
state["step"] += 1
|
162 |
+
state["RMS"] = self._rms(p_data_fp32)
|
163 |
+
lr_t = self._get_lr(group, state)
|
164 |
+
|
165 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
166 |
+
update = grad ** 2 + group["eps"]
|
167 |
+
if factored:
|
168 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
169 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
170 |
+
|
171 |
+
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
|
172 |
+
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
|
173 |
+
# exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
|
174 |
+
# exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
|
175 |
+
|
176 |
+
# Approximation of exponential moving average of square of gradient
|
177 |
+
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
178 |
+
update.mul_(grad)
|
179 |
+
else:
|
180 |
+
exp_avg_sq = state["exp_avg_sq"]
|
181 |
+
|
182 |
+
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
|
183 |
+
# exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
|
184 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
185 |
+
|
186 |
+
update.div_(
|
187 |
+
(self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)
|
188 |
+
)
|
189 |
+
update.mul_(lr_t)
|
190 |
+
|
191 |
+
if use_first_moment:
|
192 |
+
exp_avg = state["exp_avg"]
|
193 |
+
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
|
194 |
+
# exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
|
195 |
+
update = exp_avg
|
196 |
+
|
197 |
+
if group["weight_decay"] != 0:
|
198 |
+
p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
|
199 |
+
# p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
|
200 |
+
|
201 |
+
p_data_fp32.add_(-update)
|
202 |
+
|
203 |
+
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
204 |
+
p.data.copy_(p_data_fp32)
|
205 |
+
|
206 |
+
return loss
|
PreTrain_MeDSLIP/optim/adahessian.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" AdaHessian Optimizer
|
2 |
+
|
3 |
+
Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
|
4 |
+
Originally licensed MIT, Copyright 2020, David Samuel
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class Adahessian(torch.optim.Optimizer):
|
10 |
+
"""
|
11 |
+
Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
|
12 |
+
|
13 |
+
Arguments:
|
14 |
+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
15 |
+
lr (float, optional): learning rate (default: 0.1)
|
16 |
+
betas ((float, float), optional): coefficients used for computing running averages of gradient and the
|
17 |
+
squared hessian trace (default: (0.9, 0.999))
|
18 |
+
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
|
19 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
|
20 |
+
hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
|
21 |
+
update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
|
22 |
+
(to save time) (default: 1)
|
23 |
+
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
params,
|
29 |
+
lr=0.1,
|
30 |
+
betas=(0.9, 0.999),
|
31 |
+
eps=1e-8,
|
32 |
+
weight_decay=0.0,
|
33 |
+
hessian_power=1.0,
|
34 |
+
update_each=1,
|
35 |
+
n_samples=1,
|
36 |
+
avg_conv_kernel=False,
|
37 |
+
):
|
38 |
+
if not 0.0 <= lr:
|
39 |
+
raise ValueError(f"Invalid learning rate: {lr}")
|
40 |
+
if not 0.0 <= eps:
|
41 |
+
raise ValueError(f"Invalid epsilon value: {eps}")
|
42 |
+
if not 0.0 <= betas[0] < 1.0:
|
43 |
+
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
44 |
+
if not 0.0 <= betas[1] < 1.0:
|
45 |
+
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
46 |
+
if not 0.0 <= hessian_power <= 1.0:
|
47 |
+
raise ValueError(f"Invalid Hessian power value: {hessian_power}")
|
48 |
+
|
49 |
+
self.n_samples = n_samples
|
50 |
+
self.update_each = update_each
|
51 |
+
self.avg_conv_kernel = avg_conv_kernel
|
52 |
+
|
53 |
+
# use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
|
54 |
+
self.seed = 2147483647
|
55 |
+
self.generator = torch.Generator().manual_seed(self.seed)
|
56 |
+
|
57 |
+
defaults = dict(
|
58 |
+
lr=lr,
|
59 |
+
betas=betas,
|
60 |
+
eps=eps,
|
61 |
+
weight_decay=weight_decay,
|
62 |
+
hessian_power=hessian_power,
|
63 |
+
)
|
64 |
+
super(Adahessian, self).__init__(params, defaults)
|
65 |
+
|
66 |
+
for p in self.get_params():
|
67 |
+
p.hess = 0.0
|
68 |
+
self.state[p]["hessian step"] = 0
|
69 |
+
|
70 |
+
@property
|
71 |
+
def is_second_order(self):
|
72 |
+
return True
|
73 |
+
|
74 |
+
def get_params(self):
|
75 |
+
"""
|
76 |
+
Gets all parameters in all param_groups with gradients
|
77 |
+
"""
|
78 |
+
|
79 |
+
return (
|
80 |
+
p for group in self.param_groups for p in group["params"] if p.requires_grad
|
81 |
+
)
|
82 |
+
|
83 |
+
def zero_hessian(self):
|
84 |
+
"""
|
85 |
+
Zeros out the accumalated hessian traces.
|
86 |
+
"""
|
87 |
+
|
88 |
+
for p in self.get_params():
|
89 |
+
if (
|
90 |
+
not isinstance(p.hess, float)
|
91 |
+
and self.state[p]["hessian step"] % self.update_each == 0
|
92 |
+
):
|
93 |
+
p.hess.zero_()
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def set_hessian(self):
|
97 |
+
"""
|
98 |
+
Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
|
99 |
+
"""
|
100 |
+
|
101 |
+
params = []
|
102 |
+
for p in filter(lambda p: p.grad is not None, self.get_params()):
|
103 |
+
if (
|
104 |
+
self.state[p]["hessian step"] % self.update_each == 0
|
105 |
+
): # compute the trace only each `update_each` step
|
106 |
+
params.append(p)
|
107 |
+
self.state[p]["hessian step"] += 1
|
108 |
+
|
109 |
+
if len(params) == 0:
|
110 |
+
return
|
111 |
+
|
112 |
+
if (
|
113 |
+
self.generator.device != params[0].device
|
114 |
+
): # hackish way of casting the generator to the right device
|
115 |
+
self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
|
116 |
+
|
117 |
+
grads = [p.grad for p in params]
|
118 |
+
|
119 |
+
for i in range(self.n_samples):
|
120 |
+
# Rademacher distribution {-1.0, 1.0}
|
121 |
+
zs = [
|
122 |
+
torch.randint(0, 2, p.size(), generator=self.generator, device=p.device)
|
123 |
+
* 2.0
|
124 |
+
- 1.0
|
125 |
+
for p in params
|
126 |
+
]
|
127 |
+
h_zs = torch.autograd.grad(
|
128 |
+
grads,
|
129 |
+
params,
|
130 |
+
grad_outputs=zs,
|
131 |
+
only_inputs=True,
|
132 |
+
retain_graph=i < self.n_samples - 1,
|
133 |
+
)
|
134 |
+
for h_z, z, p in zip(h_zs, zs, params):
|
135 |
+
p.hess += (
|
136 |
+
h_z * z / self.n_samples
|
137 |
+
) # approximate the expected values of z*(H@z)
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def step(self, closure=None):
|
141 |
+
"""
|
142 |
+
Performs a single optimization step.
|
143 |
+
Arguments:
|
144 |
+
closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
|
145 |
+
"""
|
146 |
+
|
147 |
+
loss = None
|
148 |
+
if closure is not None:
|
149 |
+
loss = closure()
|
150 |
+
|
151 |
+
self.zero_hessian()
|
152 |
+
self.set_hessian()
|
153 |
+
|
154 |
+
for group in self.param_groups:
|
155 |
+
for p in group["params"]:
|
156 |
+
if p.grad is None or p.hess is None:
|
157 |
+
continue
|
158 |
+
|
159 |
+
if self.avg_conv_kernel and p.dim() == 4:
|
160 |
+
p.hess = (
|
161 |
+
torch.abs(p.hess)
|
162 |
+
.mean(dim=[2, 3], keepdim=True)
|
163 |
+
.expand_as(p.hess)
|
164 |
+
.clone()
|
165 |
+
)
|
166 |
+
|
167 |
+
# Perform correct stepweight decay as in AdamW
|
168 |
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
169 |
+
|
170 |
+
state = self.state[p]
|
171 |
+
|
172 |
+
# State initialization
|
173 |
+
if len(state) == 1:
|
174 |
+
state["step"] = 0
|
175 |
+
# Exponential moving average of gradient values
|
176 |
+
state["exp_avg"] = torch.zeros_like(p)
|
177 |
+
# Exponential moving average of Hessian diagonal square values
|
178 |
+
state["exp_hessian_diag_sq"] = torch.zeros_like(p)
|
179 |
+
|
180 |
+
exp_avg, exp_hessian_diag_sq = (
|
181 |
+
state["exp_avg"],
|
182 |
+
state["exp_hessian_diag_sq"],
|
183 |
+
)
|
184 |
+
beta1, beta2 = group["betas"]
|
185 |
+
state["step"] += 1
|
186 |
+
|
187 |
+
# Decay the first and second moment running average coefficient
|
188 |
+
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
|
189 |
+
exp_hessian_diag_sq.mul_(beta2).addcmul_(
|
190 |
+
p.hess, p.hess, value=1 - beta2
|
191 |
+
)
|
192 |
+
|
193 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
194 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
195 |
+
|
196 |
+
k = group["hessian_power"]
|
197 |
+
denom = (
|
198 |
+
(exp_hessian_diag_sq / bias_correction2)
|
199 |
+
.pow_(k / 2)
|
200 |
+
.add_(group["eps"])
|
201 |
+
)
|
202 |
+
|
203 |
+
# make update
|
204 |
+
step_size = group["lr"] / bias_correction1
|
205 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
206 |
+
|
207 |
+
return loss
|
PreTrain_MeDSLIP/optim/adamp.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
|
3 |
+
|
4 |
+
Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
|
5 |
+
Code: https://github.com/clovaai/AdamP
|
6 |
+
|
7 |
+
Copyright (c) 2020-present NAVER Corp.
|
8 |
+
MIT license
|
9 |
+
"""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.optim.optimizer import Optimizer, required
|
14 |
+
import math
|
15 |
+
|
16 |
+
|
17 |
+
class AdamP(Optimizer):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
params,
|
21 |
+
lr=1e-3,
|
22 |
+
betas=(0.9, 0.999),
|
23 |
+
eps=1e-8,
|
24 |
+
weight_decay=0,
|
25 |
+
delta=0.1,
|
26 |
+
wd_ratio=0.1,
|
27 |
+
nesterov=False,
|
28 |
+
):
|
29 |
+
defaults = dict(
|
30 |
+
lr=lr,
|
31 |
+
betas=betas,
|
32 |
+
eps=eps,
|
33 |
+
weight_decay=weight_decay,
|
34 |
+
delta=delta,
|
35 |
+
wd_ratio=wd_ratio,
|
36 |
+
nesterov=nesterov,
|
37 |
+
)
|
38 |
+
super(AdamP, self).__init__(params, defaults)
|
39 |
+
|
40 |
+
def _channel_view(self, x):
|
41 |
+
return x.view(x.size(0), -1)
|
42 |
+
|
43 |
+
def _layer_view(self, x):
|
44 |
+
return x.view(1, -1)
|
45 |
+
|
46 |
+
def _cosine_similarity(self, x, y, eps, view_func):
|
47 |
+
x = view_func(x)
|
48 |
+
y = view_func(y)
|
49 |
+
|
50 |
+
x_norm = x.norm(dim=1).add_(eps)
|
51 |
+
y_norm = y.norm(dim=1).add_(eps)
|
52 |
+
dot = (x * y).sum(dim=1)
|
53 |
+
|
54 |
+
return dot.abs() / x_norm / y_norm
|
55 |
+
|
56 |
+
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
|
57 |
+
wd = 1
|
58 |
+
expand_size = [-1] + [1] * (len(p.shape) - 1)
|
59 |
+
for view_func in [self._channel_view, self._layer_view]:
|
60 |
+
|
61 |
+
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
|
62 |
+
|
63 |
+
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
|
64 |
+
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
|
65 |
+
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
|
66 |
+
wd = wd_ratio
|
67 |
+
|
68 |
+
return perturb, wd
|
69 |
+
|
70 |
+
return perturb, wd
|
71 |
+
|
72 |
+
def step(self, closure=None):
|
73 |
+
loss = None
|
74 |
+
if closure is not None:
|
75 |
+
loss = closure()
|
76 |
+
|
77 |
+
for group in self.param_groups:
|
78 |
+
for p in group["params"]:
|
79 |
+
if p.grad is None:
|
80 |
+
continue
|
81 |
+
|
82 |
+
grad = p.grad.data
|
83 |
+
beta1, beta2 = group["betas"]
|
84 |
+
nesterov = group["nesterov"]
|
85 |
+
|
86 |
+
state = self.state[p]
|
87 |
+
|
88 |
+
# State initialization
|
89 |
+
if len(state) == 0:
|
90 |
+
state["step"] = 0
|
91 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
92 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
93 |
+
|
94 |
+
# Adam
|
95 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
96 |
+
|
97 |
+
state["step"] += 1
|
98 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
99 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
100 |
+
|
101 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
102 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
103 |
+
|
104 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
|
105 |
+
group["eps"]
|
106 |
+
)
|
107 |
+
step_size = group["lr"] / bias_correction1
|
108 |
+
|
109 |
+
if nesterov:
|
110 |
+
perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
|
111 |
+
else:
|
112 |
+
perturb = exp_avg / denom
|
113 |
+
|
114 |
+
# Projection
|
115 |
+
wd_ratio = 1
|
116 |
+
if len(p.shape) > 1:
|
117 |
+
perturb, wd_ratio = self._projection(
|
118 |
+
p,
|
119 |
+
grad,
|
120 |
+
perturb,
|
121 |
+
group["delta"],
|
122 |
+
group["wd_ratio"],
|
123 |
+
group["eps"],
|
124 |
+
)
|
125 |
+
|
126 |
+
# Weight decay
|
127 |
+
if group["weight_decay"] > 0:
|
128 |
+
p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio)
|
129 |
+
|
130 |
+
# Step
|
131 |
+
p.data.add_(-step_size, perturb)
|
132 |
+
|
133 |
+
return loss
|
PreTrain_MeDSLIP/optim/adamw.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" AdamW Optimizer
|
2 |
+
Impl copied from PyTorch master
|
3 |
+
"""
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
from torch.optim.optimizer import Optimizer
|
7 |
+
|
8 |
+
|
9 |
+
class AdamW(Optimizer):
|
10 |
+
r"""Implements AdamW algorithm.
|
11 |
+
|
12 |
+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
13 |
+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
14 |
+
|
15 |
+
Arguments:
|
16 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
17 |
+
parameter groups
|
18 |
+
lr (float, optional): learning rate (default: 1e-3)
|
19 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
20 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
21 |
+
eps (float, optional): term added to the denominator to improve
|
22 |
+
numerical stability (default: 1e-8)
|
23 |
+
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
24 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
25 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
26 |
+
(default: False)
|
27 |
+
|
28 |
+
.. _Adam\: A Method for Stochastic Optimization:
|
29 |
+
https://arxiv.org/abs/1412.6980
|
30 |
+
.. _Decoupled Weight Decay Regularization:
|
31 |
+
https://arxiv.org/abs/1711.05101
|
32 |
+
.. _On the Convergence of Adam and Beyond:
|
33 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
params,
|
39 |
+
lr=1e-3,
|
40 |
+
betas=(0.9, 0.999),
|
41 |
+
eps=1e-8,
|
42 |
+
weight_decay=1e-2,
|
43 |
+
amsgrad=False,
|
44 |
+
):
|
45 |
+
if not 0.0 <= lr:
|
46 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
47 |
+
if not 0.0 <= eps:
|
48 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
49 |
+
if not 0.0 <= betas[0] < 1.0:
|
50 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
51 |
+
if not 0.0 <= betas[1] < 1.0:
|
52 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
53 |
+
defaults = dict(
|
54 |
+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
|
55 |
+
)
|
56 |
+
super(AdamW, self).__init__(params, defaults)
|
57 |
+
|
58 |
+
def __setstate__(self, state):
|
59 |
+
super(AdamW, self).__setstate__(state)
|
60 |
+
for group in self.param_groups:
|
61 |
+
group.setdefault("amsgrad", False)
|
62 |
+
|
63 |
+
def step(self, closure=None):
|
64 |
+
"""Performs a single optimization step.
|
65 |
+
|
66 |
+
Arguments:
|
67 |
+
closure (callable, optional): A closure that reevaluates the model
|
68 |
+
and returns the loss.
|
69 |
+
"""
|
70 |
+
loss = None
|
71 |
+
if closure is not None:
|
72 |
+
loss = closure()
|
73 |
+
|
74 |
+
for group in self.param_groups:
|
75 |
+
for p in group["params"]:
|
76 |
+
if p.grad is None:
|
77 |
+
continue
|
78 |
+
|
79 |
+
# Perform stepweight decay
|
80 |
+
p.data.mul_(1 - group["lr"] * group["weight_decay"])
|
81 |
+
|
82 |
+
# Perform optimization step
|
83 |
+
grad = p.grad.data
|
84 |
+
if grad.is_sparse:
|
85 |
+
raise RuntimeError(
|
86 |
+
"Adam does not support sparse gradients, please consider SparseAdam instead"
|
87 |
+
)
|
88 |
+
amsgrad = group["amsgrad"]
|
89 |
+
|
90 |
+
state = self.state[p]
|
91 |
+
|
92 |
+
# State initialization
|
93 |
+
if len(state) == 0:
|
94 |
+
state["step"] = 0
|
95 |
+
# Exponential moving average of gradient values
|
96 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
97 |
+
# Exponential moving average of squared gradient values
|
98 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
99 |
+
if amsgrad:
|
100 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
101 |
+
state["max_exp_avg_sq"] = torch.zeros_like(p.data)
|
102 |
+
|
103 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
104 |
+
if amsgrad:
|
105 |
+
max_exp_avg_sq = state["max_exp_avg_sq"]
|
106 |
+
beta1, beta2 = group["betas"]
|
107 |
+
|
108 |
+
state["step"] += 1
|
109 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
110 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
111 |
+
|
112 |
+
# Decay the first and second moment running average coefficient
|
113 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
114 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
115 |
+
if amsgrad:
|
116 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
117 |
+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
118 |
+
# Use the max. for normalizing running avg. of gradient
|
119 |
+
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
|
120 |
+
group["eps"]
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
|
124 |
+
group["eps"]
|
125 |
+
)
|
126 |
+
|
127 |
+
step_size = group["lr"] / bias_correction1
|
128 |
+
|
129 |
+
p.data.addcdiv_(-step_size, exp_avg, denom)
|
130 |
+
|
131 |
+
return loss
|
PreTrain_MeDSLIP/optim/lookahead.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Lookahead Optimizer Wrapper.
|
2 |
+
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
|
3 |
+
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
|
4 |
+
|
5 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
from torch.optim.optimizer import Optimizer
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
|
12 |
+
class Lookahead(Optimizer):
|
13 |
+
def __init__(self, base_optimizer, alpha=0.5, k=6):
|
14 |
+
if not 0.0 <= alpha <= 1.0:
|
15 |
+
raise ValueError(f"Invalid slow update rate: {alpha}")
|
16 |
+
if not 1 <= k:
|
17 |
+
raise ValueError(f"Invalid lookahead steps: {k}")
|
18 |
+
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
|
19 |
+
self.base_optimizer = base_optimizer
|
20 |
+
self.param_groups = self.base_optimizer.param_groups
|
21 |
+
self.defaults = base_optimizer.defaults
|
22 |
+
self.defaults.update(defaults)
|
23 |
+
self.state = defaultdict(dict)
|
24 |
+
# manually add our defaults to the param groups
|
25 |
+
for name, default in defaults.items():
|
26 |
+
for group in self.param_groups:
|
27 |
+
group.setdefault(name, default)
|
28 |
+
|
29 |
+
def update_slow(self, group):
|
30 |
+
for fast_p in group["params"]:
|
31 |
+
if fast_p.grad is None:
|
32 |
+
continue
|
33 |
+
param_state = self.state[fast_p]
|
34 |
+
if "slow_buffer" not in param_state:
|
35 |
+
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
|
36 |
+
param_state["slow_buffer"].copy_(fast_p.data)
|
37 |
+
slow = param_state["slow_buffer"]
|
38 |
+
slow.add_(group["lookahead_alpha"], fast_p.data - slow)
|
39 |
+
fast_p.data.copy_(slow)
|
40 |
+
|
41 |
+
def sync_lookahead(self):
|
42 |
+
for group in self.param_groups:
|
43 |
+
self.update_slow(group)
|
44 |
+
|
45 |
+
def step(self, closure=None):
|
46 |
+
# assert id(self.param_groups) == id(self.base_optimizer.param_groups)
|
47 |
+
loss = self.base_optimizer.step(closure)
|
48 |
+
for group in self.param_groups:
|
49 |
+
group["lookahead_step"] += 1
|
50 |
+
if group["lookahead_step"] % group["lookahead_k"] == 0:
|
51 |
+
self.update_slow(group)
|
52 |
+
return loss
|
53 |
+
|
54 |
+
def state_dict(self):
|
55 |
+
fast_state_dict = self.base_optimizer.state_dict()
|
56 |
+
slow_state = {
|
57 |
+
(id(k) if isinstance(k, torch.Tensor) else k): v
|
58 |
+
for k, v in self.state.items()
|
59 |
+
}
|
60 |
+
fast_state = fast_state_dict["state"]
|
61 |
+
param_groups = fast_state_dict["param_groups"]
|
62 |
+
return {
|
63 |
+
"state": fast_state,
|
64 |
+
"slow_state": slow_state,
|
65 |
+
"param_groups": param_groups,
|
66 |
+
}
|
67 |
+
|
68 |
+
def load_state_dict(self, state_dict):
|
69 |
+
fast_state_dict = {
|
70 |
+
"state": state_dict["state"],
|
71 |
+
"param_groups": state_dict["param_groups"],
|
72 |
+
}
|
73 |
+
self.base_optimizer.load_state_dict(fast_state_dict)
|
74 |
+
|
75 |
+
# We want to restore the slow state, but share param_groups reference
|
76 |
+
# with base_optimizer. This is a bit redundant but least code
|
77 |
+
slow_state_new = False
|
78 |
+
if "slow_state" not in state_dict:
|
79 |
+
print("Loading state_dict from optimizer without Lookahead applied.")
|
80 |
+
state_dict["slow_state"] = defaultdict(dict)
|
81 |
+
slow_state_new = True
|
82 |
+
slow_state_dict = {
|
83 |
+
"state": state_dict["slow_state"],
|
84 |
+
"param_groups": state_dict[
|
85 |
+
"param_groups"
|
86 |
+
], # this is pointless but saves code
|
87 |
+
}
|
88 |
+
super(Lookahead, self).load_state_dict(slow_state_dict)
|
89 |
+
self.param_groups = (
|
90 |
+
self.base_optimizer.param_groups
|
91 |
+
) # make both ref same container
|
92 |
+
if slow_state_new:
|
93 |
+
# reapply defaults to catch missing lookahead specific ones
|
94 |
+
for name, default in self.defaults.items():
|
95 |
+
for group in self.param_groups:
|
96 |
+
group.setdefault(name, default)
|
PreTrain_MeDSLIP/optim/nadam.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim import Optimizer
|
3 |
+
|
4 |
+
|
5 |
+
class Nadam(Optimizer):
|
6 |
+
"""Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
|
7 |
+
|
8 |
+
It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
|
9 |
+
|
10 |
+
Arguments:
|
11 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
12 |
+
parameter groups
|
13 |
+
lr (float, optional): learning rate (default: 2e-3)
|
14 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
15 |
+
running averages of gradient and its square
|
16 |
+
eps (float, optional): term added to the denominator to improve
|
17 |
+
numerical stability (default: 1e-8)
|
18 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
19 |
+
schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
|
20 |
+
|
21 |
+
__ http://cs229.stanford.edu/proj2015/054_report.pdf
|
22 |
+
__ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
|
23 |
+
|
24 |
+
Originally taken from: https://github.com/pytorch/pytorch/pull/1408
|
25 |
+
NOTE: Has potential issues but does work well on some problems.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
params,
|
31 |
+
lr=2e-3,
|
32 |
+
betas=(0.9, 0.999),
|
33 |
+
eps=1e-8,
|
34 |
+
weight_decay=0,
|
35 |
+
schedule_decay=4e-3,
|
36 |
+
):
|
37 |
+
defaults = dict(
|
38 |
+
lr=lr,
|
39 |
+
betas=betas,
|
40 |
+
eps=eps,
|
41 |
+
weight_decay=weight_decay,
|
42 |
+
schedule_decay=schedule_decay,
|
43 |
+
)
|
44 |
+
super(Nadam, self).__init__(params, defaults)
|
45 |
+
|
46 |
+
def step(self, closure=None):
|
47 |
+
"""Performs a single optimization step.
|
48 |
+
|
49 |
+
Arguments:
|
50 |
+
closure (callable, optional): A closure that reevaluates the model
|
51 |
+
and returns the loss.
|
52 |
+
"""
|
53 |
+
loss = None
|
54 |
+
if closure is not None:
|
55 |
+
loss = closure()
|
56 |
+
|
57 |
+
for group in self.param_groups:
|
58 |
+
for p in group["params"]:
|
59 |
+
if p.grad is None:
|
60 |
+
continue
|
61 |
+
grad = p.grad.data
|
62 |
+
state = self.state[p]
|
63 |
+
|
64 |
+
# State initialization
|
65 |
+
if len(state) == 0:
|
66 |
+
state["step"] = 0
|
67 |
+
state["m_schedule"] = 1.0
|
68 |
+
state["exp_avg"] = grad.new().resize_as_(grad).zero_()
|
69 |
+
state["exp_avg_sq"] = grad.new().resize_as_(grad).zero_()
|
70 |
+
|
71 |
+
# Warming momentum schedule
|
72 |
+
m_schedule = state["m_schedule"]
|
73 |
+
schedule_decay = group["schedule_decay"]
|
74 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
75 |
+
beta1, beta2 = group["betas"]
|
76 |
+
eps = group["eps"]
|
77 |
+
state["step"] += 1
|
78 |
+
t = state["step"]
|
79 |
+
|
80 |
+
if group["weight_decay"] != 0:
|
81 |
+
grad = grad.add(group["weight_decay"], p.data)
|
82 |
+
|
83 |
+
momentum_cache_t = beta1 * (1.0 - 0.5 * (0.96 ** (t * schedule_decay)))
|
84 |
+
momentum_cache_t_1 = beta1 * (
|
85 |
+
1.0 - 0.5 * (0.96 ** ((t + 1) * schedule_decay))
|
86 |
+
)
|
87 |
+
m_schedule_new = m_schedule * momentum_cache_t
|
88 |
+
m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
|
89 |
+
state["m_schedule"] = m_schedule_new
|
90 |
+
|
91 |
+
# Decay the first and second moment running average coefficient
|
92 |
+
exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
|
93 |
+
exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
|
94 |
+
exp_avg_sq_prime = exp_avg_sq / (1.0 - beta2 ** t)
|
95 |
+
denom = exp_avg_sq_prime.sqrt_().add_(eps)
|
96 |
+
|
97 |
+
p.data.addcdiv_(
|
98 |
+
-group["lr"] * (1.0 - momentum_cache_t) / (1.0 - m_schedule_new),
|
99 |
+
grad,
|
100 |
+
denom,
|
101 |
+
)
|
102 |
+
p.data.addcdiv_(
|
103 |
+
-group["lr"] * momentum_cache_t_1 / (1.0 - m_schedule_next),
|
104 |
+
exp_avg,
|
105 |
+
denom,
|
106 |
+
)
|
107 |
+
|
108 |
+
return loss
|
PreTrain_MeDSLIP/optim/novograd.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""NovoGrad Optimizer.
|
2 |
+
Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
|
3 |
+
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
|
4 |
+
- https://arxiv.org/abs/1905.11286
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.optim.optimizer import Optimizer
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
class NovoGrad(Optimizer):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
params,
|
16 |
+
grad_averaging=False,
|
17 |
+
lr=0.1,
|
18 |
+
betas=(0.95, 0.98),
|
19 |
+
eps=1e-8,
|
20 |
+
weight_decay=0,
|
21 |
+
):
|
22 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
23 |
+
super(NovoGrad, self).__init__(params, defaults)
|
24 |
+
self._lr = lr
|
25 |
+
self._beta1 = betas[0]
|
26 |
+
self._beta2 = betas[1]
|
27 |
+
self._eps = eps
|
28 |
+
self._wd = weight_decay
|
29 |
+
self._grad_averaging = grad_averaging
|
30 |
+
|
31 |
+
self._momentum_initialized = False
|
32 |
+
|
33 |
+
def step(self, closure=None):
|
34 |
+
loss = None
|
35 |
+
if closure is not None:
|
36 |
+
loss = closure()
|
37 |
+
|
38 |
+
if not self._momentum_initialized:
|
39 |
+
for group in self.param_groups:
|
40 |
+
for p in group["params"]:
|
41 |
+
if p.grad is None:
|
42 |
+
continue
|
43 |
+
state = self.state[p]
|
44 |
+
grad = p.grad.data
|
45 |
+
if grad.is_sparse:
|
46 |
+
raise RuntimeError("NovoGrad does not support sparse gradients")
|
47 |
+
|
48 |
+
v = torch.norm(grad) ** 2
|
49 |
+
m = grad / (torch.sqrt(v) + self._eps) + self._wd * p.data
|
50 |
+
state["step"] = 0
|
51 |
+
state["v"] = v
|
52 |
+
state["m"] = m
|
53 |
+
state["grad_ema"] = None
|
54 |
+
self._momentum_initialized = True
|
55 |
+
|
56 |
+
for group in self.param_groups:
|
57 |
+
for p in group["params"]:
|
58 |
+
if p.grad is None:
|
59 |
+
continue
|
60 |
+
state = self.state[p]
|
61 |
+
state["step"] += 1
|
62 |
+
|
63 |
+
step, v, m = state["step"], state["v"], state["m"]
|
64 |
+
grad_ema = state["grad_ema"]
|
65 |
+
|
66 |
+
grad = p.grad.data
|
67 |
+
g2 = torch.norm(grad) ** 2
|
68 |
+
grad_ema = (
|
69 |
+
g2
|
70 |
+
if grad_ema is None
|
71 |
+
else grad_ema * self._beta2 + g2 * (1.0 - self._beta2)
|
72 |
+
)
|
73 |
+
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
|
74 |
+
|
75 |
+
if self._grad_averaging:
|
76 |
+
grad *= 1.0 - self._beta1
|
77 |
+
|
78 |
+
g2 = torch.norm(grad) ** 2
|
79 |
+
v = self._beta2 * v + (1.0 - self._beta2) * g2
|
80 |
+
m = self._beta1 * m + (
|
81 |
+
grad / (torch.sqrt(v) + self._eps) + self._wd * p.data
|
82 |
+
)
|
83 |
+
bias_correction1 = 1 - self._beta1 ** step
|
84 |
+
bias_correction2 = 1 - self._beta2 ** step
|
85 |
+
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
|
86 |
+
|
87 |
+
state["v"], state["m"] = v, m
|
88 |
+
state["grad_ema"] = grad_ema
|
89 |
+
p.data.add_(-step_size, m)
|
90 |
+
return loss
|
PreTrain_MeDSLIP/optim/nvnovograd.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Nvidia NovoGrad Optimizer.
|
2 |
+
Original impl by Nvidia from Jasper example:
|
3 |
+
- https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
|
4 |
+
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
|
5 |
+
- https://arxiv.org/abs/1905.11286
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.optim.optimizer import Optimizer
|
10 |
+
import math
|
11 |
+
|
12 |
+
|
13 |
+
class NvNovoGrad(Optimizer):
|
14 |
+
"""
|
15 |
+
Implements Novograd algorithm.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
19 |
+
parameter groups
|
20 |
+
lr (float, optional): learning rate (default: 1e-3)
|
21 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
22 |
+
running averages of gradient and its square (default: (0.95, 0.98))
|
23 |
+
eps (float, optional): term added to the denominator to improve
|
24 |
+
numerical stability (default: 1e-8)
|
25 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
26 |
+
grad_averaging: gradient averaging
|
27 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
28 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
29 |
+
(default: False)
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
params,
|
35 |
+
lr=1e-3,
|
36 |
+
betas=(0.95, 0.98),
|
37 |
+
eps=1e-8,
|
38 |
+
weight_decay=0,
|
39 |
+
grad_averaging=False,
|
40 |
+
amsgrad=False,
|
41 |
+
):
|
42 |
+
if not 0.0 <= lr:
|
43 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
44 |
+
if not 0.0 <= eps:
|
45 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
46 |
+
if not 0.0 <= betas[0] < 1.0:
|
47 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
48 |
+
if not 0.0 <= betas[1] < 1.0:
|
49 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
50 |
+
defaults = dict(
|
51 |
+
lr=lr,
|
52 |
+
betas=betas,
|
53 |
+
eps=eps,
|
54 |
+
weight_decay=weight_decay,
|
55 |
+
grad_averaging=grad_averaging,
|
56 |
+
amsgrad=amsgrad,
|
57 |
+
)
|
58 |
+
|
59 |
+
super(NvNovoGrad, self).__init__(params, defaults)
|
60 |
+
|
61 |
+
def __setstate__(self, state):
|
62 |
+
super(NvNovoGrad, self).__setstate__(state)
|
63 |
+
for group in self.param_groups:
|
64 |
+
group.setdefault("amsgrad", False)
|
65 |
+
|
66 |
+
def step(self, closure=None):
|
67 |
+
"""Performs a single optimization step.
|
68 |
+
|
69 |
+
Arguments:
|
70 |
+
closure (callable, optional): A closure that reevaluates the model
|
71 |
+
and returns the loss.
|
72 |
+
"""
|
73 |
+
loss = None
|
74 |
+
if closure is not None:
|
75 |
+
loss = closure()
|
76 |
+
|
77 |
+
for group in self.param_groups:
|
78 |
+
for p in group["params"]:
|
79 |
+
if p.grad is None:
|
80 |
+
continue
|
81 |
+
grad = p.grad.data
|
82 |
+
if grad.is_sparse:
|
83 |
+
raise RuntimeError("Sparse gradients are not supported.")
|
84 |
+
amsgrad = group["amsgrad"]
|
85 |
+
|
86 |
+
state = self.state[p]
|
87 |
+
|
88 |
+
# State initialization
|
89 |
+
if len(state) == 0:
|
90 |
+
state["step"] = 0
|
91 |
+
# Exponential moving average of gradient values
|
92 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
93 |
+
# Exponential moving average of squared gradient values
|
94 |
+
state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)
|
95 |
+
if amsgrad:
|
96 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
97 |
+
state["max_exp_avg_sq"] = torch.zeros([]).to(
|
98 |
+
state["exp_avg"].device
|
99 |
+
)
|
100 |
+
|
101 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
102 |
+
if amsgrad:
|
103 |
+
max_exp_avg_sq = state["max_exp_avg_sq"]
|
104 |
+
beta1, beta2 = group["betas"]
|
105 |
+
|
106 |
+
state["step"] += 1
|
107 |
+
|
108 |
+
norm = torch.sum(torch.pow(grad, 2))
|
109 |
+
|
110 |
+
if exp_avg_sq == 0:
|
111 |
+
exp_avg_sq.copy_(norm)
|
112 |
+
else:
|
113 |
+
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
|
114 |
+
|
115 |
+
if amsgrad:
|
116 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
117 |
+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
118 |
+
# Use the max. for normalizing running avg. of gradient
|
119 |
+
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
|
120 |
+
else:
|
121 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
122 |
+
|
123 |
+
grad.div_(denom)
|
124 |
+
if group["weight_decay"] != 0:
|
125 |
+
grad.add_(group["weight_decay"], p.data)
|
126 |
+
if group["grad_averaging"]:
|
127 |
+
grad.mul_(1 - beta1)
|
128 |
+
exp_avg.mul_(beta1).add_(grad)
|
129 |
+
|
130 |
+
p.data.add_(-group["lr"], exp_avg)
|
131 |
+
|
132 |
+
return loss
|
PreTrain_MeDSLIP/optim/optim_factory.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Optimizer Factory w/ Custom Weight Decay
|
2 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
from torch import optim as optim
|
6 |
+
|
7 |
+
from .adafactor import Adafactor
|
8 |
+
from .adahessian import Adahessian
|
9 |
+
from .adamp import AdamP
|
10 |
+
from .lookahead import Lookahead
|
11 |
+
from .nadam import Nadam
|
12 |
+
from .novograd import NovoGrad
|
13 |
+
from .nvnovograd import NvNovoGrad
|
14 |
+
from .radam import RAdam
|
15 |
+
from .rmsprop_tf import RMSpropTF
|
16 |
+
from .sgdp import SGDP
|
17 |
+
|
18 |
+
try:
|
19 |
+
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
20 |
+
|
21 |
+
has_apex = True
|
22 |
+
except ImportError:
|
23 |
+
has_apex = False
|
24 |
+
|
25 |
+
|
26 |
+
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
27 |
+
decay = []
|
28 |
+
no_decay = []
|
29 |
+
for name, param in model.named_parameters():
|
30 |
+
if not param.requires_grad:
|
31 |
+
continue # frozen weights
|
32 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
|
33 |
+
no_decay.append(param)
|
34 |
+
else:
|
35 |
+
decay.append(param)
|
36 |
+
return [
|
37 |
+
{"params": no_decay, "weight_decay": 0.0},
|
38 |
+
{"params": decay, "weight_decay": weight_decay},
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
+
def create_optimizer(args, model, filter_bias_and_bn=True):
|
43 |
+
opt_lower = args.opt.lower()
|
44 |
+
weight_decay = args.weight_decay
|
45 |
+
if weight_decay and filter_bias_and_bn:
|
46 |
+
skip = {}
|
47 |
+
if hasattr(model, "no_weight_decay"):
|
48 |
+
skip = model.no_weight_decay()
|
49 |
+
parameters = add_weight_decay(model, weight_decay, skip)
|
50 |
+
weight_decay = 0.0
|
51 |
+
else:
|
52 |
+
parameters = filter(
|
53 |
+
lambda p: p.requires_grad, model.parameters()
|
54 |
+
) # model.parameters()
|
55 |
+
|
56 |
+
if "fused" in opt_lower:
|
57 |
+
assert (
|
58 |
+
has_apex and torch.cuda.is_available()
|
59 |
+
), "APEX and CUDA required for fused optimizers"
|
60 |
+
|
61 |
+
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
|
62 |
+
if hasattr(args, "opt_eps") and args.opt_eps is not None:
|
63 |
+
opt_args["eps"] = args.opt_eps
|
64 |
+
if hasattr(args, "opt_betas") and args.opt_betas is not None:
|
65 |
+
opt_args["betas"] = args.opt_betas
|
66 |
+
if hasattr(args, "opt_args") and args.opt_args is not None:
|
67 |
+
opt_args.update(args.opt_args)
|
68 |
+
|
69 |
+
opt_split = opt_lower.split("_")
|
70 |
+
opt_lower = opt_split[-1]
|
71 |
+
if opt_lower == "sgd" or opt_lower == "nesterov":
|
72 |
+
opt_args.pop("eps", None)
|
73 |
+
optimizer = optim.SGD(
|
74 |
+
parameters, momentum=args.momentum, nesterov=True, **opt_args
|
75 |
+
)
|
76 |
+
elif opt_lower == "momentum":
|
77 |
+
opt_args.pop("eps", None)
|
78 |
+
optimizer = optim.SGD(
|
79 |
+
parameters, momentum=args.momentum, nesterov=False, **opt_args
|
80 |
+
)
|
81 |
+
elif opt_lower == "adam":
|
82 |
+
optimizer = optim.Adam(parameters, **opt_args)
|
83 |
+
elif opt_lower == "adamw":
|
84 |
+
optimizer = optim.AdamW(parameters, **opt_args)
|
85 |
+
elif opt_lower == "nadam":
|
86 |
+
optimizer = Nadam(parameters, **opt_args)
|
87 |
+
elif opt_lower == "radam":
|
88 |
+
optimizer = RAdam(parameters, **opt_args)
|
89 |
+
elif opt_lower == "adamp":
|
90 |
+
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
91 |
+
elif opt_lower == "sgdp":
|
92 |
+
optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
93 |
+
elif opt_lower == "adadelta":
|
94 |
+
optimizer = optim.Adadelta(parameters, **opt_args)
|
95 |
+
elif opt_lower == "adafactor":
|
96 |
+
if not args.lr:
|
97 |
+
opt_args["lr"] = None
|
98 |
+
optimizer = Adafactor(parameters, **opt_args)
|
99 |
+
elif opt_lower == "adahessian":
|
100 |
+
optimizer = Adahessian(parameters, **opt_args)
|
101 |
+
elif opt_lower == "rmsprop":
|
102 |
+
optimizer = optim.RMSprop(
|
103 |
+
parameters, alpha=0.9, momentum=args.momentum, **opt_args
|
104 |
+
)
|
105 |
+
elif opt_lower == "rmsproptf":
|
106 |
+
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
107 |
+
elif opt_lower == "novograd":
|
108 |
+
optimizer = NovoGrad(parameters, **opt_args)
|
109 |
+
elif opt_lower == "nvnovograd":
|
110 |
+
optimizer = NvNovoGrad(parameters, **opt_args)
|
111 |
+
elif opt_lower == "fusedsgd":
|
112 |
+
opt_args.pop("eps", None)
|
113 |
+
optimizer = FusedSGD(
|
114 |
+
parameters, momentum=args.momentum, nesterov=True, **opt_args
|
115 |
+
)
|
116 |
+
elif opt_lower == "fusedmomentum":
|
117 |
+
opt_args.pop("eps", None)
|
118 |
+
optimizer = FusedSGD(
|
119 |
+
parameters, momentum=args.momentum, nesterov=False, **opt_args
|
120 |
+
)
|
121 |
+
elif opt_lower == "fusedadam":
|
122 |
+
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
123 |
+
elif opt_lower == "fusedadamw":
|
124 |
+
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
|
125 |
+
elif opt_lower == "fusedlamb":
|
126 |
+
optimizer = FusedLAMB(parameters, **opt_args)
|
127 |
+
elif opt_lower == "fusednovograd":
|
128 |
+
opt_args.setdefault("betas", (0.95, 0.98))
|
129 |
+
optimizer = FusedNovoGrad(parameters, **opt_args)
|
130 |
+
else:
|
131 |
+
assert False and "Invalid optimizer"
|
132 |
+
raise ValueError
|
133 |
+
|
134 |
+
if len(opt_split) > 1:
|
135 |
+
if opt_split[0] == "lookahead":
|
136 |
+
optimizer = Lookahead(optimizer)
|
137 |
+
|
138 |
+
return optimizer
|
PreTrain_MeDSLIP/optim/radam.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""RAdam Optimizer.
|
2 |
+
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
|
3 |
+
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch.optim.optimizer import Optimizer, required
|
8 |
+
|
9 |
+
|
10 |
+
class RAdam(Optimizer):
|
11 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
12 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
13 |
+
self.buffer = [[None, None, None] for ind in range(10)]
|
14 |
+
super(RAdam, self).__init__(params, defaults)
|
15 |
+
|
16 |
+
def __setstate__(self, state):
|
17 |
+
super(RAdam, self).__setstate__(state)
|
18 |
+
|
19 |
+
def step(self, closure=None):
|
20 |
+
|
21 |
+
loss = None
|
22 |
+
if closure is not None:
|
23 |
+
loss = closure()
|
24 |
+
|
25 |
+
for group in self.param_groups:
|
26 |
+
|
27 |
+
for p in group["params"]:
|
28 |
+
if p.grad is None:
|
29 |
+
continue
|
30 |
+
grad = p.grad.data.float()
|
31 |
+
if grad.is_sparse:
|
32 |
+
raise RuntimeError("RAdam does not support sparse gradients")
|
33 |
+
|
34 |
+
p_data_fp32 = p.data.float()
|
35 |
+
|
36 |
+
state = self.state[p]
|
37 |
+
|
38 |
+
if len(state) == 0:
|
39 |
+
state["step"] = 0
|
40 |
+
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
41 |
+
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
42 |
+
else:
|
43 |
+
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
44 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
45 |
+
|
46 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
47 |
+
beta1, beta2 = group["betas"]
|
48 |
+
|
49 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
50 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
51 |
+
|
52 |
+
state["step"] += 1
|
53 |
+
buffered = self.buffer[int(state["step"] % 10)]
|
54 |
+
if state["step"] == buffered[0]:
|
55 |
+
N_sma, step_size = buffered[1], buffered[2]
|
56 |
+
else:
|
57 |
+
buffered[0] = state["step"]
|
58 |
+
beta2_t = beta2 ** state["step"]
|
59 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
60 |
+
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
61 |
+
buffered[1] = N_sma
|
62 |
+
|
63 |
+
# more conservative since it's an approximated value
|
64 |
+
if N_sma >= 5:
|
65 |
+
step_size = (
|
66 |
+
group["lr"]
|
67 |
+
* math.sqrt(
|
68 |
+
(1 - beta2_t)
|
69 |
+
* (N_sma - 4)
|
70 |
+
/ (N_sma_max - 4)
|
71 |
+
* (N_sma - 2)
|
72 |
+
/ N_sma
|
73 |
+
* N_sma_max
|
74 |
+
/ (N_sma_max - 2)
|
75 |
+
)
|
76 |
+
/ (1 - beta1 ** state["step"])
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
step_size = group["lr"] / (1 - beta1 ** state["step"])
|
80 |
+
buffered[2] = step_size
|
81 |
+
|
82 |
+
if group["weight_decay"] != 0:
|
83 |
+
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
84 |
+
|
85 |
+
# more conservative since it's an approximated value
|
86 |
+
if N_sma >= 5:
|
87 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
88 |
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
89 |
+
else:
|
90 |
+
p_data_fp32.add_(-step_size, exp_avg)
|
91 |
+
|
92 |
+
p.data.copy_(p_data_fp32)
|
93 |
+
|
94 |
+
return loss
|
95 |
+
|
96 |
+
|
97 |
+
class PlainRAdam(Optimizer):
|
98 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
99 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
100 |
+
|
101 |
+
super(PlainRAdam, self).__init__(params, defaults)
|
102 |
+
|
103 |
+
def __setstate__(self, state):
|
104 |
+
super(PlainRAdam, self).__setstate__(state)
|
105 |
+
|
106 |
+
def step(self, closure=None):
|
107 |
+
|
108 |
+
loss = None
|
109 |
+
if closure is not None:
|
110 |
+
loss = closure()
|
111 |
+
|
112 |
+
for group in self.param_groups:
|
113 |
+
|
114 |
+
for p in group["params"]:
|
115 |
+
if p.grad is None:
|
116 |
+
continue
|
117 |
+
grad = p.grad.data.float()
|
118 |
+
if grad.is_sparse:
|
119 |
+
raise RuntimeError("RAdam does not support sparse gradients")
|
120 |
+
|
121 |
+
p_data_fp32 = p.data.float()
|
122 |
+
|
123 |
+
state = self.state[p]
|
124 |
+
|
125 |
+
if len(state) == 0:
|
126 |
+
state["step"] = 0
|
127 |
+
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
128 |
+
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
129 |
+
else:
|
130 |
+
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
131 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
132 |
+
|
133 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
134 |
+
beta1, beta2 = group["betas"]
|
135 |
+
|
136 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
137 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
138 |
+
|
139 |
+
state["step"] += 1
|
140 |
+
beta2_t = beta2 ** state["step"]
|
141 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
142 |
+
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
143 |
+
|
144 |
+
if group["weight_decay"] != 0:
|
145 |
+
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
146 |
+
|
147 |
+
# more conservative since it's an approximated value
|
148 |
+
if N_sma >= 5:
|
149 |
+
step_size = (
|
150 |
+
group["lr"]
|
151 |
+
* math.sqrt(
|
152 |
+
(1 - beta2_t)
|
153 |
+
* (N_sma - 4)
|
154 |
+
/ (N_sma_max - 4)
|
155 |
+
* (N_sma - 2)
|
156 |
+
/ N_sma
|
157 |
+
* N_sma_max
|
158 |
+
/ (N_sma_max - 2)
|
159 |
+
)
|
160 |
+
/ (1 - beta1 ** state["step"])
|
161 |
+
)
|
162 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
163 |
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
164 |
+
else:
|
165 |
+
step_size = group["lr"] / (1 - beta1 ** state["step"])
|
166 |
+
p_data_fp32.add_(-step_size, exp_avg)
|
167 |
+
|
168 |
+
p.data.copy_(p_data_fp32)
|
169 |
+
|
170 |
+
return loss
|
PreTrain_MeDSLIP/optim/rmsprop_tf.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" RMSProp modified to behave like Tensorflow impl
|
2 |
+
|
3 |
+
Originally cut & paste from PyTorch RMSProp
|
4 |
+
https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
|
5 |
+
Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
|
6 |
+
|
7 |
+
Modifications Copyright 2020 Ross Wightman
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.optim import Optimizer
|
12 |
+
|
13 |
+
|
14 |
+
class RMSpropTF(Optimizer):
|
15 |
+
"""Implements RMSprop algorithm (TensorFlow style epsilon)
|
16 |
+
|
17 |
+
NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
|
18 |
+
and a few other modifications to closer match Tensorflow for matching hyper-params.
|
19 |
+
|
20 |
+
Noteworthy changes include:
|
21 |
+
1. Epsilon applied inside square-root
|
22 |
+
2. square_avg initialized to ones
|
23 |
+
3. LR scaling of update accumulated in momentum buffer
|
24 |
+
|
25 |
+
Proposed by G. Hinton in his
|
26 |
+
`course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
|
27 |
+
|
28 |
+
The centered version first appears in `Generating Sequences
|
29 |
+
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
|
30 |
+
|
31 |
+
Arguments:
|
32 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
33 |
+
parameter groups
|
34 |
+
lr (float, optional): learning rate (default: 1e-2)
|
35 |
+
momentum (float, optional): momentum factor (default: 0)
|
36 |
+
alpha (float, optional): smoothing (decay) constant (default: 0.9)
|
37 |
+
eps (float, optional): term added to the denominator to improve
|
38 |
+
numerical stability (default: 1e-10)
|
39 |
+
centered (bool, optional) : if ``True``, compute the centered RMSProp,
|
40 |
+
the gradient is normalized by an estimation of its variance
|
41 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
42 |
+
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
|
43 |
+
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
|
44 |
+
update as per defaults in Tensorflow
|
45 |
+
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
params,
|
51 |
+
lr=1e-2,
|
52 |
+
alpha=0.9,
|
53 |
+
eps=1e-10,
|
54 |
+
weight_decay=0,
|
55 |
+
momentum=0.0,
|
56 |
+
centered=False,
|
57 |
+
decoupled_decay=False,
|
58 |
+
lr_in_momentum=True,
|
59 |
+
):
|
60 |
+
if not 0.0 <= lr:
|
61 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
62 |
+
if not 0.0 <= eps:
|
63 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
64 |
+
if not 0.0 <= momentum:
|
65 |
+
raise ValueError("Invalid momentum value: {}".format(momentum))
|
66 |
+
if not 0.0 <= weight_decay:
|
67 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
68 |
+
if not 0.0 <= alpha:
|
69 |
+
raise ValueError("Invalid alpha value: {}".format(alpha))
|
70 |
+
|
71 |
+
defaults = dict(
|
72 |
+
lr=lr,
|
73 |
+
momentum=momentum,
|
74 |
+
alpha=alpha,
|
75 |
+
eps=eps,
|
76 |
+
centered=centered,
|
77 |
+
weight_decay=weight_decay,
|
78 |
+
decoupled_decay=decoupled_decay,
|
79 |
+
lr_in_momentum=lr_in_momentum,
|
80 |
+
)
|
81 |
+
super(RMSpropTF, self).__init__(params, defaults)
|
82 |
+
|
83 |
+
def __setstate__(self, state):
|
84 |
+
super(RMSpropTF, self).__setstate__(state)
|
85 |
+
for group in self.param_groups:
|
86 |
+
group.setdefault("momentum", 0)
|
87 |
+
group.setdefault("centered", False)
|
88 |
+
|
89 |
+
def step(self, closure=None):
|
90 |
+
"""Performs a single optimization step.
|
91 |
+
|
92 |
+
Arguments:
|
93 |
+
closure (callable, optional): A closure that reevaluates the model
|
94 |
+
and returns the loss.
|
95 |
+
"""
|
96 |
+
loss = None
|
97 |
+
if closure is not None:
|
98 |
+
loss = closure()
|
99 |
+
|
100 |
+
for group in self.param_groups:
|
101 |
+
for p in group["params"]:
|
102 |
+
if p.grad is None:
|
103 |
+
continue
|
104 |
+
grad = p.grad.data
|
105 |
+
if grad.is_sparse:
|
106 |
+
raise RuntimeError("RMSprop does not support sparse gradients")
|
107 |
+
state = self.state[p]
|
108 |
+
|
109 |
+
# State initialization
|
110 |
+
if len(state) == 0:
|
111 |
+
state["step"] = 0
|
112 |
+
state["square_avg"] = torch.ones_like(
|
113 |
+
p.data
|
114 |
+
) # PyTorch inits to zero
|
115 |
+
if group["momentum"] > 0:
|
116 |
+
state["momentum_buffer"] = torch.zeros_like(p.data)
|
117 |
+
if group["centered"]:
|
118 |
+
state["grad_avg"] = torch.zeros_like(p.data)
|
119 |
+
|
120 |
+
square_avg = state["square_avg"]
|
121 |
+
one_minus_alpha = 1.0 - group["alpha"]
|
122 |
+
|
123 |
+
state["step"] += 1
|
124 |
+
|
125 |
+
if group["weight_decay"] != 0:
|
126 |
+
if "decoupled_decay" in group and group["decoupled_decay"]:
|
127 |
+
p.data.add_(-group["weight_decay"], p.data)
|
128 |
+
else:
|
129 |
+
grad = grad.add(group["weight_decay"], p.data)
|
130 |
+
|
131 |
+
# Tensorflow order of ops for updating squared avg
|
132 |
+
square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
|
133 |
+
# square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
|
134 |
+
|
135 |
+
if group["centered"]:
|
136 |
+
grad_avg = state["grad_avg"]
|
137 |
+
grad_avg.add_(one_minus_alpha, grad - grad_avg)
|
138 |
+
# grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
|
139 |
+
avg = (
|
140 |
+
square_avg.addcmul(-1, grad_avg, grad_avg)
|
141 |
+
.add(group["eps"])
|
142 |
+
.sqrt_()
|
143 |
+
) # eps moved in sqrt
|
144 |
+
else:
|
145 |
+
avg = square_avg.add(group["eps"]).sqrt_() # eps moved in sqrt
|
146 |
+
|
147 |
+
if group["momentum"] > 0:
|
148 |
+
buf = state["momentum_buffer"]
|
149 |
+
# Tensorflow accumulates the LR scaling in the momentum buffer
|
150 |
+
if "lr_in_momentum" in group and group["lr_in_momentum"]:
|
151 |
+
buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
|
152 |
+
p.data.add_(-buf)
|
153 |
+
else:
|
154 |
+
# PyTorch scales the param update by LR
|
155 |
+
buf.mul_(group["momentum"]).addcdiv_(grad, avg)
|
156 |
+
p.data.add_(-group["lr"], buf)
|
157 |
+
else:
|
158 |
+
p.data.addcdiv_(-group["lr"], grad, avg)
|
159 |
+
|
160 |
+
return loss
|
PreTrain_MeDSLIP/optim/sgdp.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
|
3 |
+
|
4 |
+
Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
|
5 |
+
Code: https://github.com/clovaai/AdamP
|
6 |
+
|
7 |
+
Copyright (c) 2020-present NAVER Corp.
|
8 |
+
MIT license
|
9 |
+
"""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.optim.optimizer import Optimizer, required
|
14 |
+
import math
|
15 |
+
|
16 |
+
|
17 |
+
class SGDP(Optimizer):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
params,
|
21 |
+
lr=required,
|
22 |
+
momentum=0,
|
23 |
+
dampening=0,
|
24 |
+
weight_decay=0,
|
25 |
+
nesterov=False,
|
26 |
+
eps=1e-8,
|
27 |
+
delta=0.1,
|
28 |
+
wd_ratio=0.1,
|
29 |
+
):
|
30 |
+
defaults = dict(
|
31 |
+
lr=lr,
|
32 |
+
momentum=momentum,
|
33 |
+
dampening=dampening,
|
34 |
+
weight_decay=weight_decay,
|
35 |
+
nesterov=nesterov,
|
36 |
+
eps=eps,
|
37 |
+
delta=delta,
|
38 |
+
wd_ratio=wd_ratio,
|
39 |
+
)
|
40 |
+
super(SGDP, self).__init__(params, defaults)
|
41 |
+
|
42 |
+
def _channel_view(self, x):
|
43 |
+
return x.view(x.size(0), -1)
|
44 |
+
|
45 |
+
def _layer_view(self, x):
|
46 |
+
return x.view(1, -1)
|
47 |
+
|
48 |
+
def _cosine_similarity(self, x, y, eps, view_func):
|
49 |
+
x = view_func(x)
|
50 |
+
y = view_func(y)
|
51 |
+
|
52 |
+
x_norm = x.norm(dim=1).add_(eps)
|
53 |
+
y_norm = y.norm(dim=1).add_(eps)
|
54 |
+
dot = (x * y).sum(dim=1)
|
55 |
+
|
56 |
+
return dot.abs() / x_norm / y_norm
|
57 |
+
|
58 |
+
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
|
59 |
+
wd = 1
|
60 |
+
expand_size = [-1] + [1] * (len(p.shape) - 1)
|
61 |
+
for view_func in [self._channel_view, self._layer_view]:
|
62 |
+
|
63 |
+
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
|
64 |
+
|
65 |
+
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
|
66 |
+
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
|
67 |
+
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
|
68 |
+
wd = wd_ratio
|
69 |
+
|
70 |
+
return perturb, wd
|
71 |
+
|
72 |
+
return perturb, wd
|
73 |
+
|
74 |
+
def step(self, closure=None):
|
75 |
+
loss = None
|
76 |
+
if closure is not None:
|
77 |
+
loss = closure()
|
78 |
+
|
79 |
+
for group in self.param_groups:
|
80 |
+
weight_decay = group["weight_decay"]
|
81 |
+
momentum = group["momentum"]
|
82 |
+
dampening = group["dampening"]
|
83 |
+
nesterov = group["nesterov"]
|
84 |
+
|
85 |
+
for p in group["params"]:
|
86 |
+
if p.grad is None:
|
87 |
+
continue
|
88 |
+
grad = p.grad.data
|
89 |
+
state = self.state[p]
|
90 |
+
|
91 |
+
# State initialization
|
92 |
+
if len(state) == 0:
|
93 |
+
state["momentum"] = torch.zeros_like(p.data)
|
94 |
+
|
95 |
+
# SGD
|
96 |
+
buf = state["momentum"]
|
97 |
+
buf.mul_(momentum).add_(1 - dampening, grad)
|
98 |
+
if nesterov:
|
99 |
+
d_p = grad + momentum * buf
|
100 |
+
else:
|
101 |
+
d_p = buf
|
102 |
+
|
103 |
+
# Projection
|
104 |
+
wd_ratio = 1
|
105 |
+
if len(p.shape) > 1:
|
106 |
+
d_p, wd_ratio = self._projection(
|
107 |
+
p, grad, d_p, group["delta"], group["wd_ratio"], group["eps"]
|
108 |
+
)
|
109 |
+
|
110 |
+
# Weight decay
|
111 |
+
if weight_decay != 0:
|
112 |
+
p.data.mul_(
|
113 |
+
1
|
114 |
+
- group["lr"]
|
115 |
+
* group["weight_decay"]
|
116 |
+
* wd_ratio
|
117 |
+
/ (1 - momentum)
|
118 |
+
)
|
119 |
+
|
120 |
+
# Step
|
121 |
+
p.data.add_(-group["lr"], d_p)
|
122 |
+
|
123 |
+
return loss
|
PreTrain_MeDSLIP/scheduler/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cosine_lr import CosineLRScheduler
|
2 |
+
from .plateau_lr import PlateauLRScheduler
|
3 |
+
from .step_lr import StepLRScheduler
|
4 |
+
from .tanh_lr import TanhLRScheduler
|
5 |
+
from .scheduler_factory import create_scheduler
|
PreTrain_MeDSLIP/scheduler/cosine_lr.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Cosine Scheduler
|
2 |
+
|
3 |
+
Cosine LR schedule with warmup, cycle/restarts, noise.
|
4 |
+
|
5 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .scheduler import Scheduler
|
13 |
+
|
14 |
+
from pdb import set_trace as breakpoint
|
15 |
+
|
16 |
+
_logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class CosineLRScheduler(Scheduler):
|
20 |
+
"""
|
21 |
+
Cosine decay with restarts.
|
22 |
+
This is described in the paper https://arxiv.org/abs/1608.03983.
|
23 |
+
|
24 |
+
Inspiration from
|
25 |
+
https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
optimizer: torch.optim.Optimizer,
|
31 |
+
t_initial: int,
|
32 |
+
t_mul: float = 1.0,
|
33 |
+
lr_min: float = 0.0,
|
34 |
+
decay_rate: float = 1.0,
|
35 |
+
warmup_t=0,
|
36 |
+
warmup_lr_init=0,
|
37 |
+
warmup_prefix=True,
|
38 |
+
cycle_limit=0,
|
39 |
+
t_in_epochs=True,
|
40 |
+
noise_range_t=None,
|
41 |
+
noise_pct=0.67,
|
42 |
+
noise_std=1.0,
|
43 |
+
noise_seed=42,
|
44 |
+
initialize=True,
|
45 |
+
) -> None:
|
46 |
+
super().__init__(
|
47 |
+
optimizer,
|
48 |
+
param_group_field="lr",
|
49 |
+
noise_range_t=noise_range_t,
|
50 |
+
noise_pct=noise_pct,
|
51 |
+
noise_std=noise_std,
|
52 |
+
noise_seed=noise_seed,
|
53 |
+
initialize=initialize,
|
54 |
+
)
|
55 |
+
|
56 |
+
assert t_initial > 0
|
57 |
+
assert lr_min >= 0
|
58 |
+
if t_initial == 1 and t_mul == 1 and decay_rate == 1:
|
59 |
+
_logger.warning(
|
60 |
+
"Cosine annealing scheduler will have no effect on the learning "
|
61 |
+
"rate since t_initial = t_mul = eta_mul = 1."
|
62 |
+
)
|
63 |
+
self.t_initial = t_initial
|
64 |
+
self.t_mul = t_mul
|
65 |
+
self.lr_min = lr_min
|
66 |
+
self.decay_rate = decay_rate
|
67 |
+
self.cycle_limit = cycle_limit
|
68 |
+
self.warmup_t = warmup_t
|
69 |
+
self.warmup_lr_init = warmup_lr_init
|
70 |
+
self.warmup_prefix = warmup_prefix
|
71 |
+
self.t_in_epochs = t_in_epochs
|
72 |
+
if self.warmup_t:
|
73 |
+
self.warmup_steps = [
|
74 |
+
(v - warmup_lr_init) / self.warmup_t for v in self.base_values
|
75 |
+
]
|
76 |
+
super().update_groups(self.warmup_lr_init)
|
77 |
+
else:
|
78 |
+
self.warmup_steps = [1 for _ in self.base_values]
|
79 |
+
|
80 |
+
def _get_lr(self, t):
|
81 |
+
if t < self.warmup_t:
|
82 |
+
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
83 |
+
else:
|
84 |
+
if self.warmup_prefix:
|
85 |
+
t = t - self.warmup_t
|
86 |
+
|
87 |
+
if self.t_mul != 1:
|
88 |
+
i = math.floor(
|
89 |
+
math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)
|
90 |
+
)
|
91 |
+
t_i = self.t_mul ** i * self.t_initial
|
92 |
+
t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
93 |
+
else:
|
94 |
+
i = t // self.t_initial
|
95 |
+
t_i = self.t_initial
|
96 |
+
t_curr = t - (self.t_initial * i)
|
97 |
+
|
98 |
+
gamma = self.decay_rate ** i
|
99 |
+
lr_min = self.lr_min * gamma
|
100 |
+
lr_max_values = [v * gamma for v in self.base_values]
|
101 |
+
|
102 |
+
if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
|
103 |
+
lrs = [
|
104 |
+
lr_min
|
105 |
+
+ 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i))
|
106 |
+
for lr_max in lr_max_values
|
107 |
+
]
|
108 |
+
else:
|
109 |
+
lrs = [self.lr_min for _ in self.base_values]
|
110 |
+
|
111 |
+
return lrs
|
112 |
+
|
113 |
+
def get_epoch_values(self, epoch: int):
|
114 |
+
if self.t_in_epochs:
|
115 |
+
return self._get_lr(epoch)
|
116 |
+
else:
|
117 |
+
return None
|
118 |
+
|
119 |
+
def get_update_values(self, num_updates: int):
|
120 |
+
if not self.t_in_epochs:
|
121 |
+
return self._get_lr(num_updates)
|
122 |
+
else:
|
123 |
+
return None
|
124 |
+
|
125 |
+
def get_cycle_length(self, cycles=0):
|
126 |
+
if not cycles:
|
127 |
+
cycles = self.cycle_limit
|
128 |
+
cycles = max(1, cycles)
|
129 |
+
if self.t_mul == 1.0:
|
130 |
+
return self.t_initial * cycles
|
131 |
+
else:
|
132 |
+
return int(
|
133 |
+
math.floor(
|
134 |
+
-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)
|
135 |
+
)
|
136 |
+
)
|
PreTrain_MeDSLIP/scheduler/plateau_lr.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Plateau Scheduler
|
2 |
+
|
3 |
+
Adapts PyTorch plateau scheduler and allows application of noise, warmup.
|
4 |
+
|
5 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .scheduler import Scheduler
|
10 |
+
|
11 |
+
|
12 |
+
class PlateauLRScheduler(Scheduler):
|
13 |
+
"""Decay the LR by a factor every time the validation loss plateaus."""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
optimizer,
|
18 |
+
decay_rate=0.1,
|
19 |
+
patience_t=10,
|
20 |
+
verbose=True,
|
21 |
+
threshold=1e-4,
|
22 |
+
cooldown_t=0,
|
23 |
+
warmup_t=0,
|
24 |
+
warmup_lr_init=0,
|
25 |
+
lr_min=0,
|
26 |
+
mode="max",
|
27 |
+
noise_range_t=None,
|
28 |
+
noise_type="normal",
|
29 |
+
noise_pct=0.67,
|
30 |
+
noise_std=1.0,
|
31 |
+
noise_seed=None,
|
32 |
+
initialize=True,
|
33 |
+
):
|
34 |
+
super().__init__(optimizer, "lr", initialize=initialize)
|
35 |
+
|
36 |
+
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
37 |
+
self.optimizer,
|
38 |
+
patience=patience_t,
|
39 |
+
factor=decay_rate,
|
40 |
+
verbose=verbose,
|
41 |
+
threshold=threshold,
|
42 |
+
cooldown=cooldown_t,
|
43 |
+
mode=mode,
|
44 |
+
min_lr=lr_min,
|
45 |
+
)
|
46 |
+
|
47 |
+
self.noise_range = noise_range_t
|
48 |
+
self.noise_pct = noise_pct
|
49 |
+
self.noise_type = noise_type
|
50 |
+
self.noise_std = noise_std
|
51 |
+
self.noise_seed = noise_seed if noise_seed is not None else 42
|
52 |
+
self.warmup_t = warmup_t
|
53 |
+
self.warmup_lr_init = warmup_lr_init
|
54 |
+
if self.warmup_t:
|
55 |
+
self.warmup_steps = [
|
56 |
+
(v - warmup_lr_init) / self.warmup_t for v in self.base_values
|
57 |
+
]
|
58 |
+
super().update_groups(self.warmup_lr_init)
|
59 |
+
else:
|
60 |
+
self.warmup_steps = [1 for _ in self.base_values]
|
61 |
+
self.restore_lr = None
|
62 |
+
|
63 |
+
def state_dict(self):
|
64 |
+
return {
|
65 |
+
"best": self.lr_scheduler.best,
|
66 |
+
"last_epoch": self.lr_scheduler.last_epoch,
|
67 |
+
}
|
68 |
+
|
69 |
+
def load_state_dict(self, state_dict):
|
70 |
+
self.lr_scheduler.best = state_dict["best"]
|
71 |
+
if "last_epoch" in state_dict:
|
72 |
+
self.lr_scheduler.last_epoch = state_dict["last_epoch"]
|
73 |
+
|
74 |
+
# override the base class step fn completely
|
75 |
+
def step(self, epoch, metric=None):
|
76 |
+
if epoch <= self.warmup_t:
|
77 |
+
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
|
78 |
+
super().update_groups(lrs)
|
79 |
+
else:
|
80 |
+
if self.restore_lr is not None:
|
81 |
+
# restore actual LR from before our last noise perturbation before stepping base
|
82 |
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
83 |
+
param_group["lr"] = self.restore_lr[i]
|
84 |
+
self.restore_lr = None
|
85 |
+
|
86 |
+
self.lr_scheduler.step(metric, epoch) # step the base scheduler
|
87 |
+
|
88 |
+
if self.noise_range is not None:
|
89 |
+
if isinstance(self.noise_range, (list, tuple)):
|
90 |
+
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
|
91 |
+
else:
|
92 |
+
apply_noise = epoch >= self.noise_range
|
93 |
+
if apply_noise:
|
94 |
+
self._apply_noise(epoch)
|
95 |
+
|
96 |
+
def _apply_noise(self, epoch):
|
97 |
+
g = torch.Generator()
|
98 |
+
g.manual_seed(self.noise_seed + epoch)
|
99 |
+
if self.noise_type == "normal":
|
100 |
+
while True:
|
101 |
+
# resample if noise out of percent limit, brute force but shouldn't spin much
|
102 |
+
noise = torch.randn(1, generator=g).item()
|
103 |
+
if abs(noise) < self.noise_pct:
|
104 |
+
break
|
105 |
+
else:
|
106 |
+
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
107 |
+
|
108 |
+
# apply the noise on top of previous LR, cache the old value so we can restore for normal
|
109 |
+
# stepping of base scheduler
|
110 |
+
restore_lr = []
|
111 |
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
112 |
+
old_lr = float(param_group["lr"])
|
113 |
+
restore_lr.append(old_lr)
|
114 |
+
new_lr = old_lr + old_lr * noise
|
115 |
+
param_group["lr"] = new_lr
|
116 |
+
self.restore_lr = restore_lr
|
PreTrain_MeDSLIP/scheduler/scheduler.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class Scheduler:
|
7 |
+
""" Parameter Scheduler Base Class
|
8 |
+
A scheduler base class that can be used to schedule any optimizer parameter groups.
|
9 |
+
|
10 |
+
Unlike the builtin PyTorch schedulers, this is intended to be consistently called
|
11 |
+
* At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
|
12 |
+
* At the END of each optimizer update, after incrementing the update count, to calculate next update's value
|
13 |
+
|
14 |
+
The schedulers built on this should try to remain as stateless as possible (for simplicity).
|
15 |
+
|
16 |
+
This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
|
17 |
+
and -1 values for special behaviour. All epoch and update counts must be tracked in the training
|
18 |
+
code and explicitly passed in to the schedulers on the corresponding step or step_update call.
|
19 |
+
|
20 |
+
Based on ideas from:
|
21 |
+
* https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
|
22 |
+
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
optimizer: torch.optim.Optimizer,
|
28 |
+
param_group_field: str,
|
29 |
+
noise_range_t=None,
|
30 |
+
noise_type="normal",
|
31 |
+
noise_pct=0.67,
|
32 |
+
noise_std=1.0,
|
33 |
+
noise_seed=None,
|
34 |
+
initialize: bool = True,
|
35 |
+
) -> None:
|
36 |
+
self.optimizer = optimizer
|
37 |
+
self.param_group_field = param_group_field
|
38 |
+
self._initial_param_group_field = f"initial_{param_group_field}"
|
39 |
+
if initialize:
|
40 |
+
for i, group in enumerate(self.optimizer.param_groups):
|
41 |
+
if param_group_field not in group:
|
42 |
+
raise KeyError(
|
43 |
+
f"{param_group_field} missing from param_groups[{i}]"
|
44 |
+
)
|
45 |
+
group.setdefault(
|
46 |
+
self._initial_param_group_field, group[param_group_field]
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
for i, group in enumerate(self.optimizer.param_groups):
|
50 |
+
if self._initial_param_group_field not in group:
|
51 |
+
raise KeyError(
|
52 |
+
f"{self._initial_param_group_field} missing from param_groups[{i}]"
|
53 |
+
)
|
54 |
+
self.base_values = [
|
55 |
+
group[self._initial_param_group_field]
|
56 |
+
for group in self.optimizer.param_groups
|
57 |
+
]
|
58 |
+
self.metric = None # any point to having this for all?
|
59 |
+
self.noise_range_t = noise_range_t
|
60 |
+
self.noise_pct = noise_pct
|
61 |
+
self.noise_type = noise_type
|
62 |
+
self.noise_std = noise_std
|
63 |
+
self.noise_seed = noise_seed if noise_seed is not None else 42
|
64 |
+
self.update_groups(self.base_values)
|
65 |
+
|
66 |
+
def state_dict(self) -> Dict[str, Any]:
|
67 |
+
return {
|
68 |
+
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
69 |
+
}
|
70 |
+
|
71 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
72 |
+
self.__dict__.update(state_dict)
|
73 |
+
|
74 |
+
def get_epoch_values(self, epoch: int):
|
75 |
+
return None
|
76 |
+
|
77 |
+
def get_update_values(self, num_updates: int):
|
78 |
+
return None
|
79 |
+
|
80 |
+
def step(self, epoch: int, metric: float = None) -> None:
|
81 |
+
self.metric = metric
|
82 |
+
values = self.get_epoch_values(epoch)
|
83 |
+
if values is not None:
|
84 |
+
values = self._add_noise(values, epoch)
|
85 |
+
self.update_groups(values)
|
86 |
+
|
87 |
+
def step_update(self, num_updates: int, metric: float = None):
|
88 |
+
self.metric = metric
|
89 |
+
values = self.get_update_values(num_updates)
|
90 |
+
if values is not None:
|
91 |
+
values = self._add_noise(values, num_updates)
|
92 |
+
self.update_groups(values)
|
93 |
+
|
94 |
+
def update_groups(self, values):
|
95 |
+
if not isinstance(values, (list, tuple)):
|
96 |
+
values = [values] * len(self.optimizer.param_groups)
|
97 |
+
for param_group, value in zip(self.optimizer.param_groups, values):
|
98 |
+
param_group[self.param_group_field] = value
|
99 |
+
|
100 |
+
def _add_noise(self, lrs, t):
|
101 |
+
if self.noise_range_t is not None:
|
102 |
+
if isinstance(self.noise_range_t, (list, tuple)):
|
103 |
+
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
|
104 |
+
else:
|
105 |
+
apply_noise = t >= self.noise_range_t
|
106 |
+
if apply_noise:
|
107 |
+
g = torch.Generator()
|
108 |
+
g.manual_seed(self.noise_seed + t)
|
109 |
+
if self.noise_type == "normal":
|
110 |
+
while True:
|
111 |
+
# resample if noise out of percent limit, brute force but shouldn't spin much
|
112 |
+
noise = torch.randn(1, generator=g).item()
|
113 |
+
if abs(noise) < self.noise_pct:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
noise = (
|
117 |
+
2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
118 |
+
)
|
119 |
+
lrs = [v + v * noise for v in lrs]
|
120 |
+
return lrs
|
PreTrain_MeDSLIP/scheduler/scheduler_factory.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Scheduler Factory
|
2 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
3 |
+
"""
|
4 |
+
from .cosine_lr import CosineLRScheduler
|
5 |
+
from .tanh_lr import TanhLRScheduler
|
6 |
+
from .step_lr import StepLRScheduler
|
7 |
+
from .plateau_lr import PlateauLRScheduler
|
8 |
+
|
9 |
+
|
10 |
+
def create_scheduler(args, optimizer):
|
11 |
+
num_epochs = args.epochs
|
12 |
+
|
13 |
+
if getattr(args, "lr_noise", None) is not None:
|
14 |
+
lr_noise = getattr(args, "lr_noise")
|
15 |
+
if isinstance(lr_noise, (list, tuple)):
|
16 |
+
noise_range = [n * num_epochs for n in lr_noise]
|
17 |
+
if len(noise_range) == 1:
|
18 |
+
noise_range = noise_range[0]
|
19 |
+
else:
|
20 |
+
noise_range = lr_noise * num_epochs
|
21 |
+
else:
|
22 |
+
noise_range = None
|
23 |
+
|
24 |
+
lr_scheduler = None
|
25 |
+
if args.sched == "cosine":
|
26 |
+
lr_scheduler = CosineLRScheduler(
|
27 |
+
optimizer,
|
28 |
+
t_initial=num_epochs,
|
29 |
+
t_mul=getattr(args, "lr_cycle_mul", 1.0),
|
30 |
+
lr_min=args.min_lr,
|
31 |
+
decay_rate=args.decay_rate,
|
32 |
+
warmup_lr_init=args.warmup_lr,
|
33 |
+
warmup_t=args.warmup_epochs,
|
34 |
+
cycle_limit=getattr(args, "lr_cycle_limit", 1),
|
35 |
+
t_in_epochs=True,
|
36 |
+
noise_range_t=noise_range,
|
37 |
+
noise_pct=getattr(args, "lr_noise_pct", 0.67),
|
38 |
+
noise_std=getattr(args, "lr_noise_std", 1.0),
|
39 |
+
noise_seed=getattr(args, "seed", 42),
|
40 |
+
)
|
41 |
+
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
42 |
+
elif args.sched == "tanh":
|
43 |
+
lr_scheduler = TanhLRScheduler(
|
44 |
+
optimizer,
|
45 |
+
t_initial=num_epochs,
|
46 |
+
t_mul=getattr(args, "lr_cycle_mul", 1.0),
|
47 |
+
lr_min=args.min_lr,
|
48 |
+
warmup_lr_init=args.warmup_lr,
|
49 |
+
warmup_t=args.warmup_epochs,
|
50 |
+
cycle_limit=getattr(args, "lr_cycle_limit", 1),
|
51 |
+
t_in_epochs=True,
|
52 |
+
noise_range_t=noise_range,
|
53 |
+
noise_pct=getattr(args, "lr_noise_pct", 0.67),
|
54 |
+
noise_std=getattr(args, "lr_noise_std", 1.0),
|
55 |
+
noise_seed=getattr(args, "seed", 42),
|
56 |
+
)
|
57 |
+
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
58 |
+
elif args.sched == "step":
|
59 |
+
lr_scheduler = StepLRScheduler(
|
60 |
+
optimizer,
|
61 |
+
decay_t=args.decay_epochs,
|
62 |
+
decay_rate=args.decay_rate,
|
63 |
+
warmup_lr_init=args.warmup_lr,
|
64 |
+
warmup_t=args.warmup_epochs,
|
65 |
+
noise_range_t=noise_range,
|
66 |
+
noise_pct=getattr(args, "lr_noise_pct", 0.67),
|
67 |
+
noise_std=getattr(args, "lr_noise_std", 1.0),
|
68 |
+
noise_seed=getattr(args, "seed", 42),
|
69 |
+
)
|
70 |
+
elif args.sched == "plateau":
|
71 |
+
mode = "min" if "loss" in getattr(args, "eval_metric", "") else "max"
|
72 |
+
lr_scheduler = PlateauLRScheduler(
|
73 |
+
optimizer,
|
74 |
+
decay_rate=args.decay_rate,
|
75 |
+
patience_t=args.patience_epochs,
|
76 |
+
lr_min=args.min_lr,
|
77 |
+
mode=mode,
|
78 |
+
warmup_lr_init=args.warmup_lr,
|
79 |
+
warmup_t=args.warmup_epochs,
|
80 |
+
cooldown_t=0,
|
81 |
+
noise_range_t=noise_range,
|
82 |
+
noise_pct=getattr(args, "lr_noise_pct", 0.67),
|
83 |
+
noise_std=getattr(args, "lr_noise_std", 1.0),
|
84 |
+
noise_seed=getattr(args, "seed", 42),
|
85 |
+
)
|
86 |
+
|
87 |
+
return lr_scheduler, num_epochs
|
PreTrain_MeDSLIP/scheduler/step_lr.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Step Scheduler
|
2 |
+
|
3 |
+
Basic step LR schedule with warmup, noise.
|
4 |
+
|
5 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
6 |
+
"""
|
7 |
+
import math
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from .scheduler import Scheduler
|
11 |
+
|
12 |
+
|
13 |
+
class StepLRScheduler(Scheduler):
|
14 |
+
"""
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
optimizer: torch.optim.Optimizer,
|
20 |
+
decay_t: float,
|
21 |
+
decay_rate: float = 1.0,
|
22 |
+
warmup_t=0,
|
23 |
+
warmup_lr_init=0,
|
24 |
+
t_in_epochs=True,
|
25 |
+
noise_range_t=None,
|
26 |
+
noise_pct=0.67,
|
27 |
+
noise_std=1.0,
|
28 |
+
noise_seed=42,
|
29 |
+
initialize=True,
|
30 |
+
) -> None:
|
31 |
+
super().__init__(
|
32 |
+
optimizer,
|
33 |
+
param_group_field="lr",
|
34 |
+
noise_range_t=noise_range_t,
|
35 |
+
noise_pct=noise_pct,
|
36 |
+
noise_std=noise_std,
|
37 |
+
noise_seed=noise_seed,
|
38 |
+
initialize=initialize,
|
39 |
+
)
|
40 |
+
|
41 |
+
self.decay_t = decay_t
|
42 |
+
self.decay_rate = decay_rate
|
43 |
+
self.warmup_t = warmup_t
|
44 |
+
self.warmup_lr_init = warmup_lr_init
|
45 |
+
self.t_in_epochs = t_in_epochs
|
46 |
+
if self.warmup_t:
|
47 |
+
self.warmup_steps = [
|
48 |
+
(v - warmup_lr_init) / self.warmup_t for v in self.base_values
|
49 |
+
]
|
50 |
+
super().update_groups(self.warmup_lr_init)
|
51 |
+
else:
|
52 |
+
self.warmup_steps = [1 for _ in self.base_values]
|
53 |
+
|
54 |
+
def _get_lr(self, t):
|
55 |
+
if t < self.warmup_t:
|
56 |
+
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
57 |
+
else:
|
58 |
+
lrs = [
|
59 |
+
v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values
|
60 |
+
]
|
61 |
+
return lrs
|
62 |
+
|
63 |
+
def get_epoch_values(self, epoch: int):
|
64 |
+
if self.t_in_epochs:
|
65 |
+
return self._get_lr(epoch)
|
66 |
+
else:
|
67 |
+
return None
|
68 |
+
|
69 |
+
def get_update_values(self, num_updates: int):
|
70 |
+
if not self.t_in_epochs:
|
71 |
+
return self._get_lr(num_updates)
|
72 |
+
else:
|
73 |
+
return None
|
PreTrain_MeDSLIP/scheduler/tanh_lr.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" TanH Scheduler
|
2 |
+
|
3 |
+
TanH schedule with warmup, cycle/restarts, noise.
|
4 |
+
|
5 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .scheduler import Scheduler
|
13 |
+
|
14 |
+
|
15 |
+
_logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class TanhLRScheduler(Scheduler):
|
19 |
+
"""
|
20 |
+
Hyberbolic-Tangent decay with restarts.
|
21 |
+
This is described in the paper https://arxiv.org/abs/1806.01593
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
optimizer: torch.optim.Optimizer,
|
27 |
+
t_initial: int,
|
28 |
+
lb: float = -6.0,
|
29 |
+
ub: float = 4.0,
|
30 |
+
t_mul: float = 1.0,
|
31 |
+
lr_min: float = 0.0,
|
32 |
+
decay_rate: float = 1.0,
|
33 |
+
warmup_t=0,
|
34 |
+
warmup_lr_init=0,
|
35 |
+
warmup_prefix=False,
|
36 |
+
cycle_limit=0,
|
37 |
+
t_in_epochs=True,
|
38 |
+
noise_range_t=None,
|
39 |
+
noise_pct=0.67,
|
40 |
+
noise_std=1.0,
|
41 |
+
noise_seed=42,
|
42 |
+
initialize=True,
|
43 |
+
) -> None:
|
44 |
+
super().__init__(
|
45 |
+
optimizer,
|
46 |
+
param_group_field="lr",
|
47 |
+
noise_range_t=noise_range_t,
|
48 |
+
noise_pct=noise_pct,
|
49 |
+
noise_std=noise_std,
|
50 |
+
noise_seed=noise_seed,
|
51 |
+
initialize=initialize,
|
52 |
+
)
|
53 |
+
|
54 |
+
assert t_initial > 0
|
55 |
+
assert lr_min >= 0
|
56 |
+
assert lb < ub
|
57 |
+
assert cycle_limit >= 0
|
58 |
+
assert warmup_t >= 0
|
59 |
+
assert warmup_lr_init >= 0
|
60 |
+
self.lb = lb
|
61 |
+
self.ub = ub
|
62 |
+
self.t_initial = t_initial
|
63 |
+
self.t_mul = t_mul
|
64 |
+
self.lr_min = lr_min
|
65 |
+
self.decay_rate = decay_rate
|
66 |
+
self.cycle_limit = cycle_limit
|
67 |
+
self.warmup_t = warmup_t
|
68 |
+
self.warmup_lr_init = warmup_lr_init
|
69 |
+
self.warmup_prefix = warmup_prefix
|
70 |
+
self.t_in_epochs = t_in_epochs
|
71 |
+
if self.warmup_t:
|
72 |
+
t_v = (
|
73 |
+
self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
|
74 |
+
)
|
75 |
+
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
|
76 |
+
super().update_groups(self.warmup_lr_init)
|
77 |
+
else:
|
78 |
+
self.warmup_steps = [1 for _ in self.base_values]
|
79 |
+
|
80 |
+
def _get_lr(self, t):
|
81 |
+
if t < self.warmup_t:
|
82 |
+
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
83 |
+
else:
|
84 |
+
if self.warmup_prefix:
|
85 |
+
t = t - self.warmup_t
|
86 |
+
|
87 |
+
if self.t_mul != 1:
|
88 |
+
i = math.floor(
|
89 |
+
math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)
|
90 |
+
)
|
91 |
+
t_i = self.t_mul ** i * self.t_initial
|
92 |
+
t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
93 |
+
else:
|
94 |
+
i = t // self.t_initial
|
95 |
+
t_i = self.t_initial
|
96 |
+
t_curr = t - (self.t_initial * i)
|
97 |
+
|
98 |
+
if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
|
99 |
+
gamma = self.decay_rate ** i
|
100 |
+
lr_min = self.lr_min * gamma
|
101 |
+
lr_max_values = [v * gamma for v in self.base_values]
|
102 |
+
|
103 |
+
tr = t_curr / t_i
|
104 |
+
lrs = [
|
105 |
+
lr_min
|
106 |
+
+ 0.5
|
107 |
+
* (lr_max - lr_min)
|
108 |
+
* (1 - math.tanh(self.lb * (1.0 - tr) + self.ub * tr))
|
109 |
+
for lr_max in lr_max_values
|
110 |
+
]
|
111 |
+
else:
|
112 |
+
lrs = [
|
113 |
+
self.lr_min * (self.decay_rate ** self.cycle_limit)
|
114 |
+
for _ in self.base_values
|
115 |
+
]
|
116 |
+
return lrs
|
117 |
+
|
118 |
+
def get_epoch_values(self, epoch: int):
|
119 |
+
if self.t_in_epochs:
|
120 |
+
return self._get_lr(epoch)
|
121 |
+
else:
|
122 |
+
return None
|
123 |
+
|
124 |
+
def get_update_values(self, num_updates: int):
|
125 |
+
if not self.t_in_epochs:
|
126 |
+
return self._get_lr(num_updates)
|
127 |
+
else:
|
128 |
+
return None
|
129 |
+
|
130 |
+
def get_cycle_length(self, cycles=0):
|
131 |
+
if not cycles:
|
132 |
+
cycles = self.cycle_limit
|
133 |
+
cycles = max(1, cycles)
|
134 |
+
if self.t_mul == 1.0:
|
135 |
+
return self.t_initial * cycles
|
136 |
+
else:
|
137 |
+
return int(
|
138 |
+
math.floor(
|
139 |
+
-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)
|
140 |
+
)
|
141 |
+
)
|
PreTrain_MeDSLIP/train_MeDSLIP.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import ruamel_yaml as yaml
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
import datetime
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
import torch.backends.cudnn as cudnn
|
19 |
+
|
20 |
+
from tensorboardX import SummaryWriter
|
21 |
+
|
22 |
+
import utils
|
23 |
+
from scheduler import create_scheduler
|
24 |
+
from optim import create_optimizer
|
25 |
+
from dataset.dataset import MeDSLIP_Dataset
|
26 |
+
from models.model_MeDSLIP import MeDSLIP
|
27 |
+
from models.tokenization_bert import BertTokenizer
|
28 |
+
|
29 |
+
|
30 |
+
def get_tokenizer(tokenizer, target_text):
|
31 |
+
|
32 |
+
target_tokenizer = tokenizer(
|
33 |
+
list(target_text),
|
34 |
+
padding="max_length",
|
35 |
+
truncation=True,
|
36 |
+
max_length=128,
|
37 |
+
return_tensors="pt",
|
38 |
+
)
|
39 |
+
|
40 |
+
return target_tokenizer
|
41 |
+
|
42 |
+
|
43 |
+
def train(
|
44 |
+
model,
|
45 |
+
data_loader,
|
46 |
+
optimizer,
|
47 |
+
epoch,
|
48 |
+
warmup_steps,
|
49 |
+
device,
|
50 |
+
scheduler,
|
51 |
+
args,
|
52 |
+
config,
|
53 |
+
writer,
|
54 |
+
):
|
55 |
+
model.train()
|
56 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
57 |
+
metric_logger.add_meter(
|
58 |
+
"lr", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
|
59 |
+
)
|
60 |
+
metric_logger.add_meter(
|
61 |
+
"loss", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
|
62 |
+
)
|
63 |
+
metric_logger.add_meter(
|
64 |
+
"loss_ce_p", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
|
65 |
+
)
|
66 |
+
metric_logger.add_meter(
|
67 |
+
"loss_cl_p", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
|
68 |
+
)
|
69 |
+
metric_logger.add_meter(
|
70 |
+
"loss_ce_a", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
|
71 |
+
)
|
72 |
+
metric_logger.add_meter(
|
73 |
+
"loss_cl_a", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
|
74 |
+
)
|
75 |
+
metric_logger.add_meter(
|
76 |
+
"loss_ap", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
|
77 |
+
)
|
78 |
+
metric_logger.update(loss=1.0)
|
79 |
+
metric_logger.update(loss_ce_p=1.0)
|
80 |
+
metric_logger.update(loss_cl_p=1.0)
|
81 |
+
metric_logger.update(loss_ce_a=1.0)
|
82 |
+
metric_logger.update(loss_cl_a=1.0)
|
83 |
+
metric_logger.update(loss_ap=1.0)
|
84 |
+
metric_logger.update(lr=scheduler._get_lr(epoch)[0])
|
85 |
+
|
86 |
+
header = "Train Epoch: [{}]".format(epoch)
|
87 |
+
print_freq = 1
|
88 |
+
step_size = 100
|
89 |
+
warmup_iterations = warmup_steps * step_size
|
90 |
+
scalar_step = epoch * len(data_loader)
|
91 |
+
|
92 |
+
for i, sample in enumerate(
|
93 |
+
metric_logger.log_every(data_loader, print_freq, header)
|
94 |
+
):
|
95 |
+
|
96 |
+
images = sample["image"].to(device)
|
97 |
+
labels_pathology = sample["label_pathology"].to(device)
|
98 |
+
labels_anatomy = sample["label_anatomy"].to(device)
|
99 |
+
index_pathology = sample["index_pathology"].to(device)
|
100 |
+
index_anatomy = sample["index_anatomy"].to(device)
|
101 |
+
matrix = sample["matrix"].to(device)
|
102 |
+
|
103 |
+
optimizer.zero_grad()
|
104 |
+
|
105 |
+
(
|
106 |
+
loss,
|
107 |
+
loss_ce_pathology,
|
108 |
+
loss_cl_pathology,
|
109 |
+
loss_ce_anatomy,
|
110 |
+
loss_cl_anatomy,
|
111 |
+
loss_ap,
|
112 |
+
) = model(
|
113 |
+
images,
|
114 |
+
labels_pathology=labels_pathology,
|
115 |
+
labels_anatomy=labels_anatomy,
|
116 |
+
matrix=matrix,
|
117 |
+
sample_index_pathology=index_pathology,
|
118 |
+
sample_index_anatomy=index_anatomy,
|
119 |
+
is_train=True,
|
120 |
+
no_cl=config["no_cl"],
|
121 |
+
exclude_class=config["exclude_class"],
|
122 |
+
)
|
123 |
+
loss.backward()
|
124 |
+
optimizer.step()
|
125 |
+
writer.add_scalar("loss/loss", loss, scalar_step)
|
126 |
+
writer.add_scalar("loss/loss_ce_pathology", loss_ce_pathology, scalar_step)
|
127 |
+
writer.add_scalar("loss/loss_cl_pathology", loss_cl_pathology, scalar_step)
|
128 |
+
writer.add_scalar("loss/loss_ce_anatomy", loss_ce_anatomy, scalar_step)
|
129 |
+
writer.add_scalar("loss/loss_cl_anatomy", loss_cl_anatomy, scalar_step)
|
130 |
+
writer.add_scalar("loss/loss_ap", loss_ap, scalar_step)
|
131 |
+
scalar_step += 1
|
132 |
+
metric_logger.update(loss_ce_p=loss_ce_pathology.item())
|
133 |
+
metric_logger.update(loss_cl_p=loss_cl_pathology.item())
|
134 |
+
metric_logger.update(loss_ce_a=loss_ce_anatomy.item())
|
135 |
+
metric_logger.update(loss_cl_a=loss_cl_anatomy.item())
|
136 |
+
metric_logger.update(loss_ap=loss_ap.item())
|
137 |
+
metric_logger.update(loss=loss.item())
|
138 |
+
# metric_logger.update(loss_cl=loss_cl.item())
|
139 |
+
if epoch == 0 and i % step_size == 0 and i <= warmup_iterations:
|
140 |
+
scheduler.step(i // step_size)
|
141 |
+
metric_logger.update(lr=scheduler._get_lr(epoch)[0])
|
142 |
+
|
143 |
+
# gather the stats from all processes
|
144 |
+
metric_logger.synchronize_between_processes()
|
145 |
+
print("Averaged stats:", metric_logger.global_avg())
|
146 |
+
return {
|
147 |
+
k: "{:.3f}".format(meter.global_avg)
|
148 |
+
for k, meter in metric_logger.meters.items()
|
149 |
+
}
|
150 |
+
|
151 |
+
|
152 |
+
def valid(model, data_loader, epoch, device, config, writer):
|
153 |
+
model.eval()
|
154 |
+
val_scalar_step = epoch * len(data_loader)
|
155 |
+
val_loss = []
|
156 |
+
for i, sample in enumerate(data_loader):
|
157 |
+
|
158 |
+
images = sample["image"].to(device)
|
159 |
+
labels_pathology = sample["label_pathology"].to(device)
|
160 |
+
labels_anatomy = sample["label_anatomy"].to(device)
|
161 |
+
index_pathology = sample["index_pathology"].to(device)
|
162 |
+
index_anatomy = sample["index_anatomy"].to(device)
|
163 |
+
matrix = sample["matrix"].to(device)
|
164 |
+
|
165 |
+
with torch.no_grad():
|
166 |
+
(
|
167 |
+
loss,
|
168 |
+
loss_ce_pathology,
|
169 |
+
loss_cl_pathology,
|
170 |
+
loss_ce_anatomy,
|
171 |
+
loss_cl_anatomy,
|
172 |
+
loss_ap,
|
173 |
+
) = model(
|
174 |
+
images,
|
175 |
+
labels_pathology=labels_pathology,
|
176 |
+
labels_anatomy=labels_anatomy,
|
177 |
+
matrix=matrix,
|
178 |
+
sample_index_pathology=index_pathology,
|
179 |
+
sample_index_anatomy=index_anatomy,
|
180 |
+
is_train=True,
|
181 |
+
no_cl=config["no_cl"],
|
182 |
+
exclude_class=config["exclude_class"],
|
183 |
+
)
|
184 |
+
val_loss.append(loss.item())
|
185 |
+
writer.add_scalar("val_loss/loss", loss, val_scalar_step)
|
186 |
+
writer.add_scalar(
|
187 |
+
"val_loss/loss_ce_pathology", loss_ce_pathology, val_scalar_step
|
188 |
+
)
|
189 |
+
writer.add_scalar(
|
190 |
+
"val_loss/loss_cl_pathology", loss_cl_pathology, val_scalar_step
|
191 |
+
)
|
192 |
+
writer.add_scalar(
|
193 |
+
"val_loss/loss_ce_anatomy", loss_ce_anatomy, val_scalar_step
|
194 |
+
)
|
195 |
+
writer.add_scalar(
|
196 |
+
"val_loss/loss_cl_anatomy", loss_cl_anatomy, val_scalar_step
|
197 |
+
)
|
198 |
+
writer.add_scalar("val_loss/loss_ap", loss_ap, val_scalar_step)
|
199 |
+
val_scalar_step += 1
|
200 |
+
avg_val_loss = np.array(val_loss).mean()
|
201 |
+
return avg_val_loss
|
202 |
+
|
203 |
+
|
204 |
+
def main(args, config):
|
205 |
+
|
206 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
207 |
+
if args.computing == "parallel":
|
208 |
+
world_size = torch.distributed.get_world_size()
|
209 |
+
rank = torch.distributed.get_rank()
|
210 |
+
device = torch.device("cuda", rank)
|
211 |
+
print("World size: ", world_size, "; Rank: ", rank)
|
212 |
+
|
213 |
+
print("Total CUDA devices: ", torch.cuda.device_count())
|
214 |
+
torch.set_default_tensor_type("torch.FloatTensor")
|
215 |
+
cudnn.benchmark = True
|
216 |
+
|
217 |
+
start_epoch = 0
|
218 |
+
max_epoch = config["schedular"]["epochs"]
|
219 |
+
warmup_steps = config["schedular"]["warmup_epochs"]
|
220 |
+
|
221 |
+
#### Dataset ####
|
222 |
+
print("Creating dataset")
|
223 |
+
train_datasets = MeDSLIP_Dataset(
|
224 |
+
config["train_file"], config["label_file"], mode="train"
|
225 |
+
)
|
226 |
+
val_datasets = MeDSLIP_Dataset(
|
227 |
+
config["valid_file"], config["label_file"], mode="train"
|
228 |
+
)
|
229 |
+
if args.computing == "parallel":
|
230 |
+
# shuffl
|
231 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
232 |
+
train_datasets, num_replicas=world_size, rank=rank, shuffle=True
|
233 |
+
)
|
234 |
+
val_sampler = torch.utils.data.distributed.DistributedSampler(
|
235 |
+
val_datasets, num_replicas=world_size, rank=rank, shuffle=True
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
train_sampler = torch.utils.data.RandomSampler(train_datasets)
|
239 |
+
val_sampler = torch.utils.data.RandomSampler(val_datasets)
|
240 |
+
train_dataloader = DataLoader(
|
241 |
+
train_datasets,
|
242 |
+
batch_size=config["batch_size"],
|
243 |
+
num_workers=30,
|
244 |
+
pin_memory=True,
|
245 |
+
sampler=train_sampler,
|
246 |
+
collate_fn=None,
|
247 |
+
drop_last=True,
|
248 |
+
)
|
249 |
+
|
250 |
+
val_dataloader = DataLoader(
|
251 |
+
val_datasets,
|
252 |
+
batch_size=config["batch_size"],
|
253 |
+
num_workers=30,
|
254 |
+
pin_memory=True,
|
255 |
+
sampler=val_sampler,
|
256 |
+
collate_fn=None,
|
257 |
+
drop_last=True,
|
258 |
+
)
|
259 |
+
|
260 |
+
print("Creating book")
|
261 |
+
json_book = json.load(open(config["pathology_book"], "r"))
|
262 |
+
pathology_book = [json_book[i] for i in json_book]
|
263 |
+
anatomy_list = [
|
264 |
+
"trachea",
|
265 |
+
"left_hilar",
|
266 |
+
"right_hilar",
|
267 |
+
"hilar_unspec",
|
268 |
+
"left_pleural",
|
269 |
+
"right_pleural",
|
270 |
+
"pleural_unspec",
|
271 |
+
"heart_size",
|
272 |
+
"heart_border",
|
273 |
+
"left_diaphragm",
|
274 |
+
"right_diaphragm",
|
275 |
+
"diaphragm_unspec",
|
276 |
+
"retrocardiac",
|
277 |
+
"lower_left_lobe",
|
278 |
+
"upper_left_lobe",
|
279 |
+
"lower_right_lobe",
|
280 |
+
"middle_right_lobe",
|
281 |
+
"upper_right_lobe",
|
282 |
+
"left_lower_lung",
|
283 |
+
"left_mid_lung",
|
284 |
+
"left_upper_lung",
|
285 |
+
"left_apical_lung",
|
286 |
+
"left_lung_unspec",
|
287 |
+
"right_lower_lung",
|
288 |
+
"right_mid_lung",
|
289 |
+
"right_upper_lung",
|
290 |
+
"right_apical_lung",
|
291 |
+
"right_lung_unspec",
|
292 |
+
"lung_apices",
|
293 |
+
"lung_bases",
|
294 |
+
"left_costophrenic",
|
295 |
+
"right_costophrenic",
|
296 |
+
"costophrenic_unspec",
|
297 |
+
"cardiophrenic_sulcus",
|
298 |
+
"mediastinal",
|
299 |
+
"spine",
|
300 |
+
"clavicle",
|
301 |
+
"rib",
|
302 |
+
"stomach",
|
303 |
+
"right_atrium",
|
304 |
+
"right_ventricle",
|
305 |
+
"aorta",
|
306 |
+
"svc",
|
307 |
+
"interstitium",
|
308 |
+
"parenchymal",
|
309 |
+
"cavoatrial_junction",
|
310 |
+
"cardiopulmonary",
|
311 |
+
"pulmonary",
|
312 |
+
"lung_volumes",
|
313 |
+
"unspecified",
|
314 |
+
"other",
|
315 |
+
]
|
316 |
+
anatomy_book = []
|
317 |
+
for i in anatomy_list:
|
318 |
+
anatomy_book.append("It is located at " + i + ". ")
|
319 |
+
|
320 |
+
tokenizer = BertTokenizer.from_pretrained(config["text_encoder"])
|
321 |
+
anatomy_book_tokenizer = get_tokenizer(tokenizer, anatomy_book).to(device)
|
322 |
+
pathology_book_tokenizer = get_tokenizer(tokenizer, pathology_book).to(device)
|
323 |
+
print("Creating model")
|
324 |
+
model = MeDSLIP(
|
325 |
+
config, anatomy_book_tokenizer, pathology_book_tokenizer, mode="train"
|
326 |
+
)
|
327 |
+
model = model.to(device)
|
328 |
+
if args.computing == "parallel":
|
329 |
+
model = nn.parallel.DistributedDataParallel(
|
330 |
+
model, device_ids=[rank], find_unused_parameters=True
|
331 |
+
)
|
332 |
+
|
333 |
+
arg_opt = utils.AttrDict(config["optimizer"])
|
334 |
+
optimizer = create_optimizer(arg_opt, model)
|
335 |
+
arg_sche = utils.AttrDict(config["schedular"])
|
336 |
+
lr_scheduler, _ = create_scheduler(arg_sche, optimizer)
|
337 |
+
|
338 |
+
if args.checkpoint:
|
339 |
+
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
340 |
+
state_dict = checkpoint["model"]
|
341 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
342 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
343 |
+
start_epoch = checkpoint["epoch"] + 1
|
344 |
+
model.load_state_dict(state_dict)
|
345 |
+
print("load checkpoint from %s" % args.checkpoint)
|
346 |
+
|
347 |
+
print("Start training")
|
348 |
+
start_time = time.time()
|
349 |
+
|
350 |
+
writer = SummaryWriter(os.path.join(args.output_dir, "log"))
|
351 |
+
for epoch in range(start_epoch, max_epoch):
|
352 |
+
if epoch > 0:
|
353 |
+
lr_scheduler.step(epoch + warmup_steps)
|
354 |
+
train_stats = train(
|
355 |
+
model,
|
356 |
+
train_dataloader,
|
357 |
+
optimizer,
|
358 |
+
epoch,
|
359 |
+
warmup_steps,
|
360 |
+
device,
|
361 |
+
lr_scheduler,
|
362 |
+
args,
|
363 |
+
config,
|
364 |
+
writer,
|
365 |
+
)
|
366 |
+
|
367 |
+
for k, v in train_stats.items():
|
368 |
+
train_loss_epoch = v
|
369 |
+
|
370 |
+
writer.add_scalar("loss/train_loss_epoch", float(train_loss_epoch), epoch)
|
371 |
+
writer.add_scalar("loss/leaning_rate", lr_scheduler._get_lr(epoch)[0], epoch)
|
372 |
+
|
373 |
+
val_loss = valid(model, val_dataloader, epoch, device, config, writer)
|
374 |
+
writer.add_scalar("loss/val_loss_epoch", val_loss, epoch)
|
375 |
+
|
376 |
+
if utils.is_main_process():
|
377 |
+
log_stats = {
|
378 |
+
**{f"train_{k}": v for k, v in train_stats.items()},
|
379 |
+
"epoch": epoch,
|
380 |
+
"val_loss": val_loss.item(),
|
381 |
+
}
|
382 |
+
save_obj = {
|
383 |
+
"model": model.state_dict(),
|
384 |
+
"optimizer": optimizer.state_dict(),
|
385 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
386 |
+
"config": config,
|
387 |
+
"epoch": epoch,
|
388 |
+
}
|
389 |
+
torch.save(save_obj, os.path.join(args.output_dir, "checkpoint_state.pth"))
|
390 |
+
|
391 |
+
with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
|
392 |
+
f.write(json.dumps(log_stats) + "\n")
|
393 |
+
|
394 |
+
if epoch % 1 == 0 and epoch > 15:
|
395 |
+
save_obj = {
|
396 |
+
"model": model.state_dict(),
|
397 |
+
"optimizer": optimizer.state_dict(),
|
398 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
399 |
+
"config": config,
|
400 |
+
"epoch": epoch,
|
401 |
+
}
|
402 |
+
torch.save(
|
403 |
+
save_obj,
|
404 |
+
os.path.join(args.output_dir, "checkpoint_" + str(epoch) + ".pth"),
|
405 |
+
)
|
406 |
+
|
407 |
+
total_time = time.time() - start_time
|
408 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
409 |
+
print("Training time {}".format(total_time_str))
|
410 |
+
|
411 |
+
|
412 |
+
if __name__ == "__main__":
|
413 |
+
parser = argparse.ArgumentParser()
|
414 |
+
parser.add_argument(
|
415 |
+
"--config", default="PreTrain_MeDSLIP/configs/Pretrain_MeDSLIP.yaml"
|
416 |
+
)
|
417 |
+
parser.add_argument("--checkpoint", default="")
|
418 |
+
parser.add_argument("--output_dir", default="runs/")
|
419 |
+
parser.add_argument("--device", default="cuda")
|
420 |
+
parser.add_argument("--local_rank", default=0, type=int)
|
421 |
+
parser.add_argument("--world_size", default=1, type=int)
|
422 |
+
parser.add_argument(
|
423 |
+
"--computing", type=str, default="single", help="number of gpus"
|
424 |
+
)
|
425 |
+
args = parser.parse_args()
|
426 |
+
import datetime
|
427 |
+
|
428 |
+
args.output_dir = os.path.join(
|
429 |
+
args.output_dir, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
|
430 |
+
)
|
431 |
+
|
432 |
+
gpus = torch.cuda.device_count()
|
433 |
+
if gpus > 1:
|
434 |
+
args.computing = "parallel"
|
435 |
+
|
436 |
+
config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
|
437 |
+
|
438 |
+
if not Path(args.output_dir).exists():
|
439 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
440 |
+
|
441 |
+
yaml.dump(config, open(os.path.join(args.output_dir, "config.yaml"), "w"))
|
442 |
+
|
443 |
+
if args.computing == "parallel":
|
444 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
445 |
+
|
446 |
+
main(args, config)
|
PreTrain_MeDSLIP/utils.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from collections import defaultdict, deque
|
6 |
+
import datetime
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
+
|
16 |
+
|
17 |
+
class SmoothedValue(object):
|
18 |
+
"""Track a series of values and provide access to smoothed values over a
|
19 |
+
window or the global series average.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, window_size=20, fmt=None):
|
23 |
+
if fmt is None:
|
24 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
25 |
+
self.deque = deque(maxlen=window_size)
|
26 |
+
self.total = 0.0
|
27 |
+
self.count = 0
|
28 |
+
self.fmt = fmt
|
29 |
+
|
30 |
+
def update(self, value, n=1):
|
31 |
+
self.deque.append(value)
|
32 |
+
self.count += n
|
33 |
+
self.total += value * n
|
34 |
+
|
35 |
+
def synchronize_between_processes(self):
|
36 |
+
"""
|
37 |
+
Warning: does not synchronize the deque!
|
38 |
+
"""
|
39 |
+
if not is_dist_avail_and_initialized():
|
40 |
+
return
|
41 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
42 |
+
dist.barrier()
|
43 |
+
dist.all_reduce(t)
|
44 |
+
t = t.tolist()
|
45 |
+
self.count = int(t[0])
|
46 |
+
self.total = t[1]
|
47 |
+
|
48 |
+
@property
|
49 |
+
def median(self):
|
50 |
+
d = torch.tensor(list(self.deque))
|
51 |
+
return d.median().item()
|
52 |
+
|
53 |
+
@property
|
54 |
+
def avg(self):
|
55 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
56 |
+
return d.mean().item()
|
57 |
+
|
58 |
+
@property
|
59 |
+
def global_avg(self):
|
60 |
+
if self.count == 0:
|
61 |
+
return self.total
|
62 |
+
else:
|
63 |
+
return self.total / self.count
|
64 |
+
|
65 |
+
@property
|
66 |
+
def max(self):
|
67 |
+
return max(self.deque)
|
68 |
+
|
69 |
+
@property
|
70 |
+
def value(self):
|
71 |
+
return self.deque[-1]
|
72 |
+
|
73 |
+
def __str__(self):
|
74 |
+
return self.fmt.format(
|
75 |
+
median=self.median,
|
76 |
+
avg=self.avg,
|
77 |
+
global_avg=self.global_avg,
|
78 |
+
max=self.max,
|
79 |
+
value=self.value,
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class MetricLogger(object):
|
84 |
+
def __init__(self, delimiter="\t"):
|
85 |
+
self.meters = defaultdict(SmoothedValue)
|
86 |
+
self.delimiter = delimiter
|
87 |
+
|
88 |
+
def update(self, **kwargs):
|
89 |
+
for k, v in kwargs.items():
|
90 |
+
if isinstance(v, torch.Tensor):
|
91 |
+
v = v.item()
|
92 |
+
assert isinstance(v, (float, int))
|
93 |
+
self.meters[k].update(v)
|
94 |
+
|
95 |
+
def __getattr__(self, attr):
|
96 |
+
if attr in self.meters:
|
97 |
+
return self.meters[attr]
|
98 |
+
if attr in self.__dict__:
|
99 |
+
return self.__dict__[attr]
|
100 |
+
raise AttributeError(
|
101 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
102 |
+
)
|
103 |
+
|
104 |
+
def __str__(self):
|
105 |
+
loss_str = []
|
106 |
+
for name, meter in self.meters.items():
|
107 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
108 |
+
return self.delimiter.join(loss_str)
|
109 |
+
|
110 |
+
def global_avg(self):
|
111 |
+
loss_str = []
|
112 |
+
for name, meter in self.meters.items():
|
113 |
+
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
114 |
+
return self.delimiter.join(loss_str)
|
115 |
+
|
116 |
+
def synchronize_between_processes(self):
|
117 |
+
for meter in self.meters.values():
|
118 |
+
meter.synchronize_between_processes()
|
119 |
+
|
120 |
+
def add_meter(self, name, meter):
|
121 |
+
self.meters[name] = meter
|
122 |
+
|
123 |
+
def log_every(self, iterable, print_freq, header=None):
|
124 |
+
i = 0
|
125 |
+
if not header:
|
126 |
+
header = ""
|
127 |
+
start_time = time.time()
|
128 |
+
end = time.time()
|
129 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
130 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
131 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
132 |
+
log_msg = ["{meters}"]
|
133 |
+
if torch.cuda.is_available():
|
134 |
+
log_msg.append("max mem: {memory:.0f}")
|
135 |
+
log_msg = self.delimiter.join(log_msg)
|
136 |
+
MB = 1024.0 * 1024.0
|
137 |
+
|
138 |
+
loop = tqdm(iterable)
|
139 |
+
loop.set_description(header)
|
140 |
+
|
141 |
+
for obj in loop:
|
142 |
+
data_time.update(time.time() - end)
|
143 |
+
yield obj
|
144 |
+
iter_time.update(time.time() - end)
|
145 |
+
if i % print_freq == 0 or i == len(loop) - 1:
|
146 |
+
eta_seconds = iter_time.global_avg * (len(loop) - i)
|
147 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
148 |
+
if torch.cuda.is_available():
|
149 |
+
loop.set_postfix_str(
|
150 |
+
log_msg.format(
|
151 |
+
i,
|
152 |
+
len(loop),
|
153 |
+
eta=eta_string,
|
154 |
+
meters=str(self),
|
155 |
+
time=str(iter_time),
|
156 |
+
data=str(data_time),
|
157 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
158 |
+
)
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
loop.set_postfix_str(
|
162 |
+
log_msg.format(
|
163 |
+
i,
|
164 |
+
len(loop),
|
165 |
+
eta=eta_string,
|
166 |
+
meters=str(self),
|
167 |
+
time=str(iter_time),
|
168 |
+
data=str(data_time),
|
169 |
+
)
|
170 |
+
)
|
171 |
+
i += 1
|
172 |
+
end = time.time()
|
173 |
+
|
174 |
+
|
175 |
+
class AttrDict(dict):
|
176 |
+
def __init__(self, *args, **kwargs):
|
177 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
178 |
+
self.__dict__ = self
|
179 |
+
|
180 |
+
|
181 |
+
def compute_acc(logits, label, reduction="mean"):
|
182 |
+
ret = (torch.argmax(logits, dim=1) == label).float()
|
183 |
+
if reduction == "none":
|
184 |
+
return ret.detach()
|
185 |
+
elif reduction == "mean":
|
186 |
+
return ret.mean().item()
|
187 |
+
|
188 |
+
|
189 |
+
def compute_n_params(model, return_str=True):
|
190 |
+
tot = 0
|
191 |
+
for p in model.parameters():
|
192 |
+
w = 1
|
193 |
+
for x in p.shape:
|
194 |
+
w *= x
|
195 |
+
tot += w
|
196 |
+
if return_str:
|
197 |
+
if tot >= 1e6:
|
198 |
+
return "{:.1f}M".format(tot / 1e6)
|
199 |
+
else:
|
200 |
+
return "{:.1f}K".format(tot / 1e3)
|
201 |
+
else:
|
202 |
+
return tot
|
203 |
+
|
204 |
+
|
205 |
+
def setup_for_distributed(is_master):
|
206 |
+
"""
|
207 |
+
This function disables printing when not in master process
|
208 |
+
"""
|
209 |
+
import builtins as __builtin__
|
210 |
+
|
211 |
+
builtin_print = __builtin__.print
|
212 |
+
|
213 |
+
def print(*args, **kwargs):
|
214 |
+
force = kwargs.pop("force", False)
|
215 |
+
if is_master or force:
|
216 |
+
builtin_print(*args, **kwargs)
|
217 |
+
|
218 |
+
__builtin__.print = print
|
219 |
+
|
220 |
+
|
221 |
+
def is_dist_avail_and_initialized():
|
222 |
+
if not dist.is_available():
|
223 |
+
return False
|
224 |
+
if not dist.is_initialized():
|
225 |
+
return False
|
226 |
+
return True
|
227 |
+
|
228 |
+
|
229 |
+
def get_world_size():
|
230 |
+
if not is_dist_avail_and_initialized():
|
231 |
+
return 1
|
232 |
+
return dist.get_world_size()
|
233 |
+
|
234 |
+
|
235 |
+
def get_rank():
|
236 |
+
if not is_dist_avail_and_initialized():
|
237 |
+
return 0
|
238 |
+
return dist.get_rank()
|
239 |
+
|
240 |
+
|
241 |
+
def is_main_process():
|
242 |
+
return get_rank() == 0
|
243 |
+
|
244 |
+
|
245 |
+
def save_on_master(*args, **kwargs):
|
246 |
+
if is_main_process():
|
247 |
+
torch.save(*args, **kwargs)
|
248 |
+
|
249 |
+
|
250 |
+
def init_distributed_mode(args):
|
251 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
252 |
+
args.rank = int(os.environ["RANK"])
|
253 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
254 |
+
args.local_rank = int(os.environ["LOCAL_RANK"])
|
255 |
+
elif "SLURM_PROCID" in os.environ:
|
256 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
257 |
+
args.local_rank = args.rank % torch.cuda.device_count()
|
258 |
+
else:
|
259 |
+
print("Not using distributed mode")
|
260 |
+
args.distributed = False
|
261 |
+
return
|
262 |
+
|
263 |
+
args.distributed = True
|
264 |
+
|
265 |
+
torch.cuda.set_device(args.local_rank)
|
266 |
+
args.dist_backend = "nccl"
|
267 |
+
print(
|
268 |
+
"| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
|
269 |
+
)
|
270 |
+
torch.distributed.init_process_group(
|
271 |
+
backend=args.dist_backend,
|
272 |
+
init_method=args.dist_url,
|
273 |
+
world_size=args.world_size,
|
274 |
+
rank=args.rank,
|
275 |
+
)
|
276 |
+
torch.distributed.barrier()
|
277 |
+
setup_for_distributed(args.rank == 0)
|
README.md
CHANGED
@@ -1,3 +1,49 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MeDSLIP: Medical Knowledge Enhanced Language-Image Pre-Training in Radiology
|
2 |
+
|
3 |
+
## Introduction:
|
4 |
+
|
5 |
+
The official implementation code for "MeDSLIP: Medical Knowledge Enhanced Language-Image Pre-Training in Radiology".
|
6 |
+
|
7 |
+
[**Arxiv Version**](https://arxiv.org/abs/2403.10635)
|
8 |
+
|
9 |
+
## Quick Start:
|
10 |
+
Check checkpoints directory to download our pre-trained model from [Hugging Face: MeDSLIP](https://huggingface.co/pykale/MeDSLIP). It can be used for all zero-shot and finetuning tasks.
|
11 |
+
|
12 |
+
* **Zero-Shot Classification:**
|
13 |
+
|
14 |
+
We give an example on CXR14 in ```Sample_Zero-Shot_Classification_CXR14```. Change the data paths, and test our model by ```python test.py```.
|
15 |
+
We give an example on RSNA in ```Sample_Zero-Shot_Classification_RSNA```. Change the data paths, and test our model by ```python test.py```.
|
16 |
+
|
17 |
+
* **Zero-Shot Grounding:**
|
18 |
+
|
19 |
+
We give an example on RSNA_Pneumonia in ```Sample_Zero-Shot_Grounding_RSNA```. Change the data paths, and test our model by ```python test.py```.
|
20 |
+
|
21 |
+
* **Finetuning:**
|
22 |
+
|
23 |
+
We give segmentation and classification finetune code on SIIM_ACR dataset in ```Sample_Finetuning_SIIMACR```. Change the data paths, and finetune our model by ```python I1_classification/train_res_ft.py``` or ```python I2_segementation/train_res_ft.py```.
|
24 |
+
|
25 |
+
## Pre-train:
|
26 |
+
### Data Preparation
|
27 |
+
All files for data preparation files can be downloaded from [Hugging Face: MeDSLIP](https://huggingface.co/pykale/MeDSLIP).
|
28 |
+
- Extracted triplets: `landmark_observation_adj_mtx.npy`
|
29 |
+
- Training list: `train.json`
|
30 |
+
- Validation list: `valid.json`
|
31 |
+
- Test list: `test.json`
|
32 |
+
|
33 |
+
### Pre-training
|
34 |
+
Our pre-train code is given in ```PreTrain_MeDSLIP```.
|
35 |
+
* Check the ```PreTrain_MeDSLIP/data_file``` dir and download the files for data preparation.
|
36 |
+
* Change the data and preparation files paths as you disire in ```PreTrain_MeDSLIP/configs/Pretrain_MeDSLIP.yaml```, and ```python PreTrain_MeDSLIP/train_MeDSLIP.py``` to pre-train.
|
37 |
+
|
38 |
+
## Reference
|
39 |
+
```
|
40 |
+
@article{fan2024medslip,
|
41 |
+
title={MeDSLIP: Medical Dual-Stream Language-Image Pre-training for Fine-grained Alignment},
|
42 |
+
author={Fan, Wenrui and Suvon, Mohammod Naimul Islam and Zhou, Shuo and Liu, Xianyuan and Alabed, Samer and Osmani, Venet and Swift, Andrew and Chen, Chen and Lu, Haiping},
|
43 |
+
journal={arXiv preprint arXiv:2403.10635},
|
44 |
+
year={2024}
|
45 |
+
}
|
46 |
+
```
|
47 |
+
|
48 |
+
## Contact
|
49 |
+
If you have any question, please feel free to contact winslow.fan@outlook.com.
|
Sample_Finetuning_SIIMACR/I1_classification/configs/Res_train.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_file: "SIIM-CLS/siim-acr-pneumothorax/stage_1_train_images.csv"
|
2 |
+
valid_file: "SIIM-CLS/siim-acr-pneumothorax/stage_1_test_images.csv"
|
3 |
+
test_file: "SIIM-CLS/siim-acr-pneumothorax/stage_1_test_images.csv"
|
4 |
+
|
5 |
+
image_res: 224
|
6 |
+
batch_size: 64
|
7 |
+
test_batch_size: 64
|
8 |
+
num_classes: 1
|
9 |
+
temp: 0.07
|
10 |
+
mlm_probability: 0.15
|
11 |
+
queue_size: 8192
|
12 |
+
momentum: 0.995
|
13 |
+
alpha: 0.4
|
14 |
+
percentage: 1.0
|
15 |
+
|
16 |
+
optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02}
|
17 |
+
schedular: {sched: cosine, lr: 1e-5, epochs: 200, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 20, cooldown_epochs: 0}
|
Sample_Finetuning_SIIMACR/I1_classification/dataset/dataset_siim_acr.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cmath import nan
|
2 |
+
import csv
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import pydicom
|
8 |
+
|
9 |
+
from abc import abstractmethod
|
10 |
+
from itertools import islice
|
11 |
+
from typing import List, Tuple, Dict, Any
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
import PIL
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
import numpy as np
|
16 |
+
import pandas as pd
|
17 |
+
from torchvision import transforms
|
18 |
+
from PIL import Image
|
19 |
+
from skimage import exposure
|
20 |
+
import torch
|
21 |
+
from torchvision.transforms import InterpolationMode
|
22 |
+
from dataset.randaugment import RandomAugment
|
23 |
+
|
24 |
+
|
25 |
+
class SIIM_ACR_Dataset(Dataset):
|
26 |
+
def __init__(self, csv_path, is_train=True, percentage=0.01):
|
27 |
+
data_info = pd.read_csv(csv_path)
|
28 |
+
if is_train == True:
|
29 |
+
total_len = int(percentage * len(data_info))
|
30 |
+
choice_list = np.random.choice(
|
31 |
+
range(len(data_info)), size=total_len, replace=False
|
32 |
+
)
|
33 |
+
self.img_path_list = data_info["image_path"][choice_list].tolist()
|
34 |
+
else:
|
35 |
+
self.img_path_list = data_info["image_path"].tolist()
|
36 |
+
|
37 |
+
self.img_root = "SIIM-CLS/siim-acr-pneumothorax/png_images/"
|
38 |
+
self.seg_root = "SIIM-CLS/siim-acr-pneumothorax/png_masks/" # We have pre-processed the original SIIM_ACR data, you may change this to fix your data
|
39 |
+
|
40 |
+
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
41 |
+
|
42 |
+
if is_train:
|
43 |
+
self.transform = transforms.Compose(
|
44 |
+
[
|
45 |
+
transforms.RandomResizedCrop(
|
46 |
+
224, scale=(0.2, 1.0), interpolation=Image.BICUBIC
|
47 |
+
),
|
48 |
+
transforms.RandomHorizontalFlip(),
|
49 |
+
RandomAugment(
|
50 |
+
2,
|
51 |
+
7,
|
52 |
+
isPIL=True,
|
53 |
+
augs=[
|
54 |
+
"Identity",
|
55 |
+
"AutoContrast",
|
56 |
+
"Equalize",
|
57 |
+
"Brightness",
|
58 |
+
"Sharpness",
|
59 |
+
"ShearX",
|
60 |
+
"ShearY",
|
61 |
+
"TranslateX",
|
62 |
+
"TranslateY",
|
63 |
+
"Rotate",
|
64 |
+
],
|
65 |
+
),
|
66 |
+
transforms.ToTensor(),
|
67 |
+
normalize,
|
68 |
+
]
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
self.transform = transforms.Compose(
|
72 |
+
[transforms.Resize([224, 224]), transforms.ToTensor(), normalize,]
|
73 |
+
)
|
74 |
+
|
75 |
+
self.seg_transfrom = transforms.Compose(
|
76 |
+
[
|
77 |
+
transforms.ToTensor(),
|
78 |
+
transforms.Resize([224, 224], interpolation=InterpolationMode.NEAREST),
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
def __getitem__(self, index):
|
83 |
+
img_path = self.img_root + self.img_path_list[index].split("/")[-1] # + ".png"
|
84 |
+
seg_path = (
|
85 |
+
self.seg_root + self.img_path_list[index].split("/")[-1] # + ".png"
|
86 |
+
) # We have pre-processed the original SIIM_ACR data, you may change this to fix your data
|
87 |
+
img = PIL.Image.open(img_path).convert("RGB")
|
88 |
+
image = self.transform(img)
|
89 |
+
|
90 |
+
seg_map = PIL.Image.open(seg_path)
|
91 |
+
seg_map = self.seg_transfrom(seg_map)
|
92 |
+
seg_map = (seg_map > 0).type(torch.int)
|
93 |
+
class_label = np.array([int(torch.sum(seg_map) > 0)])
|
94 |
+
return {"image": image, "label": class_label}
|
95 |
+
|
96 |
+
def __len__(self):
|
97 |
+
return len(self.img_path_list)
|
98 |
+
|
99 |
+
|
100 |
+
def create_loader_RSNA(
|
101 |
+
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
|
102 |
+
):
|
103 |
+
loaders = []
|
104 |
+
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
|
105 |
+
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
|
106 |
+
):
|
107 |
+
if is_train:
|
108 |
+
shuffle = sampler is None
|
109 |
+
drop_last = True
|
110 |
+
else:
|
111 |
+
shuffle = False
|
112 |
+
drop_last = False
|
113 |
+
loader = DataLoader(
|
114 |
+
dataset,
|
115 |
+
batch_size=bs,
|
116 |
+
num_workers=n_worker,
|
117 |
+
pin_memory=True,
|
118 |
+
sampler=sampler,
|
119 |
+
shuffle=shuffle,
|
120 |
+
collate_fn=collate_fn,
|
121 |
+
drop_last=drop_last,
|
122 |
+
)
|
123 |
+
loaders.append(loader)
|
124 |
+
return loaders
|
Sample_Finetuning_SIIMACR/I1_classification/dataset/randaugment.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
## aug functions
|
6 |
+
def identity_func(img):
|
7 |
+
return img
|
8 |
+
|
9 |
+
|
10 |
+
def autocontrast_func(img, cutoff=0):
|
11 |
+
"""
|
12 |
+
same output as PIL.ImageOps.autocontrast
|
13 |
+
"""
|
14 |
+
n_bins = 256
|
15 |
+
|
16 |
+
def tune_channel(ch):
|
17 |
+
n = ch.size
|
18 |
+
cut = cutoff * n // 100
|
19 |
+
if cut == 0:
|
20 |
+
high, low = ch.max(), ch.min()
|
21 |
+
else:
|
22 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
23 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
24 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
25 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
26 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
27 |
+
if high <= low:
|
28 |
+
table = np.arange(n_bins)
|
29 |
+
else:
|
30 |
+
scale = (n_bins - 1) / (high - low)
|
31 |
+
offset = -low * scale
|
32 |
+
table = np.arange(n_bins) * scale + offset
|
33 |
+
table[table < 0] = 0
|
34 |
+
table[table > n_bins - 1] = n_bins - 1
|
35 |
+
table = table.clip(0, 255).astype(np.uint8)
|
36 |
+
return table[ch]
|
37 |
+
|
38 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
39 |
+
out = cv2.merge(channels)
|
40 |
+
return out
|
41 |
+
|
42 |
+
|
43 |
+
def equalize_func(img):
|
44 |
+
"""
|
45 |
+
same output as PIL.ImageOps.equalize
|
46 |
+
PIL's implementation is different from cv2.equalize
|
47 |
+
"""
|
48 |
+
n_bins = 256
|
49 |
+
|
50 |
+
def tune_channel(ch):
|
51 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
52 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
53 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
54 |
+
if step == 0:
|
55 |
+
return ch
|
56 |
+
n = np.empty_like(hist)
|
57 |
+
n[0] = step // 2
|
58 |
+
n[1:] = hist[:-1]
|
59 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
60 |
+
return table[ch]
|
61 |
+
|
62 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
63 |
+
out = cv2.merge(channels)
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
68 |
+
"""
|
69 |
+
like PIL, rotate by degree, not radians
|
70 |
+
"""
|
71 |
+
H, W = img.shape[0], img.shape[1]
|
72 |
+
center = W / 2, H / 2
|
73 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
74 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
75 |
+
return out
|
76 |
+
|
77 |
+
|
78 |
+
def solarize_func(img, thresh=128):
|
79 |
+
"""
|
80 |
+
same output as PIL.ImageOps.posterize
|
81 |
+
"""
|
82 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
83 |
+
table = table.clip(0, 255).astype(np.uint8)
|
84 |
+
out = table[img]
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
def color_func(img, factor):
|
89 |
+
"""
|
90 |
+
same output as PIL.ImageEnhance.Color
|
91 |
+
"""
|
92 |
+
## implementation according to PIL definition, quite slow
|
93 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
94 |
+
# out = blend(degenerate, img, factor)
|
95 |
+
# M = (
|
96 |
+
# np.eye(3) * factor
|
97 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
98 |
+
# )[np.newaxis, np.newaxis, :]
|
99 |
+
M = np.float32(
|
100 |
+
[[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
|
101 |
+
) * factor + np.float32([[0.114], [0.587], [0.299]])
|
102 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
def contrast_func(img, factor):
|
107 |
+
"""
|
108 |
+
same output as PIL.ImageEnhance.Contrast
|
109 |
+
"""
|
110 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
111 |
+
table = (
|
112 |
+
np.array([(el - mean) * factor + mean for el in range(256)])
|
113 |
+
.clip(0, 255)
|
114 |
+
.astype(np.uint8)
|
115 |
+
)
|
116 |
+
out = table[img]
|
117 |
+
return out
|
118 |
+
|
119 |
+
|
120 |
+
def brightness_func(img, factor):
|
121 |
+
"""
|
122 |
+
same output as PIL.ImageEnhance.Contrast
|
123 |
+
"""
|
124 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
125 |
+
out = table[img]
|
126 |
+
return out
|
127 |
+
|
128 |
+
|
129 |
+
def sharpness_func(img, factor):
|
130 |
+
"""
|
131 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
132 |
+
areas are same
|
133 |
+
"""
|
134 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
135 |
+
kernel[1][1] = 5
|
136 |
+
kernel /= 13
|
137 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
138 |
+
if factor == 0.0:
|
139 |
+
out = degenerate
|
140 |
+
elif factor == 1.0:
|
141 |
+
out = img
|
142 |
+
else:
|
143 |
+
out = img.astype(np.float32)
|
144 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
145 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
146 |
+
out = out.astype(np.uint8)
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
151 |
+
H, W = img.shape[0], img.shape[1]
|
152 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
153 |
+
out = cv2.warpAffine(
|
154 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
155 |
+
).astype(np.uint8)
|
156 |
+
return out
|
157 |
+
|
158 |
+
|
159 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
160 |
+
"""
|
161 |
+
same output as PIL.Image.transform
|
162 |
+
"""
|
163 |
+
H, W = img.shape[0], img.shape[1]
|
164 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
165 |
+
out = cv2.warpAffine(
|
166 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
167 |
+
).astype(np.uint8)
|
168 |
+
return out
|
169 |
+
|
170 |
+
|
171 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
172 |
+
"""
|
173 |
+
same output as PIL.Image.transform
|
174 |
+
"""
|
175 |
+
H, W = img.shape[0], img.shape[1]
|
176 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
177 |
+
out = cv2.warpAffine(
|
178 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
179 |
+
).astype(np.uint8)
|
180 |
+
return out
|
181 |
+
|
182 |
+
|
183 |
+
def posterize_func(img, bits):
|
184 |
+
"""
|
185 |
+
same output as PIL.ImageOps.posterize
|
186 |
+
"""
|
187 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
188 |
+
return out
|
189 |
+
|
190 |
+
|
191 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
192 |
+
H, W = img.shape[0], img.shape[1]
|
193 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
194 |
+
out = cv2.warpAffine(
|
195 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
196 |
+
).astype(np.uint8)
|
197 |
+
return out
|
198 |
+
|
199 |
+
|
200 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
201 |
+
replace = np.array(replace, dtype=np.uint8)
|
202 |
+
H, W = img.shape[0], img.shape[1]
|
203 |
+
rh, rw = np.random.random(2)
|
204 |
+
pad_size = pad_size // 2
|
205 |
+
ch, cw = int(rh * H), int(rw * W)
|
206 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
207 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
208 |
+
out = img.copy()
|
209 |
+
out[x1:x2, y1:y2, :] = replace
|
210 |
+
return out
|
211 |
+
|
212 |
+
|
213 |
+
### level to args
|
214 |
+
def enhance_level_to_args(MAX_LEVEL):
|
215 |
+
def level_to_args(level):
|
216 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
217 |
+
|
218 |
+
return level_to_args
|
219 |
+
|
220 |
+
|
221 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
222 |
+
def level_to_args(level):
|
223 |
+
level = (level / MAX_LEVEL) * 0.3
|
224 |
+
if np.random.random() > 0.5:
|
225 |
+
level = -level
|
226 |
+
return (level, replace_value)
|
227 |
+
|
228 |
+
return level_to_args
|
229 |
+
|
230 |
+
|
231 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
232 |
+
def level_to_args(level):
|
233 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
234 |
+
if np.random.random() > 0.5:
|
235 |
+
level = -level
|
236 |
+
return (level, replace_value)
|
237 |
+
|
238 |
+
return level_to_args
|
239 |
+
|
240 |
+
|
241 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
242 |
+
def level_to_args(level):
|
243 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
244 |
+
return (level, replace_value)
|
245 |
+
|
246 |
+
return level_to_args
|
247 |
+
|
248 |
+
|
249 |
+
def solarize_level_to_args(MAX_LEVEL):
|
250 |
+
def level_to_args(level):
|
251 |
+
level = int((level / MAX_LEVEL) * 256)
|
252 |
+
return (level,)
|
253 |
+
|
254 |
+
return level_to_args
|
255 |
+
|
256 |
+
|
257 |
+
def none_level_to_args(level):
|
258 |
+
return ()
|
259 |
+
|
260 |
+
|
261 |
+
def posterize_level_to_args(MAX_LEVEL):
|
262 |
+
def level_to_args(level):
|
263 |
+
level = int((level / MAX_LEVEL) * 4)
|
264 |
+
return (level,)
|
265 |
+
|
266 |
+
return level_to_args
|
267 |
+
|
268 |
+
|
269 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
270 |
+
def level_to_args(level):
|
271 |
+
level = (level / MAX_LEVEL) * 30
|
272 |
+
if np.random.random() < 0.5:
|
273 |
+
level = -level
|
274 |
+
return (level, replace_value)
|
275 |
+
|
276 |
+
return level_to_args
|
277 |
+
|
278 |
+
|
279 |
+
func_dict = {
|
280 |
+
"Identity": identity_func,
|
281 |
+
"AutoContrast": autocontrast_func,
|
282 |
+
"Equalize": equalize_func,
|
283 |
+
"Rotate": rotate_func,
|
284 |
+
"Solarize": solarize_func,
|
285 |
+
"Color": color_func,
|
286 |
+
"Contrast": contrast_func,
|
287 |
+
"Brightness": brightness_func,
|
288 |
+
"Sharpness": sharpness_func,
|
289 |
+
"ShearX": shear_x_func,
|
290 |
+
"TranslateX": translate_x_func,
|
291 |
+
"TranslateY": translate_y_func,
|
292 |
+
"Posterize": posterize_func,
|
293 |
+
"ShearY": shear_y_func,
|
294 |
+
}
|
295 |
+
|
296 |
+
translate_const = 10
|
297 |
+
MAX_LEVEL = 10
|
298 |
+
replace_value = (128, 128, 128)
|
299 |
+
arg_dict = {
|
300 |
+
"Identity": none_level_to_args,
|
301 |
+
"AutoContrast": none_level_to_args,
|
302 |
+
"Equalize": none_level_to_args,
|
303 |
+
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
|
304 |
+
"Solarize": solarize_level_to_args(MAX_LEVEL),
|
305 |
+
"Color": enhance_level_to_args(MAX_LEVEL),
|
306 |
+
"Contrast": enhance_level_to_args(MAX_LEVEL),
|
307 |
+
"Brightness": enhance_level_to_args(MAX_LEVEL),
|
308 |
+
"Sharpness": enhance_level_to_args(MAX_LEVEL),
|
309 |
+
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
|
310 |
+
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
311 |
+
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
312 |
+
"Posterize": posterize_level_to_args(MAX_LEVEL),
|
313 |
+
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
|
314 |
+
}
|
315 |
+
|
316 |
+
|
317 |
+
class RandomAugment(object):
|
318 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
319 |
+
self.N = N
|
320 |
+
self.M = M
|
321 |
+
self.isPIL = isPIL
|
322 |
+
if augs:
|
323 |
+
self.augs = augs
|
324 |
+
else:
|
325 |
+
self.augs = list(arg_dict.keys())
|
326 |
+
|
327 |
+
def get_random_ops(self):
|
328 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
329 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
330 |
+
|
331 |
+
def __call__(self, img):
|
332 |
+
if self.isPIL:
|
333 |
+
img = np.array(img)
|
334 |
+
ops = self.get_random_ops()
|
335 |
+
for name, prob, level in ops:
|
336 |
+
if np.random.random() > prob:
|
337 |
+
continue
|
338 |
+
args = arg_dict[name](level)
|
339 |
+
img = func_dict[name](img, *args)
|
340 |
+
return img
|
341 |
+
|
342 |
+
|
343 |
+
if __name__ == "__main__":
|
344 |
+
a = RandomAugment()
|
345 |
+
img = np.random.randn(32, 32, 3)
|
346 |
+
a(img)
|
Sample_Finetuning_SIIMACR/I1_classification/models/resnet.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torchvision.models as models
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
class ModelRes_ft(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
res_base_model,
|
12 |
+
out_size,
|
13 |
+
imagenet_pretrain=False,
|
14 |
+
linear_probe=False,
|
15 |
+
use_base=True,
|
16 |
+
):
|
17 |
+
super(ModelRes_ft, self).__init__()
|
18 |
+
self.resnet_dict = {
|
19 |
+
"resnet18": models.resnet18(pretrained=imagenet_pretrain),
|
20 |
+
"resnet50": models.resnet50(pretrained=imagenet_pretrain),
|
21 |
+
}
|
22 |
+
resnet = self._get_res_basemodel(res_base_model)
|
23 |
+
self.use_base = use_base
|
24 |
+
|
25 |
+
if not self.use_base:
|
26 |
+
num_ftrs = int(resnet.fc.in_features / 2)
|
27 |
+
self.res_features = nn.Sequential(*list(resnet.children())[:-3])
|
28 |
+
self.res_l1_anatomy = nn.Linear(num_ftrs, num_ftrs)
|
29 |
+
self.res_l2_anatomy = nn.Linear(num_ftrs, 256)
|
30 |
+
self.res_l1_pathology = nn.Linear(num_ftrs, num_ftrs)
|
31 |
+
self.res_l2_pathology = nn.Linear(num_ftrs, 256)
|
32 |
+
|
33 |
+
self.mask_generator = nn.Linear(num_ftrs, num_ftrs)
|
34 |
+
self.back = nn.Linear(256, num_ftrs)
|
35 |
+
self.last_res = nn.Sequential(*list(resnet.children())[-3:-1])
|
36 |
+
else:
|
37 |
+
self.res_features = nn.Sequential(*list(resnet.children())[:-1])
|
38 |
+
self.res_out = nn.Linear(int(resnet.fc.in_features), out_size)
|
39 |
+
|
40 |
+
def _get_res_basemodel(self, res_model_name):
|
41 |
+
try:
|
42 |
+
res_model = self.resnet_dict[res_model_name]
|
43 |
+
print("Image feature extractor:", res_model_name)
|
44 |
+
return res_model
|
45 |
+
except:
|
46 |
+
raise (
|
47 |
+
"Invalid model name. Check the config file and pass one of: resnet18 or resnet50"
|
48 |
+
)
|
49 |
+
|
50 |
+
def image_encoder(self, xis):
|
51 |
+
# patch features
|
52 |
+
"""
|
53 |
+
16 torch.Size([16, 1024, 14, 14])
|
54 |
+
torch.Size([16, 196, 1024])
|
55 |
+
torch.Size([3136, 1024])
|
56 |
+
torch.Size([16, 196, 256])
|
57 |
+
"""
|
58 |
+
batch_size = xis.shape[0]
|
59 |
+
res_fea = self.res_features(xis) # batch_size,feature_size,patch_num,patch_num
|
60 |
+
res_fea = rearrange(res_fea, "b d n1 n2 -> b (n1 n2) d")
|
61 |
+
x = rearrange(res_fea, "b n d -> (b n) d")
|
62 |
+
mask = self.mask_generator(x)
|
63 |
+
x_pathology = mask * x
|
64 |
+
x_pathology = self.res_l1_pathology(x_pathology)
|
65 |
+
x_pathology = F.relu(x_pathology)
|
66 |
+
|
67 |
+
x_pathology = self.res_l2_pathology(x_pathology)
|
68 |
+
|
69 |
+
out_emb_pathology = rearrange(x_pathology, "(b n) d -> b n d", b=batch_size)
|
70 |
+
out_emb_pathology = self.back(out_emb_pathology)
|
71 |
+
out_emb_pathology = rearrange(out_emb_pathology, "b (n1 n2) d -> b d n1 n2", n1=14, n2=14)
|
72 |
+
out_emb_pathology = self.last_res(out_emb_pathology)
|
73 |
+
out_emb_pathology = out_emb_pathology.squeeze()
|
74 |
+
|
75 |
+
return out_emb_pathology
|
76 |
+
|
77 |
+
def forward(self, img, linear_probe=False):
|
78 |
+
if self.use_base:
|
79 |
+
x = self.res_features(img)
|
80 |
+
else:
|
81 |
+
x = self.image_encoder(img)
|
82 |
+
|
83 |
+
x = x.squeeze()
|
84 |
+
if linear_probe:
|
85 |
+
return x
|
86 |
+
else:
|
87 |
+
x = self.res_out(x)
|
88 |
+
return x
|
Sample_Finetuning_SIIMACR/I1_classification/optim/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .adamp import AdamP
|
2 |
+
from .adamw import AdamW
|
3 |
+
from .adafactor import Adafactor
|
4 |
+
from .adahessian import Adahessian
|
5 |
+
from .lookahead import Lookahead
|
6 |
+
from .nadam import Nadam
|
7 |
+
from .novograd import NovoGrad
|
8 |
+
from .nvnovograd import NvNovoGrad
|
9 |
+
from .radam import RAdam
|
10 |
+
from .rmsprop_tf import RMSpropTF
|
11 |
+
from .sgdp import SGDP
|
12 |
+
|
13 |
+
from .optim_factory import create_optimizer
|
Sample_Finetuning_SIIMACR/I1_classification/optim/adafactor.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Adafactor Optimizer
|
2 |
+
|
3 |
+
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
4 |
+
|
5 |
+
Original header/copyright below.
|
6 |
+
|
7 |
+
"""
|
8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
9 |
+
#
|
10 |
+
# This source code is licensed under the MIT license found in the
|
11 |
+
# LICENSE file in the root directory of this source tree.
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
|
15 |
+
|
16 |
+
class Adafactor(torch.optim.Optimizer):
|
17 |
+
"""Implements Adafactor algorithm.
|
18 |
+
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
|
19 |
+
(see https://arxiv.org/abs/1804.04235)
|
20 |
+
|
21 |
+
Note that this optimizer internally adjusts the learning rate depending on the
|
22 |
+
*scale_parameter*, *relative_step* and *warmup_init* options.
|
23 |
+
|
24 |
+
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
25 |
+
`relative_step=False`.
|
26 |
+
|
27 |
+
Arguments:
|
28 |
+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
29 |
+
lr (float, optional): external learning rate (default: None)
|
30 |
+
eps (tuple[float, float]): regularization constants for square gradient
|
31 |
+
and parameter scale respectively (default: (1e-30, 1e-3))
|
32 |
+
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
|
33 |
+
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
|
34 |
+
beta1 (float): coefficient used for computing running averages of gradient (default: None)
|
35 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
36 |
+
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
|
37 |
+
relative_step (bool): if True, time-dependent learning rate is computed
|
38 |
+
instead of external learning rate (default: True)
|
39 |
+
warmup_init (bool): time-dependent learning rate computation depends on
|
40 |
+
whether warm-up initialization is being used (default: False)
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
params,
|
46 |
+
lr=None,
|
47 |
+
eps=1e-30,
|
48 |
+
eps_scale=1e-3,
|
49 |
+
clip_threshold=1.0,
|
50 |
+
decay_rate=-0.8,
|
51 |
+
betas=None,
|
52 |
+
weight_decay=0.0,
|
53 |
+
scale_parameter=True,
|
54 |
+
warmup_init=False,
|
55 |
+
):
|
56 |
+
relative_step = lr is None
|
57 |
+
if warmup_init and not relative_step:
|
58 |
+
raise ValueError("warmup_init requires relative_step=True")
|
59 |
+
|
60 |
+
beta1 = (
|
61 |
+
None if betas is None else betas[0]
|
62 |
+
) # make it compat with standard betas arg
|
63 |
+
defaults = dict(
|
64 |
+
lr=lr,
|
65 |
+
eps=eps,
|
66 |
+
eps_scale=eps_scale,
|
67 |
+
clip_threshold=clip_threshold,
|
68 |
+
decay_rate=decay_rate,
|
69 |
+
beta1=beta1,
|
70 |
+
weight_decay=weight_decay,
|
71 |
+
scale_parameter=scale_parameter,
|
72 |
+
relative_step=relative_step,
|
73 |
+
warmup_init=warmup_init,
|
74 |
+
)
|
75 |
+
super(Adafactor, self).__init__(params, defaults)
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def _get_lr(param_group, param_state):
|
79 |
+
if param_group["relative_step"]:
|
80 |
+
min_step = (
|
81 |
+
1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
|
82 |
+
)
|
83 |
+
lr_t = min(min_step, 1.0 / math.sqrt(param_state["step"]))
|
84 |
+
param_scale = 1.0
|
85 |
+
if param_group["scale_parameter"]:
|
86 |
+
param_scale = max(param_group["eps_scale"], param_state["RMS"])
|
87 |
+
param_group["lr"] = lr_t * param_scale
|
88 |
+
return param_group["lr"]
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def _get_options(param_group, param_shape):
|
92 |
+
factored = len(param_shape) >= 2
|
93 |
+
use_first_moment = param_group["beta1"] is not None
|
94 |
+
return factored, use_first_moment
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def _rms(tensor):
|
98 |
+
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
99 |
+
|
100 |
+
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
|
101 |
+
r_factor = (
|
102 |
+
(exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
|
103 |
+
.rsqrt_()
|
104 |
+
.unsqueeze(-1)
|
105 |
+
)
|
106 |
+
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
107 |
+
return torch.mul(r_factor, c_factor)
|
108 |
+
|
109 |
+
def step(self, closure=None):
|
110 |
+
"""Performs a single optimization step.
|
111 |
+
Arguments:
|
112 |
+
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
113 |
+
"""
|
114 |
+
loss = None
|
115 |
+
if closure is not None:
|
116 |
+
loss = closure()
|
117 |
+
|
118 |
+
for group in self.param_groups:
|
119 |
+
for p in group["params"]:
|
120 |
+
if p.grad is None:
|
121 |
+
continue
|
122 |
+
grad = p.grad.data
|
123 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
124 |
+
grad = grad.float()
|
125 |
+
if grad.is_sparse:
|
126 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
127 |
+
|
128 |
+
state = self.state[p]
|
129 |
+
grad_shape = grad.shape
|
130 |
+
|
131 |
+
factored, use_first_moment = self._get_options(group, grad_shape)
|
132 |
+
# State Initialization
|
133 |
+
if len(state) == 0:
|
134 |
+
state["step"] = 0
|
135 |
+
|
136 |
+
if use_first_moment:
|
137 |
+
# Exponential moving average of gradient values
|
138 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
139 |
+
if factored:
|
140 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
141 |
+
state["exp_avg_sq_col"] = torch.zeros(
|
142 |
+
grad_shape[:-2] + grad_shape[-1:]
|
143 |
+
).to(grad)
|
144 |
+
else:
|
145 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
146 |
+
|
147 |
+
state["RMS"] = 0
|
148 |
+
else:
|
149 |
+
if use_first_moment:
|
150 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
151 |
+
if factored:
|
152 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
153 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
154 |
+
else:
|
155 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
156 |
+
|
157 |
+
p_data_fp32 = p.data
|
158 |
+
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
159 |
+
p_data_fp32 = p_data_fp32.float()
|
160 |
+
|
161 |
+
state["step"] += 1
|
162 |
+
state["RMS"] = self._rms(p_data_fp32)
|
163 |
+
lr_t = self._get_lr(group, state)
|
164 |
+
|
165 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
166 |
+
update = grad ** 2 + group["eps"]
|
167 |
+
if factored:
|
168 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
169 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
170 |
+
|
171 |
+
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
|
172 |
+
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
|
173 |
+
# exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
|
174 |
+
# exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
|
175 |
+
|
176 |
+
# Approximation of exponential moving average of square of gradient
|
177 |
+
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
178 |
+
update.mul_(grad)
|
179 |
+
else:
|
180 |
+
exp_avg_sq = state["exp_avg_sq"]
|
181 |
+
|
182 |
+
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
|
183 |
+
# exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
|
184 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
185 |
+
|
186 |
+
update.div_(
|
187 |
+
(self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)
|
188 |
+
)
|
189 |
+
update.mul_(lr_t)
|
190 |
+
|
191 |
+
if use_first_moment:
|
192 |
+
exp_avg = state["exp_avg"]
|
193 |
+
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
|
194 |
+
# exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
|
195 |
+
update = exp_avg
|
196 |
+
|
197 |
+
if group["weight_decay"] != 0:
|
198 |
+
p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
|
199 |
+
# p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
|
200 |
+
|
201 |
+
p_data_fp32.add_(-update)
|
202 |
+
|
203 |
+
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
204 |
+
p.data.copy_(p_data_fp32)
|
205 |
+
|
206 |
+
return loss
|
Sample_Finetuning_SIIMACR/I1_classification/optim/adahessian.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" AdaHessian Optimizer
|
2 |
+
|
3 |
+
Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
|
4 |
+
Originally licensed MIT, Copyright 2020, David Samuel
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class Adahessian(torch.optim.Optimizer):
|
10 |
+
"""
|
11 |
+
Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
|
12 |
+
|
13 |
+
Arguments:
|
14 |
+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
15 |
+
lr (float, optional): learning rate (default: 0.1)
|
16 |
+
betas ((float, float), optional): coefficients used for computing running averages of gradient and the
|
17 |
+
squared hessian trace (default: (0.9, 0.999))
|
18 |
+
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
|
19 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
|
20 |
+
hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
|
21 |
+
update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
|
22 |
+
(to save time) (default: 1)
|
23 |
+
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
params,
|
29 |
+
lr=0.1,
|
30 |
+
betas=(0.9, 0.999),
|
31 |
+
eps=1e-8,
|
32 |
+
weight_decay=0.0,
|
33 |
+
hessian_power=1.0,
|
34 |
+
update_each=1,
|
35 |
+
n_samples=1,
|
36 |
+
avg_conv_kernel=False,
|
37 |
+
):
|
38 |
+
if not 0.0 <= lr:
|
39 |
+
raise ValueError(f"Invalid learning rate: {lr}")
|
40 |
+
if not 0.0 <= eps:
|
41 |
+
raise ValueError(f"Invalid epsilon value: {eps}")
|
42 |
+
if not 0.0 <= betas[0] < 1.0:
|
43 |
+
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
44 |
+
if not 0.0 <= betas[1] < 1.0:
|
45 |
+
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
46 |
+
if not 0.0 <= hessian_power <= 1.0:
|
47 |
+
raise ValueError(f"Invalid Hessian power value: {hessian_power}")
|
48 |
+
|
49 |
+
self.n_samples = n_samples
|
50 |
+
self.update_each = update_each
|
51 |
+
self.avg_conv_kernel = avg_conv_kernel
|
52 |
+
|
53 |
+
# use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
|
54 |
+
self.seed = 2147483647
|
55 |
+
self.generator = torch.Generator().manual_seed(self.seed)
|
56 |
+
|
57 |
+
defaults = dict(
|
58 |
+
lr=lr,
|
59 |
+
betas=betas,
|
60 |
+
eps=eps,
|
61 |
+
weight_decay=weight_decay,
|
62 |
+
hessian_power=hessian_power,
|
63 |
+
)
|
64 |
+
super(Adahessian, self).__init__(params, defaults)
|
65 |
+
|
66 |
+
for p in self.get_params():
|
67 |
+
p.hess = 0.0
|
68 |
+
self.state[p]["hessian step"] = 0
|
69 |
+
|
70 |
+
@property
|
71 |
+
def is_second_order(self):
|
72 |
+
return True
|
73 |
+
|
74 |
+
def get_params(self):
|
75 |
+
"""
|
76 |
+
Gets all parameters in all param_groups with gradients
|
77 |
+
"""
|
78 |
+
|
79 |
+
return (
|
80 |
+
p for group in self.param_groups for p in group["params"] if p.requires_grad
|
81 |
+
)
|
82 |
+
|
83 |
+
def zero_hessian(self):
|
84 |
+
"""
|
85 |
+
Zeros out the accumalated hessian traces.
|
86 |
+
"""
|
87 |
+
|
88 |
+
for p in self.get_params():
|
89 |
+
if (
|
90 |
+
not isinstance(p.hess, float)
|
91 |
+
and self.state[p]["hessian step"] % self.update_each == 0
|
92 |
+
):
|
93 |
+
p.hess.zero_()
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def set_hessian(self):
|
97 |
+
"""
|
98 |
+
Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
|
99 |
+
"""
|
100 |
+
|
101 |
+
params = []
|
102 |
+
for p in filter(lambda p: p.grad is not None, self.get_params()):
|
103 |
+
if (
|
104 |
+
self.state[p]["hessian step"] % self.update_each == 0
|
105 |
+
): # compute the trace only each `update_each` step
|
106 |
+
params.append(p)
|
107 |
+
self.state[p]["hessian step"] += 1
|
108 |
+
|
109 |
+
if len(params) == 0:
|
110 |
+
return
|
111 |
+
|
112 |
+
if (
|
113 |
+
self.generator.device != params[0].device
|
114 |
+
): # hackish way of casting the generator to the right device
|
115 |
+
self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
|
116 |
+
|
117 |
+
grads = [p.grad for p in params]
|
118 |
+
|
119 |
+
for i in range(self.n_samples):
|
120 |
+
# Rademacher distribution {-1.0, 1.0}
|
121 |
+
zs = [
|
122 |
+
torch.randint(0, 2, p.size(), generator=self.generator, device=p.device)
|
123 |
+
* 2.0
|
124 |
+
- 1.0
|
125 |
+
for p in params
|
126 |
+
]
|
127 |
+
h_zs = torch.autograd.grad(
|
128 |
+
grads,
|
129 |
+
params,
|
130 |
+
grad_outputs=zs,
|
131 |
+
only_inputs=True,
|
132 |
+
retain_graph=i < self.n_samples - 1,
|
133 |
+
)
|
134 |
+
for h_z, z, p in zip(h_zs, zs, params):
|
135 |
+
p.hess += (
|
136 |
+
h_z * z / self.n_samples
|
137 |
+
) # approximate the expected values of z*(H@z)
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def step(self, closure=None):
|
141 |
+
"""
|
142 |
+
Performs a single optimization step.
|
143 |
+
Arguments:
|
144 |
+
closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
|
145 |
+
"""
|
146 |
+
|
147 |
+
loss = None
|
148 |
+
if closure is not None:
|
149 |
+
loss = closure()
|
150 |
+
|
151 |
+
self.zero_hessian()
|
152 |
+
self.set_hessian()
|
153 |
+
|
154 |
+
for group in self.param_groups:
|
155 |
+
for p in group["params"]:
|
156 |
+
if p.grad is None or p.hess is None:
|
157 |
+
continue
|
158 |
+
|
159 |
+
if self.avg_conv_kernel and p.dim() == 4:
|
160 |
+
p.hess = (
|
161 |
+
torch.abs(p.hess)
|
162 |
+
.mean(dim=[2, 3], keepdim=True)
|
163 |
+
.expand_as(p.hess)
|
164 |
+
.clone()
|
165 |
+
)
|
166 |
+
|
167 |
+
# Perform correct stepweight decay as in AdamW
|
168 |
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
169 |
+
|
170 |
+
state = self.state[p]
|
171 |
+
|
172 |
+
# State initialization
|
173 |
+
if len(state) == 1:
|
174 |
+
state["step"] = 0
|
175 |
+
# Exponential moving average of gradient values
|
176 |
+
state["exp_avg"] = torch.zeros_like(p)
|
177 |
+
# Exponential moving average of Hessian diagonal square values
|
178 |
+
state["exp_hessian_diag_sq"] = torch.zeros_like(p)
|
179 |
+
|
180 |
+
exp_avg, exp_hessian_diag_sq = (
|
181 |
+
state["exp_avg"],
|
182 |
+
state["exp_hessian_diag_sq"],
|
183 |
+
)
|
184 |
+
beta1, beta2 = group["betas"]
|
185 |
+
state["step"] += 1
|
186 |
+
|
187 |
+
# Decay the first and second moment running average coefficient
|
188 |
+
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
|
189 |
+
exp_hessian_diag_sq.mul_(beta2).addcmul_(
|
190 |
+
p.hess, p.hess, value=1 - beta2
|
191 |
+
)
|
192 |
+
|
193 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
194 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
195 |
+
|
196 |
+
k = group["hessian_power"]
|
197 |
+
denom = (
|
198 |
+
(exp_hessian_diag_sq / bias_correction2)
|
199 |
+
.pow_(k / 2)
|
200 |
+
.add_(group["eps"])
|
201 |
+
)
|
202 |
+
|
203 |
+
# make update
|
204 |
+
step_size = group["lr"] / bias_correction1
|
205 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
206 |
+
|
207 |
+
return loss
|
Sample_Finetuning_SIIMACR/I1_classification/optim/adamp.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
|
3 |
+
|
4 |
+
Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
|
5 |
+
Code: https://github.com/clovaai/AdamP
|
6 |
+
|
7 |
+
Copyright (c) 2020-present NAVER Corp.
|
8 |
+
MIT license
|
9 |
+
"""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.optim.optimizer import Optimizer, required
|
14 |
+
import math
|
15 |
+
|
16 |
+
|
17 |
+
class AdamP(Optimizer):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
params,
|
21 |
+
lr=1e-3,
|
22 |
+
betas=(0.9, 0.999),
|
23 |
+
eps=1e-8,
|
24 |
+
weight_decay=0,
|
25 |
+
delta=0.1,
|
26 |
+
wd_ratio=0.1,
|
27 |
+
nesterov=False,
|
28 |
+
):
|
29 |
+
defaults = dict(
|
30 |
+
lr=lr,
|
31 |
+
betas=betas,
|
32 |
+
eps=eps,
|
33 |
+
weight_decay=weight_decay,
|
34 |
+
delta=delta,
|
35 |
+
wd_ratio=wd_ratio,
|
36 |
+
nesterov=nesterov,
|
37 |
+
)
|
38 |
+
super(AdamP, self).__init__(params, defaults)
|
39 |
+
|
40 |
+
def _channel_view(self, x):
|
41 |
+
return x.view(x.size(0), -1)
|
42 |
+
|
43 |
+
def _layer_view(self, x):
|
44 |
+
return x.view(1, -1)
|
45 |
+
|
46 |
+
def _cosine_similarity(self, x, y, eps, view_func):
|
47 |
+
x = view_func(x)
|
48 |
+
y = view_func(y)
|
49 |
+
|
50 |
+
x_norm = x.norm(dim=1).add_(eps)
|
51 |
+
y_norm = y.norm(dim=1).add_(eps)
|
52 |
+
dot = (x * y).sum(dim=1)
|
53 |
+
|
54 |
+
return dot.abs() / x_norm / y_norm
|
55 |
+
|
56 |
+
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
|
57 |
+
wd = 1
|
58 |
+
expand_size = [-1] + [1] * (len(p.shape) - 1)
|
59 |
+
for view_func in [self._channel_view, self._layer_view]:
|
60 |
+
|
61 |
+
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
|
62 |
+
|
63 |
+
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
|
64 |
+
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
|
65 |
+
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
|
66 |
+
wd = wd_ratio
|
67 |
+
|
68 |
+
return perturb, wd
|
69 |
+
|
70 |
+
return perturb, wd
|
71 |
+
|
72 |
+
def step(self, closure=None):
|
73 |
+
loss = None
|
74 |
+
if closure is not None:
|
75 |
+
loss = closure()
|
76 |
+
|
77 |
+
for group in self.param_groups:
|
78 |
+
for p in group["params"]:
|
79 |
+
if p.grad is None:
|
80 |
+
continue
|
81 |
+
|
82 |
+
grad = p.grad.data
|
83 |
+
beta1, beta2 = group["betas"]
|
84 |
+
nesterov = group["nesterov"]
|
85 |
+
|
86 |
+
state = self.state[p]
|
87 |
+
|
88 |
+
# State initialization
|
89 |
+
if len(state) == 0:
|
90 |
+
state["step"] = 0
|
91 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
92 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
93 |
+
|
94 |
+
# Adam
|
95 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
96 |
+
|
97 |
+
state["step"] += 1
|
98 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
99 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
100 |
+
|
101 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
102 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
103 |
+
|
104 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
|
105 |
+
group["eps"]
|
106 |
+
)
|
107 |
+
step_size = group["lr"] / bias_correction1
|
108 |
+
|
109 |
+
if nesterov:
|
110 |
+
perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
|
111 |
+
else:
|
112 |
+
perturb = exp_avg / denom
|
113 |
+
|
114 |
+
# Projection
|
115 |
+
wd_ratio = 1
|
116 |
+
if len(p.shape) > 1:
|
117 |
+
perturb, wd_ratio = self._projection(
|
118 |
+
p,
|
119 |
+
grad,
|
120 |
+
perturb,
|
121 |
+
group["delta"],
|
122 |
+
group["wd_ratio"],
|
123 |
+
group["eps"],
|
124 |
+
)
|
125 |
+
|
126 |
+
# Weight decay
|
127 |
+
if group["weight_decay"] > 0:
|
128 |
+
p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio)
|
129 |
+
|
130 |
+
# Step
|
131 |
+
p.data.add_(-step_size, perturb)
|
132 |
+
|
133 |
+
return loss
|
Sample_Finetuning_SIIMACR/I1_classification/optim/adamw.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" AdamW Optimizer
|
2 |
+
Impl copied from PyTorch master
|
3 |
+
"""
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
from torch.optim.optimizer import Optimizer
|
7 |
+
|
8 |
+
|
9 |
+
class AdamW(Optimizer):
|
10 |
+
r"""Implements AdamW algorithm.
|
11 |
+
|
12 |
+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
13 |
+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
14 |
+
|
15 |
+
Arguments:
|
16 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
17 |
+
parameter groups
|
18 |
+
lr (float, optional): learning rate (default: 1e-3)
|
19 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
20 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
21 |
+
eps (float, optional): term added to the denominator to improve
|
22 |
+
numerical stability (default: 1e-8)
|
23 |
+
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
24 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
25 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
26 |
+
(default: False)
|
27 |
+
|
28 |
+
.. _Adam\: A Method for Stochastic Optimization:
|
29 |
+
https://arxiv.org/abs/1412.6980
|
30 |
+
.. _Decoupled Weight Decay Regularization:
|
31 |
+
https://arxiv.org/abs/1711.05101
|
32 |
+
.. _On the Convergence of Adam and Beyond:
|
33 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
params,
|
39 |
+
lr=1e-3,
|
40 |
+
betas=(0.9, 0.999),
|
41 |
+
eps=1e-8,
|
42 |
+
weight_decay=1e-2,
|
43 |
+
amsgrad=False,
|
44 |
+
):
|
45 |
+
if not 0.0 <= lr:
|
46 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
47 |
+
if not 0.0 <= eps:
|
48 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
49 |
+
if not 0.0 <= betas[0] < 1.0:
|
50 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
51 |
+
if not 0.0 <= betas[1] < 1.0:
|
52 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
53 |
+
defaults = dict(
|
54 |
+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
|
55 |
+
)
|
56 |
+
super(AdamW, self).__init__(params, defaults)
|
57 |
+
|
58 |
+
def __setstate__(self, state):
|
59 |
+
super(AdamW, self).__setstate__(state)
|
60 |
+
for group in self.param_groups:
|
61 |
+
group.setdefault("amsgrad", False)
|
62 |
+
|
63 |
+
def step(self, closure=None):
|
64 |
+
"""Performs a single optimization step.
|
65 |
+
|
66 |
+
Arguments:
|
67 |
+
closure (callable, optional): A closure that reevaluates the model
|
68 |
+
and returns the loss.
|
69 |
+
"""
|
70 |
+
loss = None
|
71 |
+
if closure is not None:
|
72 |
+
loss = closure()
|
73 |
+
|
74 |
+
for group in self.param_groups:
|
75 |
+
for p in group["params"]:
|
76 |
+
if p.grad is None:
|
77 |
+
continue
|
78 |
+
|
79 |
+
# Perform stepweight decay
|
80 |
+
p.data.mul_(1 - group["lr"] * group["weight_decay"])
|
81 |
+
|
82 |
+
# Perform optimization step
|
83 |
+
grad = p.grad.data
|
84 |
+
if grad.is_sparse:
|
85 |
+
raise RuntimeError(
|
86 |
+
"Adam does not support sparse gradients, please consider SparseAdam instead"
|
87 |
+
)
|
88 |
+
amsgrad = group["amsgrad"]
|
89 |
+
|
90 |
+
state = self.state[p]
|
91 |
+
|
92 |
+
# State initialization
|
93 |
+
if len(state) == 0:
|
94 |
+
state["step"] = 0
|
95 |
+
# Exponential moving average of gradient values
|
96 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
97 |
+
# Exponential moving average of squared gradient values
|
98 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
99 |
+
if amsgrad:
|
100 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
101 |
+
state["max_exp_avg_sq"] = torch.zeros_like(p.data)
|
102 |
+
|
103 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
104 |
+
if amsgrad:
|
105 |
+
max_exp_avg_sq = state["max_exp_avg_sq"]
|
106 |
+
beta1, beta2 = group["betas"]
|
107 |
+
|
108 |
+
state["step"] += 1
|
109 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
110 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
111 |
+
|
112 |
+
# Decay the first and second moment running average coefficient
|
113 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
114 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
115 |
+
if amsgrad:
|
116 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
117 |
+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
118 |
+
# Use the max. for normalizing running avg. of gradient
|
119 |
+
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
|
120 |
+
group["eps"]
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
|
124 |
+
group["eps"]
|
125 |
+
)
|
126 |
+
|
127 |
+
step_size = group["lr"] / bias_correction1
|
128 |
+
|
129 |
+
p.data.addcdiv_(-step_size, exp_avg, denom)
|
130 |
+
|
131 |
+
return loss
|
Sample_Finetuning_SIIMACR/I1_classification/optim/lookahead.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Lookahead Optimizer Wrapper.
|
2 |
+
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
|
3 |
+
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
|
4 |
+
|
5 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
from torch.optim.optimizer import Optimizer
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
|
12 |
+
class Lookahead(Optimizer):
|
13 |
+
def __init__(self, base_optimizer, alpha=0.5, k=6):
|
14 |
+
if not 0.0 <= alpha <= 1.0:
|
15 |
+
raise ValueError(f"Invalid slow update rate: {alpha}")
|
16 |
+
if not 1 <= k:
|
17 |
+
raise ValueError(f"Invalid lookahead steps: {k}")
|
18 |
+
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
|
19 |
+
self.base_optimizer = base_optimizer
|
20 |
+
self.param_groups = self.base_optimizer.param_groups
|
21 |
+
self.defaults = base_optimizer.defaults
|
22 |
+
self.defaults.update(defaults)
|
23 |
+
self.state = defaultdict(dict)
|
24 |
+
# manually add our defaults to the param groups
|
25 |
+
for name, default in defaults.items():
|
26 |
+
for group in self.param_groups:
|
27 |
+
group.setdefault(name, default)
|
28 |
+
|
29 |
+
def update_slow(self, group):
|
30 |
+
for fast_p in group["params"]:
|
31 |
+
if fast_p.grad is None:
|
32 |
+
continue
|
33 |
+
param_state = self.state[fast_p]
|
34 |
+
if "slow_buffer" not in param_state:
|
35 |
+
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
|
36 |
+
param_state["slow_buffer"].copy_(fast_p.data)
|
37 |
+
slow = param_state["slow_buffer"]
|
38 |
+
slow.add_(group["lookahead_alpha"], fast_p.data - slow)
|
39 |
+
fast_p.data.copy_(slow)
|
40 |
+
|
41 |
+
def sync_lookahead(self):
|
42 |
+
for group in self.param_groups:
|
43 |
+
self.update_slow(group)
|
44 |
+
|
45 |
+
def step(self, closure=None):
|
46 |
+
# assert id(self.param_groups) == id(self.base_optimizer.param_groups)
|
47 |
+
loss = self.base_optimizer.step(closure)
|
48 |
+
for group in self.param_groups:
|
49 |
+
group["lookahead_step"] += 1
|
50 |
+
if group["lookahead_step"] % group["lookahead_k"] == 0:
|
51 |
+
self.update_slow(group)
|
52 |
+
return loss
|
53 |
+
|
54 |
+
def state_dict(self):
|
55 |
+
fast_state_dict = self.base_optimizer.state_dict()
|
56 |
+
slow_state = {
|
57 |
+
(id(k) if isinstance(k, torch.Tensor) else k): v
|
58 |
+
for k, v in self.state.items()
|
59 |
+
}
|
60 |
+
fast_state = fast_state_dict["state"]
|
61 |
+
param_groups = fast_state_dict["param_groups"]
|
62 |
+
return {
|
63 |
+
"state": fast_state,
|
64 |
+
"slow_state": slow_state,
|
65 |
+
"param_groups": param_groups,
|
66 |
+
}
|
67 |
+
|
68 |
+
def load_state_dict(self, state_dict):
|
69 |
+
fast_state_dict = {
|
70 |
+
"state": state_dict["state"],
|
71 |
+
"param_groups": state_dict["param_groups"],
|
72 |
+
}
|
73 |
+
self.base_optimizer.load_state_dict(fast_state_dict)
|
74 |
+
|
75 |
+
# We want to restore the slow state, but share param_groups reference
|
76 |
+
# with base_optimizer. This is a bit redundant but least code
|
77 |
+
slow_state_new = False
|
78 |
+
if "slow_state" not in state_dict:
|
79 |
+
print("Loading state_dict from optimizer without Lookahead applied.")
|
80 |
+
state_dict["slow_state"] = defaultdict(dict)
|
81 |
+
slow_state_new = True
|
82 |
+
slow_state_dict = {
|
83 |
+
"state": state_dict["slow_state"],
|
84 |
+
"param_groups": state_dict[
|
85 |
+
"param_groups"
|
86 |
+
], # this is pointless but saves code
|
87 |
+
}
|
88 |
+
super(Lookahead, self).load_state_dict(slow_state_dict)
|
89 |
+
self.param_groups = (
|
90 |
+
self.base_optimizer.param_groups
|
91 |
+
) # make both ref same container
|
92 |
+
if slow_state_new:
|
93 |
+
# reapply defaults to catch missing lookahead specific ones
|
94 |
+
for name, default in self.defaults.items():
|
95 |
+
for group in self.param_groups:
|
96 |
+
group.setdefault(name, default)
|
Sample_Finetuning_SIIMACR/I1_classification/optim/nadam.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim import Optimizer
|
3 |
+
|
4 |
+
|
5 |
+
class Nadam(Optimizer):
|
6 |
+
"""Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
|
7 |
+
|
8 |
+
It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
|
9 |
+
|
10 |
+
Arguments:
|
11 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
12 |
+
parameter groups
|
13 |
+
lr (float, optional): learning rate (default: 2e-3)
|
14 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
15 |
+
running averages of gradient and its square
|
16 |
+
eps (float, optional): term added to the denominator to improve
|
17 |
+
numerical stability (default: 1e-8)
|
18 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
19 |
+
schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
|
20 |
+
|
21 |
+
__ http://cs229.stanford.edu/proj2015/054_report.pdf
|
22 |
+
__ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
|
23 |
+
|
24 |
+
Originally taken from: https://github.com/pytorch/pytorch/pull/1408
|
25 |
+
NOTE: Has potential issues but does work well on some problems.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
params,
|
31 |
+
lr=2e-3,
|
32 |
+
betas=(0.9, 0.999),
|
33 |
+
eps=1e-8,
|
34 |
+
weight_decay=0,
|
35 |
+
schedule_decay=4e-3,
|
36 |
+
):
|
37 |
+
defaults = dict(
|
38 |
+
lr=lr,
|
39 |
+
betas=betas,
|
40 |
+
eps=eps,
|
41 |
+
weight_decay=weight_decay,
|
42 |
+
schedule_decay=schedule_decay,
|
43 |
+
)
|
44 |
+
super(Nadam, self).__init__(params, defaults)
|
45 |
+
|
46 |
+
def step(self, closure=None):
|
47 |
+
"""Performs a single optimization step.
|
48 |
+
|
49 |
+
Arguments:
|
50 |
+
closure (callable, optional): A closure that reevaluates the model
|
51 |
+
and returns the loss.
|
52 |
+
"""
|
53 |
+
loss = None
|
54 |
+
if closure is not None:
|
55 |
+
loss = closure()
|
56 |
+
|
57 |
+
for group in self.param_groups:
|
58 |
+
for p in group["params"]:
|
59 |
+
if p.grad is None:
|
60 |
+
continue
|
61 |
+
grad = p.grad.data
|
62 |
+
state = self.state[p]
|
63 |
+
|
64 |
+
# State initialization
|
65 |
+
if len(state) == 0:
|
66 |
+
state["step"] = 0
|
67 |
+
state["m_schedule"] = 1.0
|
68 |
+
state["exp_avg"] = grad.new().resize_as_(grad).zero_()
|
69 |
+
state["exp_avg_sq"] = grad.new().resize_as_(grad).zero_()
|
70 |
+
|
71 |
+
# Warming momentum schedule
|
72 |
+
m_schedule = state["m_schedule"]
|
73 |
+
schedule_decay = group["schedule_decay"]
|
74 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
75 |
+
beta1, beta2 = group["betas"]
|
76 |
+
eps = group["eps"]
|
77 |
+
state["step"] += 1
|
78 |
+
t = state["step"]
|
79 |
+
|
80 |
+
if group["weight_decay"] != 0:
|
81 |
+
grad = grad.add(group["weight_decay"], p.data)
|
82 |
+
|
83 |
+
momentum_cache_t = beta1 * (1.0 - 0.5 * (0.96 ** (t * schedule_decay)))
|
84 |
+
momentum_cache_t_1 = beta1 * (
|
85 |
+
1.0 - 0.5 * (0.96 ** ((t + 1) * schedule_decay))
|
86 |
+
)
|
87 |
+
m_schedule_new = m_schedule * momentum_cache_t
|
88 |
+
m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
|
89 |
+
state["m_schedule"] = m_schedule_new
|
90 |
+
|
91 |
+
# Decay the first and second moment running average coefficient
|
92 |
+
exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
|
93 |
+
exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
|
94 |
+
exp_avg_sq_prime = exp_avg_sq / (1.0 - beta2 ** t)
|
95 |
+
denom = exp_avg_sq_prime.sqrt_().add_(eps)
|
96 |
+
|
97 |
+
p.data.addcdiv_(
|
98 |
+
-group["lr"] * (1.0 - momentum_cache_t) / (1.0 - m_schedule_new),
|
99 |
+
grad,
|
100 |
+
denom,
|
101 |
+
)
|
102 |
+
p.data.addcdiv_(
|
103 |
+
-group["lr"] * momentum_cache_t_1 / (1.0 - m_schedule_next),
|
104 |
+
exp_avg,
|
105 |
+
denom,
|
106 |
+
)
|
107 |
+
|
108 |
+
return loss
|
Sample_Finetuning_SIIMACR/I1_classification/optim/novograd.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""NovoGrad Optimizer.
|
2 |
+
Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
|
3 |
+
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
|
4 |
+
- https://arxiv.org/abs/1905.11286
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.optim.optimizer import Optimizer
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
class NovoGrad(Optimizer):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
params,
|
16 |
+
grad_averaging=False,
|
17 |
+
lr=0.1,
|
18 |
+
betas=(0.95, 0.98),
|
19 |
+
eps=1e-8,
|
20 |
+
weight_decay=0,
|
21 |
+
):
|
22 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
23 |
+
super(NovoGrad, self).__init__(params, defaults)
|
24 |
+
self._lr = lr
|
25 |
+
self._beta1 = betas[0]
|
26 |
+
self._beta2 = betas[1]
|
27 |
+
self._eps = eps
|
28 |
+
self._wd = weight_decay
|
29 |
+
self._grad_averaging = grad_averaging
|
30 |
+
|
31 |
+
self._momentum_initialized = False
|
32 |
+
|
33 |
+
def step(self, closure=None):
|
34 |
+
loss = None
|
35 |
+
if closure is not None:
|
36 |
+
loss = closure()
|
37 |
+
|
38 |
+
if not self._momentum_initialized:
|
39 |
+
for group in self.param_groups:
|
40 |
+
for p in group["params"]:
|
41 |
+
if p.grad is None:
|
42 |
+
continue
|
43 |
+
state = self.state[p]
|
44 |
+
grad = p.grad.data
|
45 |
+
if grad.is_sparse:
|
46 |
+
raise RuntimeError("NovoGrad does not support sparse gradients")
|
47 |
+
|
48 |
+
v = torch.norm(grad) ** 2
|
49 |
+
m = grad / (torch.sqrt(v) + self._eps) + self._wd * p.data
|
50 |
+
state["step"] = 0
|
51 |
+
state["v"] = v
|
52 |
+
state["m"] = m
|
53 |
+
state["grad_ema"] = None
|
54 |
+
self._momentum_initialized = True
|
55 |
+
|
56 |
+
for group in self.param_groups:
|
57 |
+
for p in group["params"]:
|
58 |
+
if p.grad is None:
|
59 |
+
continue
|
60 |
+
state = self.state[p]
|
61 |
+
state["step"] += 1
|
62 |
+
|
63 |
+
step, v, m = state["step"], state["v"], state["m"]
|
64 |
+
grad_ema = state["grad_ema"]
|
65 |
+
|
66 |
+
grad = p.grad.data
|
67 |
+
g2 = torch.norm(grad) ** 2
|
68 |
+
grad_ema = (
|
69 |
+
g2
|
70 |
+
if grad_ema is None
|
71 |
+
else grad_ema * self._beta2 + g2 * (1.0 - self._beta2)
|
72 |
+
)
|
73 |
+
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
|
74 |
+
|
75 |
+
if self._grad_averaging:
|
76 |
+
grad *= 1.0 - self._beta1
|
77 |
+
|
78 |
+
g2 = torch.norm(grad) ** 2
|
79 |
+
v = self._beta2 * v + (1.0 - self._beta2) * g2
|
80 |
+
m = self._beta1 * m + (
|
81 |
+
grad / (torch.sqrt(v) + self._eps) + self._wd * p.data
|
82 |
+
)
|
83 |
+
bias_correction1 = 1 - self._beta1 ** step
|
84 |
+
bias_correction2 = 1 - self._beta2 ** step
|
85 |
+
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
|
86 |
+
|
87 |
+
state["v"], state["m"] = v, m
|
88 |
+
state["grad_ema"] = grad_ema
|
89 |
+
p.data.add_(-step_size, m)
|
90 |
+
return loss
|
Sample_Finetuning_SIIMACR/I1_classification/optim/nvnovograd.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Nvidia NovoGrad Optimizer.
|
2 |
+
Original impl by Nvidia from Jasper example:
|
3 |
+
- https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
|
4 |
+
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
|
5 |
+
- https://arxiv.org/abs/1905.11286
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.optim.optimizer import Optimizer
|
10 |
+
import math
|
11 |
+
|
12 |
+
|
13 |
+
class NvNovoGrad(Optimizer):
|
14 |
+
"""
|
15 |
+
Implements Novograd algorithm.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
19 |
+
parameter groups
|
20 |
+
lr (float, optional): learning rate (default: 1e-3)
|
21 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
22 |
+
running averages of gradient and its square (default: (0.95, 0.98))
|
23 |
+
eps (float, optional): term added to the denominator to improve
|
24 |
+
numerical stability (default: 1e-8)
|
25 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
26 |
+
grad_averaging: gradient averaging
|
27 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
28 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
29 |
+
(default: False)
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
params,
|
35 |
+
lr=1e-3,
|
36 |
+
betas=(0.95, 0.98),
|
37 |
+
eps=1e-8,
|
38 |
+
weight_decay=0,
|
39 |
+
grad_averaging=False,
|
40 |
+
amsgrad=False,
|
41 |
+
):
|
42 |
+
if not 0.0 <= lr:
|
43 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
44 |
+
if not 0.0 <= eps:
|
45 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
46 |
+
if not 0.0 <= betas[0] < 1.0:
|
47 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
48 |
+
if not 0.0 <= betas[1] < 1.0:
|
49 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
50 |
+
defaults = dict(
|
51 |
+
lr=lr,
|
52 |
+
betas=betas,
|
53 |
+
eps=eps,
|
54 |
+
weight_decay=weight_decay,
|
55 |
+
grad_averaging=grad_averaging,
|
56 |
+
amsgrad=amsgrad,
|
57 |
+
)
|
58 |
+
|
59 |
+
super(NvNovoGrad, self).__init__(params, defaults)
|
60 |
+
|
61 |
+
def __setstate__(self, state):
|
62 |
+
super(NvNovoGrad, self).__setstate__(state)
|
63 |
+
for group in self.param_groups:
|
64 |
+
group.setdefault("amsgrad", False)
|
65 |
+
|
66 |
+
def step(self, closure=None):
|
67 |
+
"""Performs a single optimization step.
|
68 |
+
|
69 |
+
Arguments:
|
70 |
+
closure (callable, optional): A closure that reevaluates the model
|
71 |
+
and returns the loss.
|
72 |
+
"""
|
73 |
+
loss = None
|
74 |
+
if closure is not None:
|
75 |
+
loss = closure()
|
76 |
+
|
77 |
+
for group in self.param_groups:
|
78 |
+
for p in group["params"]:
|
79 |
+
if p.grad is None:
|
80 |
+
continue
|
81 |
+
grad = p.grad.data
|
82 |
+
if grad.is_sparse:
|
83 |
+
raise RuntimeError("Sparse gradients are not supported.")
|
84 |
+
amsgrad = group["amsgrad"]
|
85 |
+
|
86 |
+
state = self.state[p]
|
87 |
+
|
88 |
+
# State initialization
|
89 |
+
if len(state) == 0:
|
90 |
+
state["step"] = 0
|
91 |
+
# Exponential moving average of gradient values
|
92 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
93 |
+
# Exponential moving average of squared gradient values
|
94 |
+
state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)
|
95 |
+
if amsgrad:
|
96 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
97 |
+
state["max_exp_avg_sq"] = torch.zeros([]).to(
|
98 |
+
state["exp_avg"].device
|
99 |
+
)
|
100 |
+
|
101 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
102 |
+
if amsgrad:
|
103 |
+
max_exp_avg_sq = state["max_exp_avg_sq"]
|
104 |
+
beta1, beta2 = group["betas"]
|
105 |
+
|
106 |
+
state["step"] += 1
|
107 |
+
|
108 |
+
norm = torch.sum(torch.pow(grad, 2))
|
109 |
+
|
110 |
+
if exp_avg_sq == 0:
|
111 |
+
exp_avg_sq.copy_(norm)
|
112 |
+
else:
|
113 |
+
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
|
114 |
+
|
115 |
+
if amsgrad:
|
116 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
117 |
+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
118 |
+
# Use the max. for normalizing running avg. of gradient
|
119 |
+
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
|
120 |
+
else:
|
121 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
122 |
+
|
123 |
+
grad.div_(denom)
|
124 |
+
if group["weight_decay"] != 0:
|
125 |
+
grad.add_(group["weight_decay"], p.data)
|
126 |
+
if group["grad_averaging"]:
|
127 |
+
grad.mul_(1 - beta1)
|
128 |
+
exp_avg.mul_(beta1).add_(grad)
|
129 |
+
|
130 |
+
p.data.add_(-group["lr"], exp_avg)
|
131 |
+
|
132 |
+
return loss
|
Sample_Finetuning_SIIMACR/I1_classification/optim/optim_factory.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Optimizer Factory w/ Custom Weight Decay
|
2 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
from torch import optim as optim
|
6 |
+
|
7 |
+
from .adafactor import Adafactor
|
8 |
+
from .adahessian import Adahessian
|
9 |
+
from .adamp import AdamP
|
10 |
+
from .lookahead import Lookahead
|
11 |
+
from .nadam import Nadam
|
12 |
+
from .novograd import NovoGrad
|
13 |
+
from .nvnovograd import NvNovoGrad
|
14 |
+
from .radam import RAdam
|
15 |
+
from .rmsprop_tf import RMSpropTF
|
16 |
+
from .sgdp import SGDP
|
17 |
+
|
18 |
+
try:
|
19 |
+
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
20 |
+
|
21 |
+
has_apex = True
|
22 |
+
except ImportError:
|
23 |
+
has_apex = False
|
24 |
+
|
25 |
+
|
26 |
+
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
27 |
+
decay = []
|
28 |
+
no_decay = []
|
29 |
+
for name, param in model.named_parameters():
|
30 |
+
if not param.requires_grad:
|
31 |
+
continue # frozen weights
|
32 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
|
33 |
+
no_decay.append(param)
|
34 |
+
else:
|
35 |
+
decay.append(param)
|
36 |
+
return [
|
37 |
+
{"params": no_decay, "weight_decay": 0.0},
|
38 |
+
{"params": decay, "weight_decay": weight_decay},
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
+
def create_optimizer(args, model, filter_bias_and_bn=True):
|
43 |
+
opt_lower = args.opt.lower()
|
44 |
+
weight_decay = args.weight_decay
|
45 |
+
if weight_decay and filter_bias_and_bn:
|
46 |
+
skip = {}
|
47 |
+
if hasattr(model, "no_weight_decay"):
|
48 |
+
skip = model.no_weight_decay()
|
49 |
+
parameters = add_weight_decay(model, weight_decay, skip)
|
50 |
+
weight_decay = 0.0
|
51 |
+
else:
|
52 |
+
parameters = filter(
|
53 |
+
lambda p: p.requires_grad, model.parameters()
|
54 |
+
) # model.parameters()
|
55 |
+
|
56 |
+
if "fused" in opt_lower:
|
57 |
+
assert (
|
58 |
+
has_apex and torch.cuda.is_available()
|
59 |
+
), "APEX and CUDA required for fused optimizers"
|
60 |
+
|
61 |
+
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
|
62 |
+
if hasattr(args, "opt_eps") and args.opt_eps is not None:
|
63 |
+
opt_args["eps"] = args.opt_eps
|
64 |
+
if hasattr(args, "opt_betas") and args.opt_betas is not None:
|
65 |
+
opt_args["betas"] = args.opt_betas
|
66 |
+
if hasattr(args, "opt_args") and args.opt_args is not None:
|
67 |
+
opt_args.update(args.opt_args)
|
68 |
+
|
69 |
+
opt_split = opt_lower.split("_")
|
70 |
+
opt_lower = opt_split[-1]
|
71 |
+
if opt_lower == "sgd" or opt_lower == "nesterov":
|
72 |
+
opt_args.pop("eps", None)
|
73 |
+
optimizer = optim.SGD(
|
74 |
+
parameters, momentum=args.momentum, nesterov=True, **opt_args
|
75 |
+
)
|
76 |
+
elif opt_lower == "momentum":
|
77 |
+
opt_args.pop("eps", None)
|
78 |
+
optimizer = optim.SGD(
|
79 |
+
parameters, momentum=args.momentum, nesterov=False, **opt_args
|
80 |
+
)
|
81 |
+
elif opt_lower == "adam":
|
82 |
+
optimizer = optim.Adam(parameters, **opt_args)
|
83 |
+
elif opt_lower == "adamw":
|
84 |
+
optimizer = optim.AdamW(parameters, **opt_args)
|
85 |
+
elif opt_lower == "nadam":
|
86 |
+
optimizer = Nadam(parameters, **opt_args)
|
87 |
+
elif opt_lower == "radam":
|
88 |
+
optimizer = RAdam(parameters, **opt_args)
|
89 |
+
elif opt_lower == "adamp":
|
90 |
+
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
91 |
+
elif opt_lower == "sgdp":
|
92 |
+
optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
93 |
+
elif opt_lower == "adadelta":
|
94 |
+
optimizer = optim.Adadelta(parameters, **opt_args)
|
95 |
+
elif opt_lower == "adafactor":
|
96 |
+
if not args.lr:
|
97 |
+
opt_args["lr"] = None
|
98 |
+
optimizer = Adafactor(parameters, **opt_args)
|
99 |
+
elif opt_lower == "adahessian":
|
100 |
+
optimizer = Adahessian(parameters, **opt_args)
|
101 |
+
elif opt_lower == "rmsprop":
|
102 |
+
optimizer = optim.RMSprop(
|
103 |
+
parameters, alpha=0.9, momentum=args.momentum, **opt_args
|
104 |
+
)
|
105 |
+
elif opt_lower == "rmsproptf":
|
106 |
+
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
107 |
+
elif opt_lower == "novograd":
|
108 |
+
optimizer = NovoGrad(parameters, **opt_args)
|
109 |
+
elif opt_lower == "nvnovograd":
|
110 |
+
optimizer = NvNovoGrad(parameters, **opt_args)
|
111 |
+
elif opt_lower == "fusedsgd":
|
112 |
+
opt_args.pop("eps", None)
|
113 |
+
optimizer = FusedSGD(
|
114 |
+
parameters, momentum=args.momentum, nesterov=True, **opt_args
|
115 |
+
)
|
116 |
+
elif opt_lower == "fusedmomentum":
|
117 |
+
opt_args.pop("eps", None)
|
118 |
+
optimizer = FusedSGD(
|
119 |
+
parameters, momentum=args.momentum, nesterov=False, **opt_args
|
120 |
+
)
|
121 |
+
elif opt_lower == "fusedadam":
|
122 |
+
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
123 |
+
elif opt_lower == "fusedadamw":
|
124 |
+
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
|
125 |
+
elif opt_lower == "fusedlamb":
|
126 |
+
optimizer = FusedLAMB(parameters, **opt_args)
|
127 |
+
elif opt_lower == "fusednovograd":
|
128 |
+
opt_args.setdefault("betas", (0.95, 0.98))
|
129 |
+
optimizer = FusedNovoGrad(parameters, **opt_args)
|
130 |
+
else:
|
131 |
+
assert False and "Invalid optimizer"
|
132 |
+
raise ValueError
|
133 |
+
|
134 |
+
if len(opt_split) > 1:
|
135 |
+
if opt_split[0] == "lookahead":
|
136 |
+
optimizer = Lookahead(optimizer)
|
137 |
+
|
138 |
+
return optimizer
|
Sample_Finetuning_SIIMACR/I1_classification/optim/radam.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""RAdam Optimizer.
|
2 |
+
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
|
3 |
+
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch.optim.optimizer import Optimizer, required
|
8 |
+
|
9 |
+
|
10 |
+
class RAdam(Optimizer):
|
11 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
12 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
13 |
+
self.buffer = [[None, None, None] for ind in range(10)]
|
14 |
+
super(RAdam, self).__init__(params, defaults)
|
15 |
+
|
16 |
+
def __setstate__(self, state):
|
17 |
+
super(RAdam, self).__setstate__(state)
|
18 |
+
|
19 |
+
def step(self, closure=None):
|
20 |
+
|
21 |
+
loss = None
|
22 |
+
if closure is not None:
|
23 |
+
loss = closure()
|
24 |
+
|
25 |
+
for group in self.param_groups:
|
26 |
+
|
27 |
+
for p in group["params"]:
|
28 |
+
if p.grad is None:
|
29 |
+
continue
|
30 |
+
grad = p.grad.data.float()
|
31 |
+
if grad.is_sparse:
|
32 |
+
raise RuntimeError("RAdam does not support sparse gradients")
|
33 |
+
|
34 |
+
p_data_fp32 = p.data.float()
|
35 |
+
|
36 |
+
state = self.state[p]
|
37 |
+
|
38 |
+
if len(state) == 0:
|
39 |
+
state["step"] = 0
|
40 |
+
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
41 |
+
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
42 |
+
else:
|
43 |
+
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
44 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
45 |
+
|
46 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
47 |
+
beta1, beta2 = group["betas"]
|
48 |
+
|
49 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
50 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
51 |
+
|
52 |
+
state["step"] += 1
|
53 |
+
buffered = self.buffer[int(state["step"] % 10)]
|
54 |
+
if state["step"] == buffered[0]:
|
55 |
+
N_sma, step_size = buffered[1], buffered[2]
|
56 |
+
else:
|
57 |
+
buffered[0] = state["step"]
|
58 |
+
beta2_t = beta2 ** state["step"]
|
59 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
60 |
+
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
61 |
+
buffered[1] = N_sma
|
62 |
+
|
63 |
+
# more conservative since it's an approximated value
|
64 |
+
if N_sma >= 5:
|
65 |
+
step_size = (
|
66 |
+
group["lr"]
|
67 |
+
* math.sqrt(
|
68 |
+
(1 - beta2_t)
|
69 |
+
* (N_sma - 4)
|
70 |
+
/ (N_sma_max - 4)
|
71 |
+
* (N_sma - 2)
|
72 |
+
/ N_sma
|
73 |
+
* N_sma_max
|
74 |
+
/ (N_sma_max - 2)
|
75 |
+
)
|
76 |
+
/ (1 - beta1 ** state["step"])
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
step_size = group["lr"] / (1 - beta1 ** state["step"])
|
80 |
+
buffered[2] = step_size
|
81 |
+
|
82 |
+
if group["weight_decay"] != 0:
|
83 |
+
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
84 |
+
|
85 |
+
# more conservative since it's an approximated value
|
86 |
+
if N_sma >= 5:
|
87 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
88 |
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
89 |
+
else:
|
90 |
+
p_data_fp32.add_(-step_size, exp_avg)
|
91 |
+
|
92 |
+
p.data.copy_(p_data_fp32)
|
93 |
+
|
94 |
+
return loss
|
95 |
+
|
96 |
+
|
97 |
+
class PlainRAdam(Optimizer):
|
98 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
99 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
100 |
+
|
101 |
+
super(PlainRAdam, self).__init__(params, defaults)
|
102 |
+
|
103 |
+
def __setstate__(self, state):
|
104 |
+
super(PlainRAdam, self).__setstate__(state)
|
105 |
+
|
106 |
+
def step(self, closure=None):
|
107 |
+
|
108 |
+
loss = None
|
109 |
+
if closure is not None:
|
110 |
+
loss = closure()
|
111 |
+
|
112 |
+
for group in self.param_groups:
|
113 |
+
|
114 |
+
for p in group["params"]:
|
115 |
+
if p.grad is None:
|
116 |
+
continue
|
117 |
+
grad = p.grad.data.float()
|
118 |
+
if grad.is_sparse:
|
119 |
+
raise RuntimeError("RAdam does not support sparse gradients")
|
120 |
+
|
121 |
+
p_data_fp32 = p.data.float()
|
122 |
+
|
123 |
+
state = self.state[p]
|
124 |
+
|
125 |
+
if len(state) == 0:
|
126 |
+
state["step"] = 0
|
127 |
+
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
128 |
+
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
129 |
+
else:
|
130 |
+
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
131 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
132 |
+
|
133 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
134 |
+
beta1, beta2 = group["betas"]
|
135 |
+
|
136 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
137 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
138 |
+
|
139 |
+
state["step"] += 1
|
140 |
+
beta2_t = beta2 ** state["step"]
|
141 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
142 |
+
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
143 |
+
|
144 |
+
if group["weight_decay"] != 0:
|
145 |
+
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
146 |
+
|
147 |
+
# more conservative since it's an approximated value
|
148 |
+
if N_sma >= 5:
|
149 |
+
step_size = (
|
150 |
+
group["lr"]
|
151 |
+
* math.sqrt(
|
152 |
+
(1 - beta2_t)
|
153 |
+
* (N_sma - 4)
|
154 |
+
/ (N_sma_max - 4)
|
155 |
+
* (N_sma - 2)
|
156 |
+
/ N_sma
|
157 |
+
* N_sma_max
|
158 |
+
/ (N_sma_max - 2)
|
159 |
+
)
|
160 |
+
/ (1 - beta1 ** state["step"])
|
161 |
+
)
|
162 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
163 |
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
164 |
+
else:
|
165 |
+
step_size = group["lr"] / (1 - beta1 ** state["step"])
|
166 |
+
p_data_fp32.add_(-step_size, exp_avg)
|
167 |
+
|
168 |
+
p.data.copy_(p_data_fp32)
|
169 |
+
|
170 |
+
return loss
|
Sample_Finetuning_SIIMACR/I1_classification/optim/rmsprop_tf.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" RMSProp modified to behave like Tensorflow impl
|
2 |
+
|
3 |
+
Originally cut & paste from PyTorch RMSProp
|
4 |
+
https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
|
5 |
+
Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
|
6 |
+
|
7 |
+
Modifications Copyright 2020 Ross Wightman
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.optim import Optimizer
|
12 |
+
|
13 |
+
|
14 |
+
class RMSpropTF(Optimizer):
|
15 |
+
"""Implements RMSprop algorithm (TensorFlow style epsilon)
|
16 |
+
|
17 |
+
NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
|
18 |
+
and a few other modifications to closer match Tensorflow for matching hyper-params.
|
19 |
+
|
20 |
+
Noteworthy changes include:
|
21 |
+
1. Epsilon applied inside square-root
|
22 |
+
2. square_avg initialized to ones
|
23 |
+
3. LR scaling of update accumulated in momentum buffer
|
24 |
+
|
25 |
+
Proposed by G. Hinton in his
|
26 |
+
`course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
|
27 |
+
|
28 |
+
The centered version first appears in `Generating Sequences
|
29 |
+
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
|
30 |
+
|
31 |
+
Arguments:
|
32 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
33 |
+
parameter groups
|
34 |
+
lr (float, optional): learning rate (default: 1e-2)
|
35 |
+
momentum (float, optional): momentum factor (default: 0)
|
36 |
+
alpha (float, optional): smoothing (decay) constant (default: 0.9)
|
37 |
+
eps (float, optional): term added to the denominator to improve
|
38 |
+
numerical stability (default: 1e-10)
|
39 |
+
centered (bool, optional) : if ``True``, compute the centered RMSProp,
|
40 |
+
the gradient is normalized by an estimation of its variance
|
41 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
42 |
+
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
|
43 |
+
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
|
44 |
+
update as per defaults in Tensorflow
|
45 |
+
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
params,
|
51 |
+
lr=1e-2,
|
52 |
+
alpha=0.9,
|
53 |
+
eps=1e-10,
|
54 |
+
weight_decay=0,
|
55 |
+
momentum=0.0,
|
56 |
+
centered=False,
|
57 |
+
decoupled_decay=False,
|
58 |
+
lr_in_momentum=True,
|
59 |
+
):
|
60 |
+
if not 0.0 <= lr:
|
61 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
62 |
+
if not 0.0 <= eps:
|
63 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
64 |
+
if not 0.0 <= momentum:
|
65 |
+
raise ValueError("Invalid momentum value: {}".format(momentum))
|
66 |
+
if not 0.0 <= weight_decay:
|
67 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
68 |
+
if not 0.0 <= alpha:
|
69 |
+
raise ValueError("Invalid alpha value: {}".format(alpha))
|
70 |
+
|
71 |
+
defaults = dict(
|
72 |
+
lr=lr,
|
73 |
+
momentum=momentum,
|
74 |
+
alpha=alpha,
|
75 |
+
eps=eps,
|
76 |
+
centered=centered,
|
77 |
+
weight_decay=weight_decay,
|
78 |
+
decoupled_decay=decoupled_decay,
|
79 |
+
lr_in_momentum=lr_in_momentum,
|
80 |
+
)
|
81 |
+
super(RMSpropTF, self).__init__(params, defaults)
|
82 |
+
|
83 |
+
def __setstate__(self, state):
|
84 |
+
super(RMSpropTF, self).__setstate__(state)
|
85 |
+
for group in self.param_groups:
|
86 |
+
group.setdefault("momentum", 0)
|
87 |
+
group.setdefault("centered", False)
|
88 |
+
|
89 |
+
def step(self, closure=None):
|
90 |
+
"""Performs a single optimization step.
|
91 |
+
|
92 |
+
Arguments:
|
93 |
+
closure (callable, optional): A closure that reevaluates the model
|
94 |
+
and returns the loss.
|
95 |
+
"""
|
96 |
+
loss = None
|
97 |
+
if closure is not None:
|
98 |
+
loss = closure()
|
99 |
+
|
100 |
+
for group in self.param_groups:
|
101 |
+
for p in group["params"]:
|
102 |
+
if p.grad is None:
|
103 |
+
continue
|
104 |
+
grad = p.grad.data
|
105 |
+
if grad.is_sparse:
|
106 |
+
raise RuntimeError("RMSprop does not support sparse gradients")
|
107 |
+
state = self.state[p]
|
108 |
+
|
109 |
+
# State initialization
|
110 |
+
if len(state) == 0:
|
111 |
+
state["step"] = 0
|
112 |
+
state["square_avg"] = torch.ones_like(
|
113 |
+
p.data
|
114 |
+
) # PyTorch inits to zero
|
115 |
+
if group["momentum"] > 0:
|
116 |
+
state["momentum_buffer"] = torch.zeros_like(p.data)
|
117 |
+
if group["centered"]:
|
118 |
+
state["grad_avg"] = torch.zeros_like(p.data)
|
119 |
+
|
120 |
+
square_avg = state["square_avg"]
|
121 |
+
one_minus_alpha = 1.0 - group["alpha"]
|
122 |
+
|
123 |
+
state["step"] += 1
|
124 |
+
|
125 |
+
if group["weight_decay"] != 0:
|
126 |
+
if "decoupled_decay" in group and group["decoupled_decay"]:
|
127 |
+
p.data.add_(-group["weight_decay"], p.data)
|
128 |
+
else:
|
129 |
+
grad = grad.add(group["weight_decay"], p.data)
|
130 |
+
|
131 |
+
# Tensorflow order of ops for updating squared avg
|
132 |
+
square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
|
133 |
+
# square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
|
134 |
+
|
135 |
+
if group["centered"]:
|
136 |
+
grad_avg = state["grad_avg"]
|
137 |
+
grad_avg.add_(one_minus_alpha, grad - grad_avg)
|
138 |
+
# grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
|
139 |
+
avg = (
|
140 |
+
square_avg.addcmul(-1, grad_avg, grad_avg)
|
141 |
+
.add(group["eps"])
|
142 |
+
.sqrt_()
|
143 |
+
) # eps moved in sqrt
|
144 |
+
else:
|
145 |
+
avg = square_avg.add(group["eps"]).sqrt_() # eps moved in sqrt
|
146 |
+
|
147 |
+
if group["momentum"] > 0:
|
148 |
+
buf = state["momentum_buffer"]
|
149 |
+
# Tensorflow accumulates the LR scaling in the momentum buffer
|
150 |
+
if "lr_in_momentum" in group and group["lr_in_momentum"]:
|
151 |
+
buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
|
152 |
+
p.data.add_(-buf)
|
153 |
+
else:
|
154 |
+
# PyTorch scales the param update by LR
|
155 |
+
buf.mul_(group["momentum"]).addcdiv_(grad, avg)
|
156 |
+
p.data.add_(-group["lr"], buf)
|
157 |
+
else:
|
158 |
+
p.data.addcdiv_(-group["lr"], grad, avg)
|
159 |
+
|
160 |
+
return loss
|