LogicGoInfotechSpaces commited on
Commit
a2d6cd7
·
1 Parent(s): 80080e1

List repository files to find actual model filename instead of guessing

Browse files
Files changed (1) hide show
  1. app/colorize_model.py +23 -7
app/colorize_model.py CHANGED
@@ -21,7 +21,7 @@ os.environ["XDG_CACHE_HOME"] = cache_dir
21
  import torch
22
  from PIL import Image
23
  from fastai.vision.all import *
24
- from huggingface_hub import from_pretrained_fastai, hf_hub_download
25
 
26
  from app.config import settings
27
 
@@ -64,21 +64,37 @@ class ColorizeModel:
64
  except Exception as e1:
65
  logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
66
  # Fallback: manually download and load the model file
67
- # Try common FastAI model file names
68
- model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl"]
69
- model_path = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
71
  for filename in model_filenames:
72
  try:
73
  model_path = hf_hub_download(
74
  repo_id=self.model_id,
75
  filename=filename,
76
  cache_dir=self.cache_dir,
77
- token=os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
78
  )
79
  logger.info("Found model file: %s", filename)
80
  break
81
- except Exception:
 
82
  continue
83
 
84
  if model_path and os.path.exists(model_path):
@@ -87,7 +103,7 @@ class ColorizeModel:
87
  self.learn = load_learner(model_path)
88
  logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
89
  else:
90
- # If no model file found, try listing repository files
91
  raise RuntimeError(
92
  f"Could not find model file in repository '{self.model_id}'. "
93
  f"Tried: {', '.join(model_filenames)}. "
 
21
  import torch
22
  from PIL import Image
23
  from fastai.vision.all import *
24
+ from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files
25
 
26
  from app.config import settings
27
 
 
64
  except Exception as e1:
65
  logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
66
  # Fallback: manually download and load the model file
67
+ # First, list files in the repository to find the actual model file
68
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
69
+ try:
70
+ repo_files = list_repo_files(repo_id=self.model_id, token=hf_token)
71
+ logger.info("Repository files: %s", repo_files)
72
+ # Look for .pkl files
73
+ pkl_files = [f for f in repo_files if f.endswith('.pkl')]
74
+ if not pkl_files:
75
+ # Also try common FastAI model file names
76
+ model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl"]
77
+ else:
78
+ model_filenames = pkl_files
79
+ logger.info("Found .pkl files in repository: %s", pkl_files)
80
+ except Exception as list_err:
81
+ logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err))
82
+ # Fallback to common filenames
83
+ model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl"]
84
 
85
+ model_path = None
86
  for filename in model_filenames:
87
  try:
88
  model_path = hf_hub_download(
89
  repo_id=self.model_id,
90
  filename=filename,
91
  cache_dir=self.cache_dir,
92
+ token=hf_token
93
  )
94
  logger.info("Found model file: %s", filename)
95
  break
96
+ except Exception as dl_err:
97
+ logger.debug("Failed to download %s: %s", filename, str(dl_err))
98
  continue
99
 
100
  if model_path and os.path.exists(model_path):
 
103
  self.learn = load_learner(model_path)
104
  logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
105
  else:
106
+ # If no model file found, raise error with more details
107
  raise RuntimeError(
108
  f"Could not find model file in repository '{self.model_id}'. "
109
  f"Tried: {', '.join(model_filenames)}. "