harshinde commited on
Commit
11f9299
·
verified ·
1 Parent(s): 3e79cf5

Upload src/model_downloader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/model_downloader.py +28 -9
src/model_downloader.py CHANGED
@@ -11,22 +11,22 @@ class ModelDownloader:
11
  self.models_dir = Path("/app/models")
12
  self.models_dir.mkdir(exist_ok=True)
13
 
14
- # Kaggle model repository details
15
- self.kaggle_model_url = "https://www.kaggle.com/models/harshshinde8/sims/frameworks/PyTorch/serve"
16
 
17
- # Model mapping with Kaggle model IDs
18
  self.model_files = {
19
  "deeplabv3plus": {
20
  "file": "deeplabv3.pth",
21
- "id": "deeplabv3"
22
  },
23
  "densenet121": {
24
  "file": "densenet121.pth",
25
- "id": "densenet121"
26
  },
27
  "efficientnetb0": {
28
- "file": "effucientnetb0.pth",
29
- "id": "effucientnetb0"
30
  },
31
  "inceptionresnetv2": {
32
  "file": "inceptionresnetv2.pth",
@@ -70,9 +70,9 @@ class ModelDownloader:
70
  }
71
  }
72
 
73
- def download_from_kaggle(self, model_name):
74
  """
75
- Download model from Kaggle Models repository
76
  Args:
77
  model_name (str): Name of the model to download
78
  Returns:
@@ -83,6 +83,25 @@ class ModelDownloader:
83
 
84
  model_info = self.model_files[model_name]
85
  model_path = self.models_dir / model_info['file']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # If model already exists, return path
88
  if model_path.exists():
 
11
  self.models_dir = Path("/app/models")
12
  self.models_dir.mkdir(exist_ok=True)
13
 
14
+ # HuggingFace model repository details
15
+ self.hf_model_url = "https://huggingface.co/harshinde/Sims/resolve/main/models/"
16
 
17
+ # Model mapping with file names
18
  self.model_files = {
19
  "deeplabv3plus": {
20
  "file": "deeplabv3.pth",
21
+ "url": f"{self.hf_model_url}deeplabv3.pth"
22
  },
23
  "densenet121": {
24
  "file": "densenet121.pth",
25
+ "url": f"{self.hf_model_url}densenet121.pth"
26
  },
27
  "efficientnetb0": {
28
+ "file": "efficientnetb0.pth",
29
+ "url": f"{self.hf_model_url}efficientnetb0.pth"
30
  },
31
  "inceptionresnetv2": {
32
  "file": "inceptionresnetv2.pth",
 
70
  }
71
  }
72
 
73
+ def download_model(self, model_name):
74
  """
75
+ Download model from Hugging Face Models repository
76
  Args:
77
  model_name (str): Name of the model to download
78
  Returns:
 
83
 
84
  model_info = self.model_files[model_name]
85
  model_path = self.models_dir / model_info['file']
86
+
87
+ if not model_path.exists():
88
+ print(f"Downloading {model_name} model...")
89
+ response = requests.get(model_info['url'], stream=True)
90
+ response.raise_for_status()
91
+
92
+ total_size = int(response.headers.get('content-length', 0))
93
+ with open(model_path, 'wb') as f, tqdm(
94
+ total=total_size,
95
+ unit='iB',
96
+ unit_scale=True,
97
+ unit_divisor=1024,
98
+ ) as pbar:
99
+ for data in response.iter_content(chunk_size=1024):
100
+ size = f.write(data)
101
+ pbar.update(size)
102
+ print(f"Model downloaded successfully to {model_path}")
103
+
104
+ return str(model_path)
105
 
106
  # If model already exists, return path
107
  if model_path.exists():