soutrik commited on
Commit
ea271d0
·
1 Parent(s): 4828471

added: gradio app file and tested on local

Browse files
Files changed (3) hide show
  1. .gradio/certificate.pem +31 -0
  2. app.py +115 -0
  3. src/utils/aws_s3_services.py +14 -4
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ from torchvision import transforms
7
+ from src.models.catdog_model_resnet import ResnetClassifier
8
+ from src.utils.aws_s3_services import S3Handler
9
+ from src.utils.logging_utils import setup_logger
10
+ from loguru import logger
11
+ import rootutils
12
+
13
+ # Load environment variables and configure logger
14
+ setup_logger(Path("./logs") / "gradio_app.log")
15
+ # Setup root directory
16
+ root = rootutils.setup_root(__file__, indicator=".project-root")
17
+
18
+
19
+ class ImageClassifier:
20
+ def __init__(self, cfg):
21
+ self.cfg = cfg
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ self.classes = cfg.labels
24
+
25
+ # Download and load model from S3
26
+ logger.info("Downloading model from S3...")
27
+ s3_handler = S3Handler(bucket_name="deep-bucket-s3")
28
+ s3_handler.download_folder("checkpoints", "checkpoints")
29
+
30
+ logger.info("Loading model checkpoint...")
31
+ self.model = ResnetClassifier.load_from_checkpoint(
32
+ checkpoint_path=cfg.ckpt_path
33
+ )
34
+ self.model = self.model.to(self.device)
35
+ self.model.eval()
36
+
37
+ # Image transform
38
+ self.transform = transforms.Compose(
39
+ [
40
+ transforms.Resize((cfg.data.image_size, cfg.data.image_size)),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(
43
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
44
+ ),
45
+ ]
46
+ )
47
+
48
+ def predict(self, image):
49
+ if image is None:
50
+ return "No image provided.", None
51
+
52
+ # Preprocess the image
53
+ logger.info("Processing input image...")
54
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
55
+
56
+ # Inference
57
+ with torch.no_grad():
58
+ output = self.model(img_tensor)
59
+ probabilities = F.softmax(output, dim=1)
60
+ predicted_class_idx = torch.argmax(probabilities, dim=1).item()
61
+ confidence = probabilities[0][predicted_class_idx].item()
62
+
63
+ predicted_label = self.classes[predicted_class_idx]
64
+ logger.info(f"Prediction: {predicted_label} (Confidence: {confidence:.2f})")
65
+ return predicted_label, confidence
66
+
67
+
68
+ def create_gradio_app(cfg):
69
+ classifier = ImageClassifier(cfg)
70
+
71
+ def classify_image(image):
72
+ """Gradio interface function."""
73
+ predicted_label, confidence = classifier.predict(image)
74
+ if predicted_label:
75
+ return f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
76
+ return "Error during prediction."
77
+
78
+ # Create Gradio interface
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown(
81
+ """
82
+ # Cat vs Dog Classifier
83
+ Upload an image of a cat or a dog to classify it with confidence.
84
+ """
85
+ )
86
+
87
+ with gr.Row():
88
+ with gr.Column():
89
+ input_image = gr.Image(
90
+ label="Input Image", type="pil", image_mode="RGB"
91
+ )
92
+ predict_button = gr.Button("Classify")
93
+ with gr.Column():
94
+ output_text = gr.Textbox(label="Prediction")
95
+
96
+ # Define interaction
97
+ predict_button.click(
98
+ fn=classify_image, inputs=[input_image], outputs=[output_text]
99
+ )
100
+
101
+ return demo
102
+
103
+
104
+ # Hydra config wrapper for launching Gradio app
105
+ if __name__ == "__main__":
106
+ import hydra
107
+ from omegaconf import DictConfig
108
+
109
+ @hydra.main(config_path="configs", config_name="infer", version_base="1.3")
110
+ def main(cfg: DictConfig):
111
+ logger.info("Launching Gradio App...")
112
+ demo = create_gradio_app(cfg)
113
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860)
114
+
115
+ main()
src/utils/aws_s3_services.py CHANGED
@@ -51,14 +51,24 @@ class S3Handler:
51
  s3_folder (str): Source folder in S3.
52
  dest_folder (str): Local destination folder path.
53
  """
54
- dest_folder = Path(dest_folder)
55
  paginator = self.s3.get_paginator("list_objects_v2")
56
 
57
  for page in paginator.paginate(Bucket=self.bucket_name, Prefix=s3_folder):
58
  for obj in page.get("Contents", []):
59
  s3_path = obj["Key"]
60
- local_path = dest_folder / Path(s3_path).relative_to(s3_folder)
 
 
 
 
 
 
 
 
61
  local_path.parent.mkdir(parents=True, exist_ok=True)
 
 
62
  self.s3.download_file(self.bucket_name, s3_path, str(local_path))
63
  print(f"Downloaded: {s3_path} to {local_path}")
64
 
@@ -71,8 +81,8 @@ if __name__ == "__main__":
71
  # Upload specific files
72
  s3_handler.upload_folder(
73
  "checkpoints",
74
- "checkpoints_test",
75
  )
76
 
77
  # Download example
78
- s3_handler.download_folder("checkpoints_test", "checkpoints")
 
51
  s3_folder (str): Source folder in S3.
52
  dest_folder (str): Local destination folder path.
53
  """
54
+ dest_folder = Path(dest_folder).resolve()
55
  paginator = self.s3.get_paginator("list_objects_v2")
56
 
57
  for page in paginator.paginate(Bucket=self.bucket_name, Prefix=s3_folder):
58
  for obj in page.get("Contents", []):
59
  s3_path = obj["Key"]
60
+ # Skip folder itself if returned by S3
61
+ if s3_path.endswith("/"):
62
+ continue
63
+
64
+ # Compute relative path and local destination
65
+ relative_path = Path(s3_path[len(s3_folder) :].lstrip("/"))
66
+ local_path = dest_folder / relative_path
67
+
68
+ # Create necessary local directories
69
  local_path.parent.mkdir(parents=True, exist_ok=True)
70
+
71
+ # Download file
72
  self.s3.download_file(self.bucket_name, s3_path, str(local_path))
73
  print(f"Downloaded: {s3_path} to {local_path}")
74
 
 
81
  # Upload specific files
82
  s3_handler.upload_folder(
83
  "checkpoints",
84
+ "checkpoints",
85
  )
86
 
87
  # Download example
88
+ s3_handler.download_folder("checkpoints", "checkpoints")