sdutta28 commited on
Commit
32cc554
1 Parent(s): 293f4bc

HF Changes

Browse files
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ service/static/PlantDiseaseClassificationModel/best.ckpt filter=lfs diff=lfs merge=lfs -text
37
+ service/static/PlantDiseaseOODModel filter=lfs diff=lfs merge=lfs -text
38
+ service/static/PlantDiseaseOODModel/best.ckpt filter=lfs diff=lfs merge=lfs -text
39
+ service/static/PlantDiseaseClassificationModel filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv
2
+ ml/input
3
+ .DS_Store
4
+ logs
5
+
6
+ .env
7
+ lightning_logs
8
+ input/*
9
+ !input/.gitkeep
10
+ logs/*
11
+ *.pyc
12
+ venv
13
+ .venv
14
+
15
+ __pycache__
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CDIApp
3
+ emoji: 🏆
4
+ colorFrom: red
5
+ colorTo: black
6
+ sdk: gradio
7
+ sdk_version: 5.4.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # Plant Disease Classification
14
+ ## Generated by Claude v3
15
+
16
+ This is a deep learning project for classifying plant diseases from images. It uses a convolutional neural network trained on a dataset of plant disease images.
17
+
18
+ ## Features
19
+
20
+ - Train a disease classification model on your own dataset
21
+ - Evaluate model performance on a test set
22
+ - Run inference on new images through a web interface
23
+
24
+ ## Installation
25
+
26
+ 1. Clone the repository:
27
+
28
+ ```
29
+ git clone https://github.com/username/plant-disease-classifier.git
30
+ ```
31
+
32
+ 2. Install dependencies:
33
+
34
+ ```
35
+ cd plant-disease-classifier
36
+ pip install -r requirements.txt
37
+ ```
38
+
39
+ ## Usage
40
+
41
+ ### Data Preparation
42
+
43
+ Organize your image data into folders for each disease class, for example:
44
+
45
+ ```
46
+ data/
47
+ healthy/
48
+ image1.jpg
49
+ image2.jpg
50
+ ...
51
+ disease_a/
52
+ image1.jpg
53
+ image2.jpg
54
+ ...
55
+ disease_b/
56
+ ...
57
+ ```
58
+
59
+ ### Training
60
+
61
+ To train the model, run:
62
+
63
+ ```
64
+ python train_classifier.py --data_dir data/
65
+ ```
66
+
67
+ This will save the trained model to the `models/` directory.
68
+
69
+ ### Evaluation
70
+
71
+ Evaluate the model on a test set:
72
+
73
+ ```
74
+ python evaluate.py --data_dir data/test/ --model models/classifier.pth
75
+ ```
76
+
77
+ This will print the classification metrics.
78
+
79
+ ### Inference
80
+
81
+ To launch the web interface for running inference on new images:
82
+
83
+ ```
84
+ python app.py
85
+ ```
86
+
87
+ Then open `http://localhost:5000` in your web browser. You can upload images and see the predicted disease class.
88
+
89
+ ## Contributing
90
+
91
+ Contributions are welcome! Please open an issue or submit a pull request.
92
+
93
+ ## License
94
+
95
+ This project is licensed under the [MIT License](LICENSE).
acfg/appconfig.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from acfg.modelconfig import ModelConfig
6
+ from ml.app.anomaly import DiseaseOODModule
7
+ from ml.app.models.classification import DiseaseClassificationModel
8
+ from ml.app.models.ood import Autoencoder
9
+
10
+
11
+ def get_device():
12
+ """Gets the appropriate device for PyTorch operations.
13
+
14
+ Checks for CUDA GPU availability first, then Apple M1/M2 MPS, falling back to CPU.
15
+
16
+ Returns:
17
+ tuple: A tuple containing two strings:
18
+ - First string indicates the device type ('cuda', 'mps', or 'cpu')
19
+ - Second string indicates the specific device ('cuda:0', 'mps', or 'cpu')
20
+ """
21
+ if torch.cuda.is_available():
22
+ return "cuda", "cuda:0"
23
+ elif torch.backends.mps.is_available():
24
+ return "mps", "mps"
25
+ else:
26
+ return "cpu", "cpu"
27
+
28
+
29
+ @dataclass
30
+ class ServiceConfig:
31
+ LLM_MODEL_KEY = "gemini"
32
+ OOD_THRESHOLD = 0.034
33
+ ID2LABEL = (
34
+ "Apple scab",
35
+ "Apple Black rot",
36
+ "Apple Cedar rust",
37
+ "Apple healthy",
38
+ "Blueberry healthy",
39
+ "Cherry Powdery mildew",
40
+ "Cherry healthy",
41
+ "Corn Cercospora leaf spot Gray leaf spot",
42
+ "Corn Common rust",
43
+ "Corn Northern Leaf Blight",
44
+ "Corn healthy",
45
+ "Grape Black rot",
46
+ "Grape Esca Black Measles",
47
+ "Grape Leaf blight Isariopsis Leaf Spot",
48
+ "Grape healthy",
49
+ "Orange Haunglongbing Citrus greening",
50
+ "Peach Bacterial spot",
51
+ "Peach healthy",
52
+ "Pepper bell Bacterial spot",
53
+ "Pepper bell healthy",
54
+ "Potato Early blight",
55
+ "Potato Late blight",
56
+ "Potato healthy",
57
+ "Raspberry healthy",
58
+ "Soybean healthy",
59
+ "Squash Powdery mildew",
60
+ "Strawberry Leaf scorch",
61
+ "Strawberry healthy",
62
+ "Tomato Bacterial spot",
63
+ "Tomato Early blight",
64
+ "Tomato Late blight",
65
+ "Tomato Leaf Mold",
66
+ "Tomato Septoria leaf spot",
67
+ "Tomato Spider mites Two spotted spider mite",
68
+ "Tomato Target Spot",
69
+ "Tomato Yellow Leaf Curl Virus",
70
+ "Tomato mosaic virus",
71
+ "Tomato healthy",
72
+ )
73
+
74
+
75
+ def load_my_model(checkpoint_path, model):
76
+ """Loads a PyTorch model from a checkpoint file with state dict key remapping.
77
+
78
+ Args:
79
+ checkpoint_path (str): Path to the checkpoint file containing model weights
80
+ model (torch.nn.Module): PyTorch model instance to load the weights into
81
+
82
+ Returns:
83
+ torch.nn.Module: Model with loaded weights
84
+
85
+ Notes:
86
+ - Loads checkpoint using appropriate device (CUDA/MPS/CPU)
87
+ - Remaps state dict keys by removing 'model.model.' prefix
88
+ - Only keeps state dict entries that start with 'model.model.'
89
+ """
90
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(get_device()[1]))
91
+ state_dict = checkpoint["state_dict"]
92
+
93
+ # Create a new state dict with the correct keys
94
+ new_state_dict = {}
95
+ for k, v in state_dict.items():
96
+ if k.startswith("model.model."):
97
+ new_key = k.replace("model.model.", "model.")
98
+ new_state_dict[new_key] = v
99
+
100
+ # Load the new state dict
101
+ model.load_state_dict(new_state_dict)
102
+ return model
103
+
104
+
105
+ CLF_MODEL = DiseaseClassificationModel(model_name=ModelConfig.PRETRAINED_MODEL_NAME)
106
+ CLF_MODEL = load_my_model(ModelConfig.CLASSIFY_MODEL_CHECKPOINT, CLF_MODEL).to(
107
+ get_device()[1]
108
+ )
109
+
110
+ OOD_MODEL = DiseaseOODModule.load_from_checkpoint(
111
+ ModelConfig.OOD_MODEL_CHECKPOINT
112
+ ).model.to(get_device()[1])
acfg/modelconfig.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass
6
+ class ModelConfig:
7
+ TRAIN_DATA_PATH: str = "ml/input/PlantDiseaseClassificationDataset/train"
8
+ VAL_DATA_PATH: str = "ml/input/PlantDiseaseClassificationDataset/valid"
9
+ TEST_DATA_PATH: str = "ml/input/PlantDiseaseClassificationDataset/test"
10
+ N_INPUT_CHANNELS = 3
11
+ IMG_SIZE: int = 224
12
+ BATCH_SIZE: int = 128
13
+ NUM_OUTPUT_CLASSES: int = 38
14
+ NUM_WORKERS: int = os.cpu_count() // 2
15
+ IMG_STD: tuple = (0.485, 0.456, 0.406)
16
+ IMG_MEAN: tuple = (0.229, 0.224, 0.225)
17
+ VAL_LOSS: str = "VL"
18
+ PRETRAINED_MODEL_NAME: str = "mobilenet_v3_small"
19
+ CLASSIFY_MODEL_CHECKPOINT: str = (
20
+ "service/static/PlantDiseaseClassificationModel/best.ckpt"
21
+ )
22
+ OOD_MODEL_CHECKPOINT: str = "service/static/PlantDiseaseOODModel/best.ckpt"
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from service.predict import workflow
4
+
5
+
6
+ def process_image(image):
7
+ disease_name, remedy = workflow(image)
8
+ return disease_name, remedy
9
+
10
+
11
+ # Create the Gradio interface
12
+ iface = gr.Interface(
13
+ fn=process_image,
14
+ inputs=gr.Image(
15
+ image_mode="RGB",
16
+ sources="upload",
17
+ label="Upload Plant Disease Image",
18
+ show_download_button=True,
19
+ type="pil",
20
+ ),
21
+ outputs=[
22
+ gr.Textbox(label="Prediction", placeholder="Disease Prediction"),
23
+ gr.Markdown(label="Remedy"),
24
+ ],
25
+ title="Classify Plant Diseases and Get Remedies",
26
+ )
27
+
28
+ if __name__ == "__main__":
29
+ iface.launch()
ml/.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ input/*
2
+ !input/.gitkeep
3
+ logs/*
4
+ *.pyc
5
+ venv
6
+ .venv
ml/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Model training
ml/app/__init__.py ADDED
File without changes
ml/app/anomaly.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from lightning import LightningModule
5
+ from torch.nn import functional as F
6
+
7
+ from ml.app.models.ood import Autoencoder
8
+
9
+
10
+ class DiseaseOODModule(LightningModule):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+ self.model = Autoencoder(in_channels=3, out_channels=3)
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ return self.model(x)
17
+
18
+ def training_step(self, batch: Tuple[torch.Tensor, ...], batch_idx: int) -> torch.Tensor:
19
+ loss = self._loss(batch)
20
+ self.log("TL", loss, prog_bar=True)
21
+ return loss
22
+
23
+ def configure_optimizers(self):
24
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
25
+ return optimizer
26
+
27
+ def validation_step(self, batch: Tuple[torch.Tensor, ...], batch_idx: int):
28
+ loss = self._loss(batch)
29
+ self.log("VL", loss, prog_bar=True)
30
+ return loss
31
+
32
+ def _loss(self, batch):
33
+ images, _ = batch
34
+ outputs = self(images)
35
+ loss = F.mse_loss(outputs, images)
36
+ return loss
ml/app/data.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning import LightningDataModule
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import transforms as T
4
+ from torchvision.datasets import ImageFolder
5
+
6
+ from acfg.modelconfig import ModelConfig
7
+
8
+
9
+ class ImageDataModule(LightningDataModule):
10
+ def __init__(
11
+ self,
12
+ train_path: str,
13
+ val_path: str,
14
+ test_path: str,
15
+ batch_size: int,
16
+ img_size: int,
17
+ ):
18
+ super().__init__()
19
+ self.train_path = train_path
20
+ self.val_path = val_path
21
+ self.test_path = test_path
22
+ self.batch_size = batch_size
23
+ self.img_size = img_size
24
+ self.train_transforms = self._get_train_transforms()
25
+ self.val_transforms = self._get_val_transforms()
26
+ self.test_transforms = self._get_test_transforms()
27
+
28
+ def _get_train_transforms(self):
29
+ return T.Compose(
30
+ [
31
+ T.Resize(self.img_size),
32
+ T.RandomHorizontalFlip(p=0.5),
33
+ T.RandomVerticalFlip(p=0.5),
34
+ T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
35
+ T.ToTensor(),
36
+ T.Normalize(mean=ModelConfig.IMG_MEAN, std=ModelConfig.IMG_STD),
37
+ ]
38
+ )
39
+
40
+ def _get_val_transforms(self):
41
+ return T.Compose(
42
+ [
43
+ T.Resize(self.img_size),
44
+ T.ToTensor(),
45
+ T.Normalize(mean=ModelConfig.IMG_MEAN, std=ModelConfig.IMG_STD),
46
+ ]
47
+ )
48
+
49
+ def _get_test_transforms(self):
50
+ return T.Compose(
51
+ [
52
+ T.Resize(self.img_size),
53
+ T.ToTensor(),
54
+ T.Normalize(mean=ModelConfig.IMG_MEAN, std=ModelConfig.IMG_STD),
55
+ ]
56
+ )
57
+
58
+ def setup(self, stage=None):
59
+ if stage == "fit" or stage is None:
60
+ self.train_data = ImageFolder(root=self.train_path, transform=self.train_transforms)
61
+ self.val_data = ImageFolder(root=self.val_path, transform=self.val_transforms)
62
+ if stage == "test" or stage is None:
63
+ self.test_data = ImageFolder(root=self.test_path, transform=self.test_transforms)
64
+
65
+ def train_dataloader(self):
66
+ return DataLoader(
67
+ self.train_data,
68
+ batch_size=self.batch_size,
69
+ shuffle=True,
70
+ persistent_workers=True,
71
+ pin_memory=True,
72
+ num_workers=ModelConfig.NUM_WORKERS,
73
+ )
74
+
75
+ def val_dataloader(self):
76
+ return DataLoader(
77
+ self.val_data,
78
+ batch_size=self.batch_size,
79
+ persistent_workers=True,
80
+ pin_memory=True,
81
+ num_workers=ModelConfig.NUM_WORKERS,
82
+ )
83
+
84
+ def test_dataloader(self):
85
+ return DataLoader(
86
+ self.test_data,
87
+ batch_size=self.batch_size,
88
+ persistent_workers=True,
89
+ pin_memory=True,
90
+ num_workers=ModelConfig.NUM_WORKERS,
91
+ )
ml/app/lm.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Any
2
+
3
+ import torch
4
+ from lightning import LightningModule
5
+ from torch import nn, Tensor
6
+ from torch.nn import functional as F
7
+ from torch.optim import AdamW
8
+
9
+
10
+ class ClassificationModule(LightningModule):
11
+ def __init__(
12
+ self,
13
+ model: nn.Module,
14
+ num_classes: int,
15
+ ) -> None:
16
+ super().__init__()
17
+ self.model = model
18
+ self.num_classes = num_classes
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ return self.model(x)
22
+
23
+ def configure_optimizers(self):
24
+ # Low lr as we would be fine tuning a backbone
25
+ optimizer = AdamW(self.parameters(), lr=1e-5)
26
+ return optimizer
27
+
28
+ def training_step(self, batch: Tuple[torch.Tensor, ...], batch_idx: int) -> torch.Tensor:
29
+ images, labels = batch
30
+ outputs = self(images)
31
+ loss = F.cross_entropy(outputs, labels)
32
+ self.log("TL", loss, prog_bar=True)
33
+ return loss
34
+
35
+ def validation_step(self, batch: Tuple[torch.Tensor, ...], batch_idx: int) -> dict[str, Tensor | float | Any]:
36
+ images, labels = batch
37
+ outputs = self(images)
38
+ loss = F.cross_entropy(outputs, labels)
39
+ acc = self._accuracy(labels, outputs)
40
+ self.log("VL", loss, prog_bar=True)
41
+ self.log("VA", acc, prog_bar=True)
42
+ return {"VL": loss, "VA": acc}
43
+
44
+ @staticmethod
45
+ def _accuracy(labels, outputs):
46
+ preds = torch.argmax(outputs, dim=1)
47
+ acc = torch.sum(preds == labels).float() / len(labels)
48
+ return acc
49
+
50
+ def test_step(self, batch: Tuple[torch.Tensor, ...], batch_idx: int) -> None:
51
+ images, labels = batch
52
+ outputs = self(images)
53
+ loss = F.cross_entropy(outputs, labels)
54
+ self.log("TL", loss)
ml/app/models/__init__.py ADDED
File without changes
ml/app/models/classification.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torchvision.models import EfficientNet_B0_Weights, ResNet50_Weights, MobileNet_V3_Small_Weights
5
+
6
+ from acfg.modelconfig import ModelConfig
7
+
8
+
9
+ # TODO: Uncomment if needed
10
+ # Pytorch fix for hash mismatch
11
+ # def get_state_dict(self, *args, **kwargs):
12
+ # kwargs.pop("check_hash")
13
+ # return load_state_dict_from_url(self.url, *args, **kwargs)
14
+
15
+
16
+ # WeightsEnum.get_state_dict = get_state_dict
17
+
18
+
19
+ class MLPHead(nn.Module):
20
+ def __init__(self, in_features: int, num_output_classes: int) -> None:
21
+ super().__init__()
22
+ self.classifier = nn.Sequential(
23
+ nn.Linear(in_features, 2048),
24
+ nn.GELU(),
25
+ nn.Dropout(p=0.5),
26
+ nn.Linear(2048, num_output_classes),
27
+ )
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ return self.classifier(x)
31
+
32
+
33
+ class PretrainedModelFactory:
34
+ @staticmethod
35
+ def _freeze_pretrained_weights(model):
36
+ for param in model.parameters():
37
+ param.requires_grad = False
38
+
39
+ @staticmethod
40
+ def _efficientnet_b0():
41
+ model = torchvision.models.efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
42
+ model.classifier = MLPHead(
43
+ in_features=model.classifier[1].in_features,
44
+ num_output_classes=ModelConfig.NUM_OUTPUT_CLASSES,
45
+ )
46
+ return model
47
+
48
+ @staticmethod
49
+ def _resnet_50():
50
+ model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
51
+ model.fc = MLPHead(
52
+ in_features=model.fc.in_features,
53
+ num_output_classes=ModelConfig.NUM_OUTPUT_CLASSES,
54
+ )
55
+ return model
56
+
57
+ @staticmethod
58
+ def _mobilenet_v3_small():
59
+ model = torchvision.models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
60
+ model.classifier = MLPHead(
61
+ in_features=model.classifier[0].in_features,
62
+ num_output_classes=ModelConfig.NUM_OUTPUT_CLASSES,
63
+ )
64
+ return model
65
+ @staticmethod
66
+ def _vit_b_16():
67
+ raise NotImplementedError
68
+
69
+ def __init__(self):
70
+ self.available_models = {
71
+ "efficientnet_b0": self._efficientnet_b0,
72
+ "resnet_50": self._resnet_50,
73
+ "vit_b_16": self._vit_b_16,
74
+ "mobilenet_v3_small": self._mobilenet_v3_small,
75
+ }
76
+
77
+ def get_model(self, model_name: str) -> nn.Module:
78
+ if model_name not in self.available_models:
79
+ raise ValueError(f"Model '{model_name}' not available. Choose from {self.available_models.keys()}")
80
+ return self.available_models[model_name]()
81
+
82
+
83
+ class DiseaseClassificationModel(nn.Module):
84
+ def __init__(self, model_name: str) -> None:
85
+ super().__init__()
86
+ factory = PretrainedModelFactory()
87
+ self.model = factory.get_model(model_name)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ return self.model(x)
ml/app/models/ood.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class ConvADN(nn.Module):
6
+ def __init__(
7
+ self,
8
+ in_channels: int,
9
+ out_channels: int,
10
+ kernel=2,
11
+ stride=2,
12
+ dilation=1,
13
+ padding=0,
14
+ p_drop=0.2,
15
+ is_transpose: bool = False,
16
+ ):
17
+ super().__init__()
18
+ self.model = nn.Sequential(
19
+ (nn.Conv2d if not is_transpose else nn.ConvTranspose2d)(
20
+ in_channels=in_channels,
21
+ out_channels=out_channels,
22
+ kernel_size=kernel,
23
+ stride=stride,
24
+ dilation=dilation,
25
+ padding=padding,
26
+ ),
27
+ nn.GELU(),
28
+ nn.Dropout(p_drop),
29
+ nn.InstanceNorm3d(num_features=out_channels),
30
+ )
31
+
32
+ def forward(self, x):
33
+ return self.model(x)
34
+
35
+
36
+ class Encoder(nn.Module):
37
+ def __init__(self, in_channels: int = 3):
38
+ super().__init__()
39
+ self.model = nn.Sequential(
40
+ ConvADN(in_channels, 32, kernel=2, stride=2, padding=0),
41
+ ConvADN(32, 64, kernel=2, stride=2, padding=0),
42
+ ConvADN(64, 128, kernel=2, stride=2, padding=0),
43
+ ConvADN(128, 256, kernel=2, stride=2, padding=0),
44
+ )
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ return self.model(x)
48
+
49
+
50
+ class Decoder(nn.Module):
51
+ def __init__(self, out_channels: int = 3):
52
+ super().__init__()
53
+ self.model = nn.Sequential(
54
+ ConvADN(256, 128, kernel=2, stride=2, padding=0, is_transpose=True),
55
+ ConvADN(128, 64, kernel=2, stride=2, padding=0, is_transpose=True),
56
+ ConvADN(64, 32, kernel=2, stride=2, padding=0, is_transpose=True),
57
+ ConvADN(32, out_channels, kernel=2, stride=2, padding=0, is_transpose=True),
58
+ )
59
+ self.output = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ x = self.model(x)
63
+ return self.output(x)
64
+
65
+
66
+ class Autoencoder(nn.Module):
67
+ def __init__(
68
+ self,
69
+ in_channels: int,
70
+ out_channels: int,
71
+ ) -> None:
72
+ super().__init__()
73
+ self.encoder = Encoder(in_channels)
74
+ self.decoder = Decoder(out_channels)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ x = self.encoder(x)
78
+ x = self.decoder(x)
79
+ return x
ml/pyproject.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [tool.black]
2
+ line-length = 120
requirements.txt ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.3
3
+ aiohttp==3.10.10
4
+ aiosignal==1.3.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ attrs==24.2.0
8
+ certifi==2024.8.30
9
+ charset-normalizer==3.4.0
10
+ click==8.1.7
11
+ fastapi==0.115.4
12
+ ffmpy==0.4.0
13
+ filelock==3.16.1
14
+ frozenlist==1.5.0
15
+ fsspec==2024.10.0
16
+ gradio==5.4.0
17
+ gradio_client==1.4.2
18
+ h11==0.14.0
19
+ httpcore==1.0.6
20
+ httpx==0.27.2
21
+ huggingface-hub==0.26.2
22
+ idna==3.10
23
+ Jinja2==3.1.4
24
+ lightning==2.4.0
25
+ lightning-utilities==0.11.8
26
+ markdown-it-py==3.0.0
27
+ MarkupSafe==2.1.5
28
+ mdurl==0.1.2
29
+ mpmath==1.3.0
30
+ multidict==6.1.0
31
+ networkx==3.4.2
32
+ numpy==1.26.4
33
+ orjson==3.10.11
34
+ packaging==24.1
35
+ pandas==2.2.3
36
+ pillow==11.0.0
37
+ propcache==0.2.0
38
+ pydantic==2.9.2
39
+ pydantic_core==2.23.4
40
+ pydub==0.25.1
41
+ Pygments==2.18.0
42
+ python-dateutil==2.9.0.post0
43
+ python-multipart==0.0.12
44
+ pytorch-lightning==2.4.0
45
+ pytz==2024.2
46
+ PyYAML==6.0.2
47
+ requests==2.32.3
48
+ rich==13.9.4
49
+ ruff==0.7.2
50
+ safehttpx==0.1.1
51
+ semantic-version==2.10.0
52
+ setuptools==75.3.0
53
+ shellingham==1.5.4
54
+ six==1.16.0
55
+ sniffio==1.3.1
56
+ starlette==0.41.2
57
+ sympy==1.13.1
58
+ tomlkit==0.12.0
59
+ torch==2.5.1
60
+ torchmetrics==1.5.1
61
+ torchvision==0.20.1
62
+ tqdm==4.66.6
63
+ typer==0.12.5
64
+ typing_extensions==4.12.2
65
+ tzdata==2024.2
66
+ urllib3==2.2.3
67
+ uvicorn==0.32.0
68
+ websockets==12.0
69
+ yarl==1.17.1
service/.gitignore ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python,jetbrains+all
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,jetbrains+all
3
+
4
+ ### JetBrains+all ###
5
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
6
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
7
+
8
+ # User-specific stuff
9
+ .idea/**/workspace.xml
10
+ .idea/**/tasks.xml
11
+ .idea/**/usage.statistics.xml
12
+ .idea/**/dictionaries
13
+ .idea/**/shelf
14
+
15
+ # AWS User-specific
16
+ .idea/**/aws.xml
17
+
18
+ # Generated files
19
+ .idea/**/contentModel.xml
20
+
21
+ # Sensitive or high-churn files
22
+ .idea/**/dataSources/
23
+ .idea/**/dataSources.ids
24
+ .idea/**/dataSources.local.xml
25
+ .idea/**/sqlDataSources.xml
26
+ .idea/**/dynamic.xml
27
+ .idea/**/uiDesigner.xml
28
+ .idea/**/dbnavigator.xml
29
+
30
+ # Gradle
31
+ .idea/**/gradle.xml
32
+ .idea/**/libraries
33
+
34
+ # Gradle and Maven with auto-import
35
+ # When using Gradle or Maven with auto-import, you should exclude module files,
36
+ # since they will be recreated, and may cause churn. Uncomment if using
37
+ # auto-import.
38
+ # .idea/artifacts
39
+ # .idea/compiler.xml
40
+ # .idea/jarRepositories.xml
41
+ # .idea/modules.xml
42
+ # .idea/*.iml
43
+ # .idea/modules
44
+ # *.iml
45
+ # *.ipr
46
+
47
+ # CMake
48
+ cmake-build-*/
49
+
50
+ # Mongo Explorer plugin
51
+ .idea/**/mongoSettings.xml
52
+
53
+ # File-based project format
54
+ *.iws
55
+
56
+ # IntelliJ
57
+ out/
58
+
59
+ # mpeltonen/sbt-idea plugin
60
+ .idea_modules/
61
+
62
+ # JIRA plugin
63
+ atlassian-ide-plugin.xml
64
+
65
+ # Cursive Clojure plugin
66
+ .idea/replstate.xml
67
+
68
+ # SonarLint plugin
69
+ .idea/sonarlint/
70
+
71
+ # Crashlytics plugin (for Android Studio and IntelliJ)
72
+ com_crashlytics_export_strings.xml
73
+ crashlytics.properties
74
+ crashlytics-build.properties
75
+ fabric.properties
76
+
77
+ # Editor-based Rest Client
78
+ .idea/httpRequests
79
+
80
+ # Android studio 3.1+ serialized cache file
81
+ .idea/caches/build_file_checksums.ser
82
+
83
+ ### JetBrains+all Patch ###
84
+ # Ignore everything but code style settings and run configurations
85
+ # that are supposed to be shared within teams.
86
+
87
+ .idea/*
88
+
89
+ !.idea/codeStyles
90
+ !.idea/runConfigurations
91
+
92
+ ### Python ###
93
+ # Byte-compiled / optimized / DLL files
94
+ __pycache__/
95
+ *.py[cod]
96
+ *$py.class
97
+
98
+ # C extensions
99
+ *.so
100
+
101
+ # Distribution / packaging
102
+ .Python
103
+ build/
104
+ develop-eggs/
105
+ dist/
106
+ downloads/
107
+ eggs/
108
+ .eggs/
109
+ lib/
110
+ lib64/
111
+ parts/
112
+ sdist/
113
+ var/
114
+ wheels/
115
+ share/python-wheels/
116
+ *.egg-info/
117
+ .installed.cfg
118
+ *.egg
119
+ MANIFEST
120
+
121
+ # PyInstaller
122
+ # Usually these files are written by a python script from a template
123
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
124
+ *.manifest
125
+ *.spec
126
+
127
+ # Installer logs
128
+ pip-log.txt
129
+ pip-delete-this-directory.txt
130
+
131
+ # Unit test / coverage reports
132
+ htmlcov/
133
+ .tox/
134
+ .nox/
135
+ .coverage
136
+ .coverage.*
137
+ .cache
138
+ nosetests.xml
139
+ coverage.xml
140
+ *.cover
141
+ *.py,cover
142
+ .hypothesis/
143
+ .pytest_cache/
144
+ cover/
145
+
146
+ # Translations
147
+ *.mo
148
+ *.pot
149
+
150
+ # Django stuff:
151
+ *.log
152
+ local_settings.py
153
+ db.sqlite3
154
+ db.sqlite3-journal
155
+
156
+ # Flask stuff:
157
+ instance/
158
+ .webassets-cache
159
+
160
+ # Scrapy stuff:
161
+ .scrapy
162
+
163
+ # Sphinx documentation
164
+ docs/_build/
165
+
166
+ # PyBuilder
167
+ .pybuilder/
168
+ target/
169
+
170
+ # Jupyter Notebook
171
+ .ipynb_checkpoints
172
+
173
+ # IPython
174
+ profile_default/
175
+ ipython_config.py
176
+
177
+ # pyenv
178
+ # For a library or package, you might want to ignore these files since the code is
179
+ # intended to run in multiple environments; otherwise, check them in:
180
+ # .python-version
181
+
182
+ # pipenv
183
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
184
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
185
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
186
+ # install all needed dependencies.
187
+ #Pipfile.lock
188
+
189
+ # poetry
190
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
191
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
192
+ # commonly ignored for libraries.
193
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
194
+ #poetry.lock
195
+
196
+ # pdm
197
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
198
+ #pdm.lock
199
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
200
+ # in version control.
201
+ # https://pdm.fming.dev/#use-with-ide
202
+ .pdm.toml
203
+
204
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
205
+ __pypackages__/
206
+
207
+ # Celery stuff
208
+ celerybeat-schedule
209
+ celerybeat.pid
210
+
211
+ # SageMath parsed files
212
+ *.sage.py
213
+
214
+ # Environments
215
+ .env
216
+ .venv
217
+ env/
218
+ venv/
219
+ ENV/
220
+ env.bak/
221
+ venv.bak/
222
+
223
+ # Spyder project settings
224
+ .spyderproject
225
+ .spyproject
226
+
227
+ # Rope project settings
228
+ .ropeproject
229
+
230
+ # mkdocs documentation
231
+ /site
232
+
233
+ # mypy
234
+ .mypy_cache/
235
+ .dmypy.json
236
+ dmypy.json
237
+
238
+ # Pyre type checker
239
+ .pyre/
240
+
241
+ # pytype static type analyzer
242
+ .pytype/
243
+
244
+ # Cython debug symbols
245
+ cython_debug/
246
+
247
+ # PyCharm
248
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
249
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
250
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
251
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
252
+ #.idea/
253
+
254
+ ### Python Patch ###
255
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
256
+ poetry.toml
257
+
258
+ # ruff
259
+ .ruff_cache/
260
+
261
+ # LSP config files
262
+ pyrightconfig.json
263
+
264
+ # End of https://www.toptal.com/developers/gitignore/api/python,jetbrains+all
service/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Crop Disease Prediction Backend
service/external.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import google.generativeai as genai
4
+
5
+ load_dotenv()
6
+
7
+ SYS_INSTR = "You are a plant disease expert. You will be given queries regarding plant diseases. Always respond in Markdown"
8
+ TXT_PROMPT = "Suggest remedy for the disease in bullet points"
9
+ IMG_TXT_PROMPT = "Based on the given image, suggest the possible disease the plant is suffering from, along with the remedy in 150 words"
10
+
11
+
12
+ def llm_strategy(llm_name, disease_name, image_file=None):
13
+ if llm_name.lower() == "gemini":
14
+ return get_response_from_gemini(disease_name, image_file)
15
+ else:
16
+ raise ValueError(f"LLM {llm_name} not supported")
17
+
18
+
19
+ def get_response_from_gemini(disease_name, image_file=None) -> str:
20
+ genai.configure(api_key=os.environ["GEMINI_API_KEY"])
21
+ model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=SYS_INSTR)
22
+
23
+ generation_config = genai.GenerationConfig(max_output_tokens=300)
24
+
25
+ prompt = [TXT_PROMPT, disease_name]
26
+ if image_file:
27
+ prompt = [IMG_TXT_PROMPT, image_file]
28
+ response = model.generate_content(prompt, generation_config=generation_config)
29
+ return response.text
service/predict.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+ import torch
4
+ from acfg.modelconfig import ModelConfig
5
+ import torchvision.transforms.functional as F
6
+ from torch.nn import functional as Fx
7
+
8
+
9
+ from acfg.appconfig import CLF_MODEL, OOD_MODEL, ServiceConfig, get_device
10
+ from service.external import llm_strategy
11
+
12
+
13
+ def transform_for_prediction(img: PIL.Image):
14
+ """Transforms a PIL image for model prediction.
15
+
16
+ This function applies a series of transformations to prepare an image for model inference:
17
+ 1. Resizes the image to the model's expected input size
18
+ 2. Converts the image to a tensor
19
+ 3. Normalizes the tensor using preconfigured mean and std values
20
+
21
+ Args:
22
+ img (PIL.Image): Input image to transform
23
+
24
+ Returns:
25
+ torch.Tensor: Transformed image tensor ready for model inference
26
+ """
27
+ z = img
28
+ z = F.resize(img, [ModelConfig.IMG_SIZE, ModelConfig.IMG_SIZE])
29
+ z = F.to_tensor(z)
30
+ z = F.normalize(z, mean=ModelConfig.IMG_MEAN, std=ModelConfig.IMG_STD)
31
+ return z.to(get_device()[1])
32
+
33
+
34
+ def classify_disease(image):
35
+ image_tensor = transform_for_prediction(image).unsqueeze(0)
36
+
37
+ with torch.no_grad():
38
+ outputs = CLF_MODEL(image_tensor)
39
+ _, predicted = torch.max(outputs, 1)
40
+ prediction = predicted.item()
41
+
42
+ return ServiceConfig.ID2LABEL[prediction]
43
+
44
+
45
+ def img_in_distribution(image):
46
+ image_tensor = transform_for_prediction(image).unsqueeze(0)
47
+
48
+ with torch.no_grad():
49
+ output = OOD_MODEL(image_tensor)
50
+ mse_loss_value = Fx.mse_loss(output, image_tensor)
51
+ print("MSE", mse_loss_value)
52
+
53
+ return mse_loss_value < ServiceConfig.OOD_THRESHOLD
54
+
55
+
56
+ def workflow(image: np.array):
57
+ if not img_in_distribution(image):
58
+ disease_name = "Unknown"
59
+ remedy = "We do not know the remedy to this one. Sorry!"
60
+ else:
61
+ disease_name = classify_disease(image)
62
+ remedy = "No remedy needed. Plant is Healthy"
63
+ print(disease_name)
64
+
65
+ if "healthy" in disease_name:
66
+ return disease_name, remedy
67
+
68
+ else:
69
+ remedy = llm_strategy(ServiceConfig.LLM_MODEL_KEY, disease_name)
70
+
71
+ return disease_name, remedy
service/static/PlantDiseaseClassificationModel/best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22331788f98e2379081c3734d2e7a2c820df3fbf2029240970e8695e70e10f9e
3
+ size 26490932
service/static/PlantDiseaseOODModel/best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:995a05fa7e42e65f3f62f6fd9a777cce136a7930334a9f94ddb75a125e481f09
3
+ size 4172283
train_classifier.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning import Trainer, seed_everything
2
+ from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
3
+
4
+ from acfg.modelconfig import ModelConfig
5
+ from ml.app.data import ImageDataModule
6
+ from ml.app.lm import ClassificationModule
7
+ from ml.app.models.classification import DiseaseClassificationModel
8
+
9
+
10
+ ckpt_callback = ModelCheckpoint(
11
+ filename="classification" + "_{epoch:02d}_{VA:.2f}",
12
+ save_top_k=1,
13
+ mode="min",
14
+ monitor=ModelConfig.VAL_LOSS,
15
+ )
16
+
17
+ tqdm_callback = TQDMProgressBar(refresh_rate=10)
18
+
19
+
20
+ model = DiseaseClassificationModel(ModelConfig.PRETRAINED_MODEL_NAME)
21
+
22
+ datamodule = ImageDataModule(
23
+ train_path=ModelConfig.TRAIN_DATA_PATH,
24
+ val_path=ModelConfig.VAL_DATA_PATH,
25
+ test_path=ModelConfig.TEST_DATA_PATH,
26
+ batch_size=ModelConfig.BATCH_SIZE,
27
+ img_size=ModelConfig.IMG_SIZE,
28
+ )
29
+
30
+ l_module = ClassificationModule(
31
+ model=model,
32
+ num_classes=ModelConfig.NUM_OUTPUT_CLASSES,
33
+ )
34
+
35
+ seed_everything(42)
36
+ trainer = Trainer(
37
+ max_epochs=25,
38
+ callbacks=[ckpt_callback, tqdm_callback],
39
+ num_sanity_val_steps=2,
40
+ )
41
+
42
+
43
+ if __name__ == "__main__":
44
+ trainer.fit(
45
+ model=l_module,
46
+ datamodule=datamodule,
47
+ )
train_ood.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning import Trainer, seed_everything
2
+ from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
3
+
4
+ from acfg.modelconfig import ModelConfig
5
+ from ml.app.anomaly import DiseaseOODModule
6
+ from ml.app.data import ImageDataModule
7
+
8
+
9
+
10
+ ckpt_callback = ModelCheckpoint(
11
+ filename="ood" + "_{epoch:02d}_{VL:.2f}",
12
+ save_top_k=1,
13
+ mode="min",
14
+ monitor=ModelConfig.VAL_LOSS,
15
+ )
16
+
17
+ tqdm_callback = TQDMProgressBar(refresh_rate=10)
18
+
19
+ datamodule = ImageDataModule(
20
+ train_path=ModelConfig.TRAIN_DATA_PATH,
21
+ val_path=ModelConfig.VAL_DATA_PATH,
22
+ test_path=ModelConfig.TEST_DATA_PATH,
23
+ batch_size=ModelConfig.BATCH_SIZE,
24
+ img_size=ModelConfig.IMG_SIZE,
25
+ )
26
+
27
+ l_module = DiseaseOODModule()
28
+
29
+ seed_everything(42)
30
+
31
+ trainer = Trainer(
32
+ max_epochs=100,
33
+ callbacks=[ckpt_callback, tqdm_callback],
34
+ num_sanity_val_steps=2,
35
+ )
36
+
37
+
38
+ if __name__ == "__main__":
39
+ trainer.fit(model=l_module, datamodule=datamodule)