Soutrik commited on
Commit
8931c9e
·
1 Parent(s): beb5662

infer script

Browse files
artifacts/image_prediction.png ADDED
configs/infer.yaml CHANGED
@@ -4,8 +4,8 @@
4
  # order of defaults determines the order in which configs override each other
5
  defaults:
6
  - _self_
7
- - data: dogbreed
8
- - model: dogbreed_classifier
9
  - callbacks: default
10
  - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
  - trainer: default
@@ -13,7 +13,7 @@ defaults:
13
  - hydra: default
14
  # experiment configs allow for version control of specific hyperparameters
15
  # e.g. best hyperparameters for given model and datamodule
16
- - experiment: dogbreed_experiment
17
  # debugging config (enable through command line, e.g. `python train.py debug=default)
18
  - debug: null
19
 
@@ -39,4 +39,4 @@ ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt
39
  seed: 42
40
 
41
  # name of the experiment
42
- name: "dogbreed_experiment"
 
4
  # order of defaults determines the order in which configs override each other
5
  defaults:
6
  - _self_
7
+ - data: catdog
8
+ - model: catdog_model
9
  - callbacks: default
10
  - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
  - trainer: default
 
13
  - hydra: default
14
  # experiment configs allow for version control of specific hyperparameters
15
  # e.g. best hyperparameters for given model and datamodule
16
+ - experiment: catdog_experiment
17
  # debugging config (enable through command line, e.g. `python train.py debug=default)
18
  - debug: null
19
 
 
39
  seed: 42
40
 
41
  # name of the experiment
42
+ name: "catdog_experiment"
configs/test.yaml DELETED
@@ -1,34 +0,0 @@
1
- defaults:
2
- - _self_
3
- - data: dogbreed
4
- - model: dogbreed_classifier
5
- - callbacks: default
6
- - logger: null
7
- - trainer: default
8
- - paths: default # This should map to another config file if using hydra to merge
9
-
10
- task_name: train
11
- tags:
12
- - dev
13
- train: true
14
- test: true
15
- ckpt_path: null
16
- seed: 42
17
-
18
- # Ensure paths section is present
19
- paths:
20
- root_dir: ./ # Project root directory
21
- data_dir: ./data # Path to your dataset
22
- log_dir: ./logs # Path to logs directory
23
- ckpt_dir: ./checkpoints # Path to checkpoints
24
- artifact_dir: ./artifacts # Path to save artifacts
25
- kaggle_dir: khushikhushikhushi/dog-breed-image-dataset # Path for Kaggle dataset
26
-
27
- # Ensure data section is present
28
- data:
29
- num_workers: 4
30
- batch_size: 32
31
- image_size: 224
32
- train_split: 0.8
33
- val_split: 0.1
34
- test_split: 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/train.yaml CHANGED
@@ -38,7 +38,7 @@ ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt
38
  seed: 42
39
 
40
  # name of the experiment
41
- name: "dogbreed_experiment"
42
 
43
  # optimization metric
44
  optimization_metric: "val_acc"
 
38
  seed: 42
39
 
40
  # name of the experiment
41
+ name: "catdog_experiment"
42
 
43
  # optimization metric
44
  optimization_metric: "val_acc"
