Commit
·
d279e64
1
Parent(s):
849364d
- hf_backend.py +40 -12
hf_backend.py
CHANGED
|
@@ -33,11 +33,10 @@ except Exception as e:
|
|
| 33 |
|
| 34 |
# ---------------- helpers ----------------
|
| 35 |
def _pick_cpu_dtype() -> torch.dtype:
|
| 36 |
-
# Prefer BF16 if CPU supports it
|
| 37 |
if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
|
| 38 |
try:
|
| 39 |
if torch.cpu.is_bf16_supported():
|
| 40 |
-
logger.info("CPU BF16 supported,
|
| 41 |
return torch.bfloat16
|
| 42 |
except Exception:
|
| 43 |
pass
|
|
@@ -57,17 +56,32 @@ def _get_model(device: str, dtype: torch.dtype):
|
|
| 57 |
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 58 |
if hasattr(cfg, "quantization_config"):
|
| 59 |
logger.warning("Removing quantization_config from model config")
|
| 60 |
-
delattr(cfg, "quantization_config")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 63 |
-
MODEL_ID,
|
| 64 |
-
config=cfg,
|
| 65 |
-
torch_dtype=dtype,
|
| 66 |
-
trust_remote_code=True,
|
| 67 |
-
device_map="auto" if device != "cpu" else {"": "cpu"},
|
| 68 |
-
)
|
| 69 |
model.eval()
|
| 70 |
-
_MODEL_CACHE[
|
| 71 |
return model
|
| 72 |
|
| 73 |
|
|
@@ -78,7 +92,6 @@ class HFChatBackend(ChatBackend):
|
|
| 78 |
raise RuntimeError(load_error)
|
| 79 |
|
| 80 |
messages = request.get("messages", [])
|
| 81 |
-
prompt = messages[-1]["content"] if messages else "(empty)"
|
| 82 |
temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
|
| 83 |
max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
|
| 84 |
|
|
@@ -91,6 +104,21 @@ class HFChatBackend(ChatBackend):
|
|
| 91 |
zero_client.HEADERS["X-IP-Token"] = x_ip_token
|
| 92 |
logger.debug("Injected X-IP-Token into ZeroGPU headers")
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
|
| 95 |
model = _get_model(device, dtype)
|
| 96 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
| 33 |
|
| 34 |
# ---------------- helpers ----------------
|
| 35 |
def _pick_cpu_dtype() -> torch.dtype:
|
|
|
|
| 36 |
if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
|
| 37 |
try:
|
| 38 |
if torch.cpu.is_bf16_supported():
|
| 39 |
+
logger.info("CPU BF16 supported, will attempt torch.bfloat16")
|
| 40 |
return torch.bfloat16
|
| 41 |
except Exception:
|
| 42 |
pass
|
|
|
|
| 56 |
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 57 |
if hasattr(cfg, "quantization_config"):
|
| 58 |
logger.warning("Removing quantization_config from model config")
|
| 59 |
+
delattr(cfg, "quantization_config")
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 63 |
+
MODEL_ID,
|
| 64 |
+
config=cfg,
|
| 65 |
+
torch_dtype=dtype,
|
| 66 |
+
trust_remote_code=True,
|
| 67 |
+
device_map="auto" if device != "cpu" else {"": "cpu"},
|
| 68 |
+
)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
if device == "cpu" and dtype == torch.bfloat16:
|
| 71 |
+
logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.")
|
| 72 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 73 |
+
MODEL_ID,
|
| 74 |
+
config=cfg,
|
| 75 |
+
torch_dtype=torch.float32,
|
| 76 |
+
trust_remote_code=True,
|
| 77 |
+
device_map={"": "cpu"},
|
| 78 |
+
)
|
| 79 |
+
dtype = torch.float32
|
| 80 |
+
else:
|
| 81 |
+
raise
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
model.eval()
|
| 84 |
+
_MODEL_CACHE[(device, dtype)] = model
|
| 85 |
return model
|
| 86 |
|
| 87 |
|
|
|
|
| 92 |
raise RuntimeError(load_error)
|
| 93 |
|
| 94 |
messages = request.get("messages", [])
|
|
|
|
| 95 |
temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
|
| 96 |
max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
|
| 97 |
|
|
|
|
| 104 |
zero_client.HEADERS["X-IP-Token"] = x_ip_token
|
| 105 |
logger.debug("Injected X-IP-Token into ZeroGPU headers")
|
| 106 |
|
| 107 |
+
# Build prompt using chat template if available
|
| 108 |
+
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
|
| 109 |
+
try:
|
| 110 |
+
prompt = tokenizer.apply_chat_template(
|
| 111 |
+
messages,
|
| 112 |
+
tokenize=False,
|
| 113 |
+
add_generation_prompt=True,
|
| 114 |
+
)
|
| 115 |
+
logger.debug("Applied chat template for prompt")
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.warning(f"Failed to apply chat template: {e}, using fallback")
|
| 118 |
+
prompt = messages[-1]["content"] if messages else "(empty)"
|
| 119 |
+
else:
|
| 120 |
+
prompt = messages[-1]["content"] if messages else "(empty)"
|
| 121 |
+
|
| 122 |
def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
|
| 123 |
model = _get_model(device, dtype)
|
| 124 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|