Johnyquest7 commited on
Commit
b66d95f
·
verified ·
1 Parent(s): 4be7566

Upload train_thyroid.py

Browse files
Files changed (1) hide show
  1. train_thyroid.py +204 -0
train_thyroid.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Thyroid Ultrasound Nodule Malignancy Classification
4
+ Dataset: BTX24/thyroid-cancer-classification-ultrasound-dataset
5
+ Binary classification: benign (0) vs malignant (1)
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import numpy as np
11
+ from collections import Counter
12
+
13
+ from datasets import load_dataset
14
+ from transformers import (
15
+ AutoImageProcessor,
16
+ AutoModelForImageClassification,
17
+ TrainingArguments,
18
+ Trainer,
19
+ DefaultDataCollator,
20
+ EarlyStoppingCallback,
21
+ )
22
+ import evaluate
23
+ import torch
24
+ from torchvision.transforms import (
25
+ Compose, Resize, RandomRotation, RandomHorizontalFlip,
26
+ RandomVerticalFlip, ColorJitter, ToTensor, Normalize
27
+ )
28
+
29
+ # ------------------------------------------------------------------
30
+ # Config
31
+ # ------------------------------------------------------------------
32
+ DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset"
33
+ MODEL_NAME = "microsoft/swinv2-base-patch4-window8-256"
34
+ OUTPUT_DIR = "./thyroid-swinv2-model"
35
+ HUB_MODEL_ID = "Johnyquest7/ML-Inter_thyroid"
36
+
37
+ NUM_LABELS = 2
38
+ ID2LABEL = {0: "benign", 1: "malignant"}
39
+ LABEL2ID = {"benign": 0, "malignant": 1}
40
+
41
+ # ------------------------------------------------------------------
42
+ # Metrics
43
+ # ------------------------------------------------------------------
44
+ accuracy = evaluate.load("accuracy")
45
+ f1 = evaluate.load("f1")
46
+ precision = evaluate.load("precision")
47
+ recall = evaluate.load("recall")
48
+ roc_auc = evaluate.load("roc_auc")
49
+
50
+ def compute_metrics(eval_pred):
51
+ logits, labels = eval_pred
52
+ preds = np.argmax(logits, axis=1)
53
+ probs = torch.softmax(torch.tensor(logits), dim=1)[:, 1].numpy()
54
+
55
+ result = {}
56
+ result.update(accuracy.compute(predictions=preds, references=labels))
57
+ result.update(f1.compute(predictions=preds, references=labels, average="binary"))
58
+ result.update(precision.compute(predictions=preds, references=labels, average="binary"))
59
+ result.update(recall.compute(predictions=preds, references=labels, average="binary"))
60
+ try:
61
+ result.update(roc_auc.compute(prediction_scores=probs, references=labels))
62
+ except Exception:
63
+ result["roc_auc"] = 0.0
64
+ return result
65
+
66
+ # ------------------------------------------------------------------
67
+ # Load dataset
68
+ # ------------------------------------------------------------------
69
+ print("Loading dataset...")
70
+ train_ds = load_dataset(DATASET_NAME, split="train")
71
+ test_ds = load_dataset(DATASET_NAME, split="test")
72
+
73
+ # Create validation split from train
74
+ train_val = train_ds.train_test_split(test_size=0.15, stratify_by_column="label", seed=42)
75
+ train_ds = train_val["train"]
76
+ val_ds = train_val["test"]
77
+
78
+ print(f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}")
79
+ print(f"Train labels: {Counter(train_ds['label'])}")
80
+ print(f"Val labels: {Counter(val_ds['label'])}")
81
+ print(f"Test labels: {Counter(test_ds['label'])}")
82
+
83
+ # ------------------------------------------------------------------
84
+ # Image processor & transforms
85
+ # ------------------------------------------------------------------
86
+ print(f"Loading image processor from {MODEL_NAME}...")
87
+ image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
88
+
89
+ # Ultrasound images are grayscale (mode 'L') — convert to RGB for SwinV2
90
+ image_mean = image_processor.image_mean
91
+ image_std = image_processor.image_std
92
+ size = (
93
+ image_processor.size["shortest_edge"]
94
+ if "shortest_edge" in image_processor.size
95
+ else (image_processor.size["height"], image_processor.size["width"])
96
+ )
97
+
98
+ train_transforms = Compose([
99
+ Resize(size),
100
+ RandomRotation(degrees=10),
101
+ RandomHorizontalFlip(p=0.5),
102
+ RandomVerticalFlip(p=0.3),
103
+ ColorJitter(brightness=0.2, contrast=0.2),
104
+ ToTensor(),
105
+ Normalize(mean=image_mean, std=image_std),
106
+ ])
107
+
108
+ val_transforms = Compose([
109
+ Resize(size),
110
+ ToTensor(),
111
+ Normalize(mean=image_mean, std=image_std),
112
+ ])
113
+
114
+ def preprocess_train(examples):
115
+ # Convert grayscale to RGB
116
+ examples["pixel_values"] = [
117
+ train_transforms(img.convert("RGB")) for img in examples["image"]
118
+ ]
119
+ del examples["image"]
120
+ return examples
121
+
122
+ def preprocess_val(examples):
123
+ examples["pixel_values"] = [
124
+ val_transforms(img.convert("RGB")) for img in examples["image"]
125
+ ]
126
+ del examples["image"]
127
+ return examples
128
+
129
+ print("Applying transforms...")
130
+ train_ds = train_ds.with_transform(preprocess_train)
131
+ val_ds = val_ds.with_transform(preprocess_val)
132
+ test_ds = test_ds.with_transform(preprocess_val)
133
+
134
+ # ------------------------------------------------------------------
135
+ # Model
136
+ # ------------------------------------------------------------------
137
+ print(f"Loading model {MODEL_NAME}...")
138
+ model = AutoModelForImageClassification.from_pretrained(
139
+ MODEL_NAME,
140
+ num_labels=NUM_LABELS,
141
+ id2label=ID2LABEL,
142
+ label2id=LABEL2ID,
143
+ ignore_mismatched_sizes=True,
144
+ )
145
+
146
+ # ------------------------------------------------------------------
147
+ # Training arguments
148
+ # ------------------------------------------------------------------
149
+ training_args = TrainingArguments(
150
+ output_dir=OUTPUT_DIR,
151
+ remove_unused_columns=False,
152
+ eval_strategy="epoch",
153
+ save_strategy="epoch",
154
+ learning_rate=2e-5,
155
+ per_device_train_batch_size=16,
156
+ per_device_eval_batch_size=16,
157
+ gradient_accumulation_steps=2,
158
+ num_train_epochs=30,
159
+ warmup_steps=100,
160
+ weight_decay=0.01,
161
+ logging_strategy="steps",
162
+ logging_steps=10,
163
+ logging_first_step=True,
164
+ disable_tqdm=True,
165
+ load_best_model_at_end=True,
166
+ metric_for_best_model="eval_roc_auc",
167
+ greater_is_better=True,
168
+ push_to_hub=True,
169
+ hub_model_id=HUB_MODEL_ID,
170
+ report_to="trackio",
171
+ run_name="thyroid-swinv2-binary",
172
+ project="thyroid-malignancy",
173
+ seed=42,
174
+ bf16=True,
175
+ dataloader_num_workers=4,
176
+ )
177
+
178
+ # ------------------------------------------------------------------
179
+ # Trainer
180
+ # ------------------------------------------------------------------
181
+ data_collator = DefaultDataCollator()
182
+
183
+ trainer = Trainer(
184
+ model=model,
185
+ args=training_args,
186
+ data_collator=data_collator,
187
+ train_dataset=train_ds,
188
+ eval_dataset=val_ds,
189
+ processing_class=image_processor,
190
+ compute_metrics=compute_metrics,
191
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
192
+ )
193
+
194
+ print("Starting training...")
195
+ trainer.train()
196
+
197
+ print("Evaluating on test set...")
198
+ test_results = trainer.evaluate(test_ds, metric_key_prefix="test")
199
+ print("Test results:", test_results)
200
+
201
+ print("Pushing to Hub...")
202
+ trainer.push_to_hub()
203
+
204
+ print("Done!")