Ashok75 commited on
Commit
e944423
·
verified ·
1 Parent(s): 6b59904

Upload server_runtime.py

Browse files
Files changed (1) hide show
  1. server_runtime.py +43 -15
server_runtime.py CHANGED
@@ -145,6 +145,9 @@ def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
145
  join_timeout = float(os.getenv("HF_GENERATION_JOIN_TIMEOUT_SECONDS", "180"))
146
  max_input_tokens = int(os.getenv("HF_MAX_INPUT_TOKENS", str(config.max_input_tokens)))
147
  max_new_tokens_limit = int(os.getenv("HF_MAX_NEW_TOKENS", str(config.max_new_tokens)))
 
 
 
148
 
149
  base_dir = os.path.dirname(os.path.abspath(__file__))
150
 
@@ -366,32 +369,57 @@ def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
366
  nonlocal model, tokenizer, worker_tasks, max_workers, device
367
 
368
  logger.info("Loading model %s on %s", config.model_name, device)
369
- tokenizer_kwargs: Dict[str, Any] = {"trust_remote_code": True}
 
 
 
370
  if config.tokenizer_use_fast is not None:
371
  tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast
372
- tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs)
373
  model_load_kwargs: Dict[str, Any] = {
374
  "trust_remote_code": True,
375
  "device_map": "auto" if device == "cuda" else None,
 
376
  }
377
  if device == "cuda":
378
  model_load_kwargs["dtype"] = "auto"
379
  else:
380
  model_load_kwargs["torch_dtype"] = torch.float32
381
 
382
- try:
383
- model = AutoModelForCausalLM.from_pretrained(
384
- config.model_name,
385
- **model_load_kwargs,
386
- )
387
- except TypeError:
388
- # Backward compatibility for older transformers that do not accept `dtype`.
389
- if "dtype" in model_load_kwargs:
390
- model_load_kwargs["torch_dtype"] = model_load_kwargs.pop("dtype")
391
- model = AutoModelForCausalLM.from_pretrained(
392
- config.model_name,
393
- **model_load_kwargs,
394
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  if device != "cuda":
397
  model = model.to("cpu")
 
145
  join_timeout = float(os.getenv("HF_GENERATION_JOIN_TIMEOUT_SECONDS", "180"))
146
  max_input_tokens = int(os.getenv("HF_MAX_INPUT_TOKENS", str(config.max_input_tokens)))
147
  max_new_tokens_limit = int(os.getenv("HF_MAX_NEW_TOKENS", str(config.max_new_tokens)))
148
+ model_load_retries = max(1, int(os.getenv("HF_MODEL_LOAD_RETRIES", "4")))
149
+ model_load_retry_delay = max(1.0, float(os.getenv("HF_MODEL_LOAD_RETRY_DELAY_SECONDS", "8")))
150
+ local_files_only = _is_truthy(os.getenv("HF_LOCAL_FILES_ONLY", "0"))
151
 
152
  base_dir = os.path.dirname(os.path.abspath(__file__))
153
 
 
369
  nonlocal model, tokenizer, worker_tasks, max_workers, device
370
 
371
  logger.info("Loading model %s on %s", config.model_name, device)
372
+ tokenizer_kwargs: Dict[str, Any] = {
373
+ "trust_remote_code": True,
374
+ "local_files_only": local_files_only,
375
+ }
376
  if config.tokenizer_use_fast is not None:
377
  tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast
 
378
  model_load_kwargs: Dict[str, Any] = {
379
  "trust_remote_code": True,
380
  "device_map": "auto" if device == "cuda" else None,
381
+ "local_files_only": local_files_only,
382
  }
383
  if device == "cuda":
384
  model_load_kwargs["dtype"] = "auto"
385
  else:
386
  model_load_kwargs["torch_dtype"] = torch.float32
387
 
388
+ last_load_error: Optional[Exception] = None
389
+ for attempt in range(1, model_load_retries + 1):
390
+ try:
391
+ tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs)
392
+ try:
393
+ model = AutoModelForCausalLM.from_pretrained(
394
+ config.model_name,
395
+ **model_load_kwargs,
396
+ )
397
+ except TypeError:
398
+ # Backward compatibility for older transformers that do not accept `dtype`.
399
+ if "dtype" in model_load_kwargs:
400
+ model_load_kwargs["torch_dtype"] = model_load_kwargs.pop("dtype")
401
+ model = AutoModelForCausalLM.from_pretrained(
402
+ config.model_name,
403
+ **model_load_kwargs,
404
+ )
405
+ break
406
+ except Exception as exc:
407
+ last_load_error = exc
408
+ logger.warning(
409
+ "Model load attempt %d/%d failed: %s",
410
+ attempt,
411
+ model_load_retries,
412
+ str(exc),
413
+ )
414
+ if attempt < model_load_retries:
415
+ await asyncio.sleep(model_load_retry_delay)
416
+ else:
417
+ logger.error(
418
+ "Model loading failed after %d attempts (local_files_only=%s)",
419
+ model_load_retries,
420
+ str(local_files_only),
421
+ )
422
+ raise last_load_error
423
 
424
  if device != "cuda":
425
  model = model.to("cpu")