SauravMaheshkar commited on
Commit
bfd8285
1 Parent(s): ca1a0ac

feat: add initial template

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. app.py +52 -0
  3. bin/dino.index +3 -0
  4. bin/model.ckpt +3 -0
  5. model.py +159 -0
  6. requirements.txt +8 -0
.gitattributes CHANGED
@@ -6,6 +6,7 @@
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.index filter=lfs diff=lfs merge=lfs -text
10
  *.joblib filter=lfs diff=lfs merge=lfs -text
11
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from faiss import read_index
5
+ from PIL import Image, ImageOps
6
+ from datasets import load_dataset
7
+ import torchvision.transforms as T
8
+ from torchvision.models import resnet50
9
+
10
+ from model import DINO
11
+
12
+ transforms = T.Compose(
13
+ [T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])]
14
+ )
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ datset = load_dataset("ethz/food101")
19
+
20
+ model = DINO(batch_size_per_device=32, num_classes=1000).to(device)
21
+ model.load_state_dict(torch.load("./bin/model.ckpt", map_location=device)["state_dict"])
22
+
23
+
24
+ def augment(img, transforms=transforms) -> torch.Tensor:
25
+ img = Image.fromarray(img)
26
+ if img.mode == "L":
27
+ # Convert grayscale image to RGB by duplicating the single channel three times
28
+ img = ImageOps.colorize(img, black="black", white="white")
29
+ return transforms(img).unsqueeze(0)
30
+
31
+
32
+ def search_index(input_image, k: int):
33
+ with torch.no_grad():
34
+ embedding = model(augment(input_image))
35
+ index = read_index("./bin/dino.index")
36
+ _, I = index.search(np.array(embedding[0].reshape(1, -1)), k)
37
+ indices = I[0]
38
+ answer = ""
39
+ for i, index in enumerate(indices[:3]):
40
+ answer += index
41
+ # retrieved_img = dataset["train"][int(index)]["image"]
42
+ return answer
43
+
44
+
45
+ app = gr.Interface(
46
+ search_index,
47
+ inputs=[gr.Image(), gr.Slider(value=3, minimum=1, step=1)],
48
+ outputs="text",
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ app.launch()
bin/dino.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19ebbf3848fc84c63cc7a50cc2e26a82a99018a3be2558ea4cca50b5f14f273d
3
+ size 620544045
bin/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ab95b4201d663ba01d8fbc19643b99d4cccbf459ff10ca8455fa226950fd0f1
3
+ size 608315727
model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from pytorch_lightning import LightningModule
4
+ from torch import Tensor
5
+ from torch.nn import Identity
6
+ from torchvision.models import resnet50
7
+
8
+ from lightly.loss import DINOLoss
9
+ from lightly.models.modules import DINOProjectionHead
10
+ from lightly.models.utils import (
11
+ activate_requires_grad,
12
+ deactivate_requires_grad,
13
+ get_weight_decay_parameters,
14
+ update_momentum,
15
+ )
16
+ from lightly.transforms import DINOTransform
17
+ from lightly.utils.benchmarking import OnlineLinearClassifier
18
+ from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule
19
+
20
+ from typing import Union, Tuple, List
21
+
22
+
23
+ class DINO(LightningModule):
24
+ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
25
+ super().__init__()
26
+ self.save_hyperparameters()
27
+ self.batch_size_per_device = batch_size_per_device
28
+
29
+ resnet = resnet50()
30
+ resnet.fc = Identity() # Ignore classification head
31
+ self.backbone = resnet
32
+ self.projection_head = DINOProjectionHead(freeze_last_layer=1)
33
+ self.student_backbone = copy.deepcopy(self.backbone)
34
+ self.student_projection_head = DINOProjectionHead()
35
+ self.criterion = DINOLoss(output_dim=65536)
36
+
37
+ self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)
38
+
39
+ def forward(self, x: Tensor) -> Tensor:
40
+ return self.backbone(x)
41
+
42
+ def forward_student(self, x: Tensor) -> Tensor:
43
+ features = self.student_backbone(x).flatten(start_dim=1)
44
+ projections = self.student_projection_head(features)
45
+ return projections
46
+
47
+ def on_train_start(self) -> None:
48
+ deactivate_requires_grad(self.backbone)
49
+ deactivate_requires_grad(self.projection_head)
50
+
51
+ def on_train_end(self) -> None:
52
+ activate_requires_grad(self.backbone)
53
+ activate_requires_grad(self.projection_head)
54
+
55
+ def training_step(
56
+ self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
57
+ ) -> Tensor:
58
+ # Momentum update teacher.
59
+ momentum = cosine_schedule(
60
+ step=self.trainer.global_step,
61
+ max_steps=self.trainer.estimated_stepping_batches,
62
+ start_value=0.996,
63
+ end_value=1.0,
64
+ )
65
+ update_momentum(self.student_backbone, self.backbone, m=momentum)
66
+ update_momentum(self.student_projection_head, self.projection_head, m=momentum)
67
+
68
+ views, targets = batch[0], batch[1]
69
+ global_views = torch.cat(views[:2])
70
+ local_views = torch.cat(views[2:])
71
+
72
+ teacher_features = self.forward(global_views).flatten(start_dim=1)
73
+ teacher_projections = self.projection_head(teacher_features)
74
+ student_projections = torch.cat(
75
+ [self.forward_student(global_views), self.forward_student(local_views)]
76
+ )
77
+
78
+ loss = self.criterion(
79
+ teacher_out=teacher_projections.chunk(2),
80
+ student_out=student_projections.chunk(len(views)),
81
+ epoch=self.current_epoch,
82
+ )
83
+ self.log_dict(
84
+ {"train_loss": loss, "ema_momentum": momentum},
85
+ prog_bar=True,
86
+ sync_dist=True,
87
+ batch_size=len(targets),
88
+ )
89
+
90
+ # Online classification.
91
+ cls_loss, cls_log = self.online_classifier.training_step(
92
+ (teacher_features.chunk(2)[0].detach(), targets), batch_idx
93
+ )
94
+ self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
95
+ return loss + cls_loss
96
+
97
+ def validation_step(
98
+ self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
99
+ ) -> Tensor:
100
+ images, targets = batch[0], batch[1]
101
+ features = self.forward(images).flatten(start_dim=1)
102
+ cls_loss, cls_log = self.online_classifier.validation_step(
103
+ (features.detach(), targets), batch_idx
104
+ )
105
+ self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
106
+ return cls_loss
107
+
108
+ def configure_optimizers(self):
109
+ # Don't use weight decay for batch norm, bias parameters, and classification
110
+ # head to improve performance.
111
+ params, params_no_weight_decay = get_weight_decay_parameters(
112
+ [self.student_backbone, self.student_projection_head]
113
+ )
114
+ # For ResNet50 we use SGD instead of AdamW/LARS as recommended by the authors:
115
+ # https://github.com/facebookresearch/dino#resnet-50-and-other-convnets-trainings
116
+ optimizer = SGD(
117
+ [
118
+ {"name": "dino", "params": params},
119
+ {
120
+ "name": "dino_no_weight_decay",
121
+ "params": params_no_weight_decay,
122
+ "weight_decay": 0.0,
123
+ },
124
+ {
125
+ "name": "online_classifier",
126
+ "params": self.online_classifier.parameters(),
127
+ "weight_decay": 0.0,
128
+ },
129
+ ],
130
+ lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256,
131
+ momentum=0.9,
132
+ weight_decay=1e-4,
133
+ )
134
+ scheduler = {
135
+ "scheduler": CosineWarmupScheduler(
136
+ optimizer=optimizer,
137
+ warmup_epochs=int(
138
+ self.trainer.estimated_stepping_batches
139
+ / self.trainer.max_epochs
140
+ * 10
141
+ ),
142
+ max_epochs=int(self.trainer.estimated_stepping_batches),
143
+ ),
144
+ "interval": "step",
145
+ }
146
+ return [optimizer], [scheduler]
147
+
148
+ def configure_gradient_clipping(
149
+ self,
150
+ optimizer,
151
+ gradient_clip_val: Union[int, float, None] = None,
152
+ gradient_clip_algorithm: Union[str, None] = None,
153
+ ) -> None:
154
+ self.clip_gradients(
155
+ optimizer=optimizer,
156
+ gradient_clip_val=3.0,
157
+ gradient_clip_algorithm="norm",
158
+ )
159
+ self.student_projection_head.cancel_last_layer_gradients(self.current_epoch)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ faiss-cpu
3
+ gradio
4
+ lightly
5
+ lightning
6
+ numpy
7
+ Pillow
8
+ torchvision