image.jpg ADDED
src/infer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import requests
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from src.models.catdog_model import ViTTinyClassifier
9
+ from src.utils.logging_utils import setup_logger, task_wrapper, get_rich_progress
10
+ import hydra
11
+ from omegaconf import DictConfig, OmegaConf
12
+ from dotenv import load_dotenv, find_dotenv
13
+ import rootutils
14
+ import time
15
+ from loguru import logger
16
+
17
+ # Load environment variables
18
+ load_dotenv(find_dotenv(".env"))
19
+
20
+ # Setup root directory
21
+ root = rootutils.setup_root(__file__, indicator=".project-root")
22
+
23
+
24
+ @task_wrapper
25
+ def load_image(image_path: str, image_size: int):
26
+ """Load and preprocess an image."""
27
+ img = Image.open(image_path).convert("RGB")
28
+ transform = transforms.Compose(
29
+ [
30
+ transforms.Resize((image_size, image_size)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
+ ]
34
+ )
35
+ return img, transform(img).unsqueeze(0)
36
+
37
+
38
+ @task_wrapper
39
+ def infer(model: torch.nn.Module, image_tensor: torch.Tensor, classes: list):
40
+ """Perform inference on the provided image tensor."""
41
+ model.eval()
42
+ with torch.no_grad():
43
+ output = model(image_tensor)
44
+ probabilities = F.softmax(output, dim=1)
45
+ predicted_class = torch.argmax(probabilities, dim=1).item()
46
+
47
+ predicted_label = classes[predicted_class]
48
+ confidence = probabilities[0][predicted_class].item()
49
+ return predicted_label, confidence
50
+
51
+
52
+ @task_wrapper
53
+ def save_prediction_image(
54
+ image: Image.Image, predicted_label: str, confidence: float, output_path: Path
55
+ ):
56
+ """Save the image with the prediction overlay."""
57
+ plt.figure(figsize=(10, 6))
58
+ plt.imshow(image)
59
+ plt.axis("off")
60
+ plt.title(f"Predicted: {predicted_label} (Confidence: {confidence:.2f})")
61
+ plt.tight_layout()
62
+ output_path.parent.mkdir(parents=True, exist_ok=True)
63
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
64
+ plt.close()
65
+
66
+
67
+ @task_wrapper
68
+ def download_image(cfg: DictConfig):
69
+ """Download an image from the web for inference."""
70
+ url = "https://github.com/laxmimerit/dog-cat-full-dataset/raw/master/data/train/dogs/dog.1.jpg"
71
+ headers = {
72
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36",
73
+ }
74
+ response = requests.get(url, headers=headers, allow_redirects=True)
75
+ if response.status_code == 200:
76
+ image_path = Path(cfg.paths.root_dir) / "image.jpg"
77
+ with open(image_path, "wb") as file:
78
+ file.write(response.content)
79
+ time.sleep(5)
80
+ print(f"Image downloaded successfully as {image_path}!")
81
+ else:
82
+ logger.error(f"Failed to download image. Status code: {response.status_code}")
83
+
84
+
85
+ @hydra.main(config_path="../configs", config_name="infer", version_base="1.1")
86
+ def main_infer(cfg: DictConfig):
87
+ # Print the configuration
88
+ logger.info(OmegaConf.to_yaml(cfg))
89
+ setup_logger(Path(cfg.paths.log_dir) / "infer.log")
90
+
91
+ # Remove the train_done flag if it exists
92
+ flag_file = Path(cfg.paths.ckpt_dir) / "train_done.flag"
93
+ if flag_file.exists():
94
+ flag_file.unlink()
95
+
96
+ # Load the trained model
97
+ model = ViTTinyClassifier.load_from_checkpoint(checkpoint_path=cfg.ckpt_path)
98
+ classes = ["dog", "cat"]
99
+
100
+ # Download an image for inference
101
+ download_image(cfg)
102
+
103
+ # Load images from directory and perform inference
104
+ image_files = [
105
+ f
106
+ for f in Path(cfg.paths.root_dir).iterdir()
107
+ if f.suffix in {".jpg", ".jpeg", ".png"}
108
+ ]
109
+
110
+ with get_rich_progress() as progress:
111
+ task = progress.add_task("[green]Processing images...", total=len(image_files))
112
+
113
+ for image_file in image_files:
114
+ img, img_tensor = load_image(image_file, cfg.data.image_size)
115
+ predicted_label, confidence = infer(
116
+ model, img_tensor.to(model.device), classes
117
+ )
118
+ output_file = (
119
+ Path(cfg.paths.artifact_dir) / f"{image_file.stem}_prediction.png"
120
+ )
121
+ save_prediction_image(img, predicted_label, confidence, output_file)
122
+ progress.advance(task)
123
+
124
+ logger.info(f"Processed {image_file}: {predicted_label} ({confidence:.2f})")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main_infer()