aknapitsch user commited on
Commit
f7d060f
·
1 Parent(s): 37de32d

load from checkpoint_path

Browse files
Files changed (1) hide show
  1. app.py +75 -6
app.py CHANGED
@@ -30,7 +30,6 @@ from hf_utils.css_and_html import (
30
  get_header_html,
31
  )
32
  from hf_utils.visual_util import predictions_to_glb
33
- from mapanything.models import MapAnything
34
  from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
35
  from mapanything.utils.image import load_images, rgb
36
 
@@ -53,7 +52,7 @@ print("Initializing and loading MapAnything model...")
53
  def load_hf_token():
54
  """Load HuggingFace access token from local file"""
55
  token_file_paths = [
56
- "~/hf_token.txt",
57
  ]
58
 
59
  for token_path in token_file_paths:
@@ -66,6 +65,8 @@ def load_hf_token():
66
  except Exception as e:
67
  print(f"Error reading token from {token_path}: {e}")
68
  continue
 
 
69
 
70
  # Also try environment variable
71
  # see https://huggingface.co/docs/hub/spaces-overview#managing-secrets on options
@@ -103,13 +104,14 @@ def init_hydra_config(config_path, overrides=None):
103
  # MapAnything Configuration
104
  high_level_config = {
105
  "path": "configs/train.yaml",
106
- "hf_model_name": "facebook/map-anything",
107
  "config_overrides": [
108
  "machine=aws",
109
  "model=mapanything",
110
  "model/task=images_only",
111
  "model.encoder.uses_torch_hub=false",
112
  ],
 
113
  "trained_with_amp": True,
114
  "trained_with_amp_dtype": "fp16",
115
  "data_norm_type": "dinov2",
@@ -130,6 +132,8 @@ def run_model(target_dir, model_placeholder, apply_mask=True, mask_edges=True):
130
  Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
131
  """
132
  global model
 
 
133
  print(f"Processing images from {target_dir}")
134
 
135
  # Device check
@@ -140,11 +144,76 @@ def run_model(target_dir, model_placeholder, apply_mask=True, mask_edges=True):
140
  if model is None:
141
  print("Initializing MapAnything model...")
142
 
143
- print("Loading CC-BY-NC 4.0 licensed MapAnything model...")
144
- model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to(
145
- device
146
  )
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  else:
149
  model = model.to(device)
150
 
 
30
  get_header_html,
31
  )
32
  from hf_utils.visual_util import predictions_to_glb
 
33
  from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
34
  from mapanything.utils.image import load_images, rgb
35
 
 
52
  def load_hf_token():
53
  """Load HuggingFace access token from local file"""
54
  token_file_paths = [
55
+ "/home/aknapitsch/hf_token.txt",
56
  ]
57
 
58
  for token_path in token_file_paths:
 
65
  except Exception as e:
66
  print(f"Error reading token from {token_path}: {e}")
67
  continue
68
+ else:
69
+ print(token_path, "token_path doesnt exist")
70
 
71
  # Also try environment variable
72
  # see https://huggingface.co/docs/hub/spaces-overview#managing-secrets on options
 
104
  # MapAnything Configuration
105
  high_level_config = {
106
  "path": "configs/train.yaml",
107
+ "hf_model_name": "facebook/map-anything-apache",
108
  "config_overrides": [
109
  "machine=aws",
110
  "model=mapanything",
111
  "model/task=images_only",
112
  "model.encoder.uses_torch_hub=false",
113
  ],
114
+ "checkpoint_path": "https://huggingface.co/facebook/MapAnything/resolve/main/mapa_curri_24v_13d_48ipg_64g.pth",
115
  "trained_with_amp": True,
116
  "trained_with_amp_dtype": "fp16",
117
  "data_norm_type": "dinov2",
 
132
  Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
133
  """
134
  global model
135
+ import torch # Ensure torch is available in function scope
136
+
137
  print(f"Processing images from {target_dir}")
138
 
139
  # Device check
 
144
  if model is None:
145
  print("Initializing MapAnything model...")
146
 
147
+ # Initialize Hydra config and create model from configuration
148
+ cfg = init_hydra_config(
149
+ high_level_config["path"], overrides=high_level_config["config_overrides"]
150
  )
151
 
152
+ print("Loading MapAnything model...")
153
+ # Create model from local configuration instead of using from_pretrained
154
+ from mapanything.models import init_model
155
+
156
+ model = init_model(
157
+ model_str=cfg.model.model_str,
158
+ model_config=cfg.model.model_config,
159
+ torch_hub_force_reload=high_level_config.get(
160
+ "torch_hub_force_reload", False
161
+ ),
162
+ )
163
+
164
+ # Load the pretrained weights from HuggingFace Hub
165
+ try:
166
+ from huggingface_hub import hf_hub_download, list_repo_files
167
+
168
+ # First, let's see what files are available in the repository
169
+ try:
170
+ repo_files = list_repo_files(
171
+ repo_id=high_level_config["hf_model_name"], token=load_hf_token()
172
+ )
173
+ print(f"Available files in repository: {repo_files}")
174
+
175
+ checkpoint_filename = "model.safetensors"
176
+
177
+ # Download the model weights
178
+ checkpoint_path = hf_hub_download(
179
+ repo_id=high_level_config["hf_model_name"],
180
+ filename=checkpoint_filename,
181
+ token=load_hf_token(),
182
+ )
183
+
184
+ # Load the weights
185
+ print("start loading checkpoint")
186
+ if checkpoint_filename.endswith(".safetensors"):
187
+ from safetensors.torch import load_file
188
+
189
+ checkpoint = load_file(checkpoint_path)
190
+ else:
191
+ checkpoint = torch.load(
192
+ checkpoint_path, map_location="cpu", weights_only=True
193
+ )
194
+
195
+ print("start loading state_dict")
196
+ if "model" in checkpoint:
197
+ model.load_state_dict(checkpoint["model"])
198
+ elif "state_dict" in checkpoint:
199
+ model.load_state_dict(checkpoint["state_dict"])
200
+ else:
201
+ model.load_state_dict(checkpoint)
202
+
203
+ print(
204
+ f"Successfully loaded pretrained weights from HuggingFace Hub ({checkpoint_filename})"
205
+ )
206
+
207
+ except Exception as inner_e:
208
+ print(f"Error listing repository files or loading weights: {inner_e}")
209
+ raise inner_e
210
+
211
+ except Exception as e:
212
+ print(f"Warning: Could not load pretrained weights: {e}")
213
+ print("Proceeding with randomly initialized model...")
214
+
215
+ model = model.to(device)
216
+
217
  else:
218
  model = model.to(device)
219