Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
·
8931c9e
1
Parent(s):
beb5662
infer script
Browse files- artifacts/image_prediction.png +0 -0
- configs/infer.yaml +4 -4
- configs/test.yaml +0 -34
- configs/train.yaml +1 -1
- image.jpg +0 -0
- src/infer.py +128 -0
artifacts/image_prediction.png
ADDED
configs/infer.yaml
CHANGED
@@ -4,8 +4,8 @@
|
|
4 |
# order of defaults determines the order in which configs override each other
|
5 |
defaults:
|
6 |
- _self_
|
7 |
-
- data:
|
8 |
-
- model:
|
9 |
- callbacks: default
|
10 |
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
11 |
- trainer: default
|
@@ -13,7 +13,7 @@ defaults:
|
|
13 |
- hydra: default
|
14 |
# experiment configs allow for version control of specific hyperparameters
|
15 |
# e.g. best hyperparameters for given model and datamodule
|
16 |
-
- experiment:
|
17 |
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
18 |
- debug: null
|
19 |
|
@@ -39,4 +39,4 @@ ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt
|
|
39 |
seed: 42
|
40 |
|
41 |
# name of the experiment
|
42 |
-
name: "
|
|
|
4 |
# order of defaults determines the order in which configs override each other
|
5 |
defaults:
|
6 |
- _self_
|
7 |
+
- data: catdog
|
8 |
+
- model: catdog_model
|
9 |
- callbacks: default
|
10 |
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
11 |
- trainer: default
|
|
|
13 |
- hydra: default
|
14 |
# experiment configs allow for version control of specific hyperparameters
|
15 |
# e.g. best hyperparameters for given model and datamodule
|
16 |
+
- experiment: catdog_experiment
|
17 |
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
18 |
- debug: null
|
19 |
|
|
|
39 |
seed: 42
|
40 |
|
41 |
# name of the experiment
|
42 |
+
name: "catdog_experiment"
|
configs/test.yaml
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- _self_
|
3 |
-
- data: dogbreed
|
4 |
-
- model: dogbreed_classifier
|
5 |
-
- callbacks: default
|
6 |
-
- logger: null
|
7 |
-
- trainer: default
|
8 |
-
- paths: default # This should map to another config file if using hydra to merge
|
9 |
-
|
10 |
-
task_name: train
|
11 |
-
tags:
|
12 |
-
- dev
|
13 |
-
train: true
|
14 |
-
test: true
|
15 |
-
ckpt_path: null
|
16 |
-
seed: 42
|
17 |
-
|
18 |
-
# Ensure paths section is present
|
19 |
-
paths:
|
20 |
-
root_dir: ./ # Project root directory
|
21 |
-
data_dir: ./data # Path to your dataset
|
22 |
-
log_dir: ./logs # Path to logs directory
|
23 |
-
ckpt_dir: ./checkpoints # Path to checkpoints
|
24 |
-
artifact_dir: ./artifacts # Path to save artifacts
|
25 |
-
kaggle_dir: khushikhushikhushi/dog-breed-image-dataset # Path for Kaggle dataset
|
26 |
-
|
27 |
-
# Ensure data section is present
|
28 |
-
data:
|
29 |
-
num_workers: 4
|
30 |
-
batch_size: 32
|
31 |
-
image_size: 224
|
32 |
-
train_split: 0.8
|
33 |
-
val_split: 0.1
|
34 |
-
test_split: 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/train.yaml
CHANGED
@@ -38,7 +38,7 @@ ckpt_path: ${paths.ckpt_dir}/best-checkpoint.ckpt
|
|
38 |
seed: 42
|
39 |
|
40 |
# name of the experiment
|
41 |
-
name: "
|
42 |
|
43 |
# optimization metric
|
44 |
optimization_metric: "val_acc"
|
|
|
38 |
seed: 42
|
39 |
|
40 |
# name of the experiment
|
41 |
+
name: "catdog_experiment"
|
42 |
|
43 |
# optimization metric
|
44 |
optimization_metric: "val_acc"
|
image.jpg
ADDED
src/infer.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import requests
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
from src.models.catdog_model import ViTTinyClassifier
|
9 |
+
from src.utils.logging_utils import setup_logger, task_wrapper, get_rich_progress
|
10 |
+
import hydra
|
11 |
+
from omegaconf import DictConfig, OmegaConf
|
12 |
+
from dotenv import load_dotenv, find_dotenv
|
13 |
+
import rootutils
|
14 |
+
import time
|
15 |
+
from loguru import logger
|
16 |
+
|
17 |
+
# Load environment variables
|
18 |
+
load_dotenv(find_dotenv(".env"))
|
19 |
+
|
20 |
+
# Setup root directory
|
21 |
+
root = rootutils.setup_root(__file__, indicator=".project-root")
|
22 |
+
|
23 |
+
|
24 |
+
@task_wrapper
|
25 |
+
def load_image(image_path: str, image_size: int):
|
26 |
+
"""Load and preprocess an image."""
|
27 |
+
img = Image.open(image_path).convert("RGB")
|
28 |
+
transform = transforms.Compose(
|
29 |
+
[
|
30 |
+
transforms.Resize((image_size, image_size)),
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
33 |
+
]
|
34 |
+
)
|
35 |
+
return img, transform(img).unsqueeze(0)
|
36 |
+
|
37 |
+
|
38 |
+
@task_wrapper
|
39 |
+
def infer(model: torch.nn.Module, image_tensor: torch.Tensor, classes: list):
|
40 |
+
"""Perform inference on the provided image tensor."""
|
41 |
+
model.eval()
|
42 |
+
with torch.no_grad():
|
43 |
+
output = model(image_tensor)
|
44 |
+
probabilities = F.softmax(output, dim=1)
|
45 |
+
predicted_class = torch.argmax(probabilities, dim=1).item()
|
46 |
+
|
47 |
+
predicted_label = classes[predicted_class]
|
48 |
+
confidence = probabilities[0][predicted_class].item()
|
49 |
+
return predicted_label, confidence
|
50 |
+
|
51 |
+
|
52 |
+
@task_wrapper
|
53 |
+
def save_prediction_image(
|
54 |
+
image: Image.Image, predicted_label: str, confidence: float, output_path: Path
|
55 |
+
):
|
56 |
+
"""Save the image with the prediction overlay."""
|
57 |
+
plt.figure(figsize=(10, 6))
|
58 |
+
plt.imshow(image)
|
59 |
+
plt.axis("off")
|
60 |
+
plt.title(f"Predicted: {predicted_label} (Confidence: {confidence:.2f})")
|
61 |
+
plt.tight_layout()
|
62 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
63 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
64 |
+
plt.close()
|
65 |
+
|
66 |
+
|
67 |
+
@task_wrapper
|
68 |
+
def download_image(cfg: DictConfig):
|
69 |
+
"""Download an image from the web for inference."""
|
70 |
+
url = "https://github.com/laxmimerit/dog-cat-full-dataset/raw/master/data/train/dogs/dog.1.jpg"
|
71 |
+
headers = {
|
72 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36",
|
73 |
+
}
|
74 |
+
response = requests.get(url, headers=headers, allow_redirects=True)
|
75 |
+
if response.status_code == 200:
|
76 |
+
image_path = Path(cfg.paths.root_dir) / "image.jpg"
|
77 |
+
with open(image_path, "wb") as file:
|
78 |
+
file.write(response.content)
|
79 |
+
time.sleep(5)
|
80 |
+
print(f"Image downloaded successfully as {image_path}!")
|
81 |
+
else:
|
82 |
+
logger.error(f"Failed to download image. Status code: {response.status_code}")
|
83 |
+
|
84 |
+
|
85 |
+
@hydra.main(config_path="../configs", config_name="infer", version_base="1.1")
|
86 |
+
def main_infer(cfg: DictConfig):
|
87 |
+
# Print the configuration
|
88 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
89 |
+
setup_logger(Path(cfg.paths.log_dir) / "infer.log")
|
90 |
+
|
91 |
+
# Remove the train_done flag if it exists
|
92 |
+
flag_file = Path(cfg.paths.ckpt_dir) / "train_done.flag"
|
93 |
+
if flag_file.exists():
|
94 |
+
flag_file.unlink()
|
95 |
+
|
96 |
+
# Load the trained model
|
97 |
+
model = ViTTinyClassifier.load_from_checkpoint(checkpoint_path=cfg.ckpt_path)
|
98 |
+
classes = ["dog", "cat"]
|
99 |
+
|
100 |
+
# Download an image for inference
|
101 |
+
download_image(cfg)
|
102 |
+
|
103 |
+
# Load images from directory and perform inference
|
104 |
+
image_files = [
|
105 |
+
f
|
106 |
+
for f in Path(cfg.paths.root_dir).iterdir()
|
107 |
+
if f.suffix in {".jpg", ".jpeg", ".png"}
|
108 |
+
]
|
109 |
+
|
110 |
+
with get_rich_progress() as progress:
|
111 |
+
task = progress.add_task("[green]Processing images...", total=len(image_files))
|
112 |
+
|
113 |
+
for image_file in image_files:
|
114 |
+
img, img_tensor = load_image(image_file, cfg.data.image_size)
|
115 |
+
predicted_label, confidence = infer(
|
116 |
+
model, img_tensor.to(model.device), classes
|
117 |
+
)
|
118 |
+
output_file = (
|
119 |
+
Path(cfg.paths.artifact_dir) / f"{image_file.stem}_prediction.png"
|
120 |
+
)
|
121 |
+
save_prediction_image(img, predicted_label, confidence, output_file)
|
122 |
+
progress.advance(task)
|
123 |
+
|
124 |
+
logger.info(f"Processed {image_file}: {predicted_label} ({confidence:.2f})")
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
main_infer()
|