johnbridges commited on
Commit
2dcb7ad
·
1 Parent(s): f453fc2
Files changed (2) hide show
  1. app.py +4 -2
  2. hf_backend.py +122 -0
app.py CHANGED
@@ -8,7 +8,9 @@ from listener import RabbitListenerBase
8
  from rabbit_repo import RabbitRepo
9
  from oa_server import OpenAIServers
10
  #from vllm_backend import VLLMChatBackend, StubImagesBackend
11
- from transformers_backend import TransformersChatBackend, StubImagesBackend
 
 
12
 
13
  logging.basicConfig(
14
  level=logging.INFO,
@@ -35,7 +37,7 @@ base = RabbitBase(exchange_type_resolver=resolver)
35
 
36
  servers = OpenAIServers(
37
  publisher,
38
- chat_backend=TransformersChatBackend(),
39
  images_backend=StubImagesBackend()
40
  )
41
 
 
8
  from rabbit_repo import RabbitRepo
9
  from oa_server import OpenAIServers
10
  #from vllm_backend import VLLMChatBackend, StubImagesBackend
11
+ #from transformers_backend import TransformersChatBackend, StubImagesBackend
12
+ from hf_backend import HFChatBackend, StubImagesBackend
13
+
14
 
15
  logging.basicConfig(
16
  level=logging.INFO,
 
37
 
38
  servers = OpenAIServers(
39
  publisher,
40
+ chat_backend=HFChatBackend(),
41
  images_backend=StubImagesBackend()
42
  )
43
 
hf_backend.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # hf_backend.py
2
+ import time, logging, os, contextlib
3
+ from typing import Any, Dict, AsyncIterable, List
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from backends_base import ChatBackend, ImagesBackend
8
+ from config import settings
9
+
10
+ try:
11
+ import spaces
12
+ except ImportError:
13
+ spaces = None
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # --- Load model/tokenizer on CPU at import time (ZeroGPU safe) ---
18
+ MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
19
+ logger.info(f"Loading {MODEL_ID} on CPU at startup (ZeroGPU safe)...")
20
+
21
+ tokenizer = None
22
+ model = None
23
+ load_error = None
24
+ try:
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_ID,
28
+ torch_dtype=torch.float32, # CPU-safe default
29
+ trust_remote_code=True,
30
+ )
31
+ model.eval()
32
+ except Exception as e:
33
+ load_error = f"Failed to load model/tokenizer: {e}"
34
+ logger.exception(load_error)
35
+
36
+
37
+ # --- Device helpers ---
38
+ def pick_device() -> str:
39
+ forced = os.getenv("FORCE_DEVICE", "").lower().strip()
40
+ if forced in {"cpu", "cuda", "mps"}:
41
+ return forced
42
+ if torch.cuda.is_available():
43
+ return "cuda"
44
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
45
+ return "mps"
46
+ return "cpu"
47
+
48
+ def pick_dtype(device: str) -> torch.dtype:
49
+ if device == "cuda":
50
+ major, _ = torch.cuda.get_device_capability()
51
+ return torch.bfloat16 if major >= 8 else torch.float16
52
+ if device == "mps":
53
+ return torch.float16
54
+ return torch.float32
55
+
56
+
57
+ # --- Backend class ---
58
+ class HFChatBackend(ChatBackend):
59
+ async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
60
+ if load_error:
61
+ raise RuntimeError(load_error)
62
+
63
+ messages = request.get("messages", [])
64
+ prompt = messages[-1]["content"] if messages else "(empty)"
65
+ temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
66
+ max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
67
+
68
+ rid = f"chatcmpl-hf-{int(time.time())}"
69
+ now = int(time.time())
70
+
71
+ if spaces:
72
+ @spaces.GPU(duration=120) # allow longer run
73
+ def run_once(prompt: str) -> str:
74
+ device = pick_device()
75
+ dtype = pick_dtype(device)
76
+
77
+ # Move model to GPU if needed
78
+ model.to(device=device, dtype=dtype).eval()
79
+
80
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
81
+ with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
82
+ outputs = model.generate(
83
+ **inputs,
84
+ max_new_tokens=max_tokens,
85
+ temperature=temperature,
86
+ do_sample=True,
87
+ )
88
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
89
+ else:
90
+ def run_once(prompt: str) -> str:
91
+ inputs = tokenizer(prompt, return_tensors="pt")
92
+ with torch.inference_mode():
93
+ outputs = model.generate(
94
+ **inputs,
95
+ max_new_tokens=max_tokens,
96
+ temperature=temperature,
97
+ do_sample=True,
98
+ )
99
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
100
+
101
+ try:
102
+ text = run_once(prompt)
103
+ yield {
104
+ "id": rid,
105
+ "object": "chat.completion.chunk",
106
+ "created": now,
107
+ "model": MODEL_ID,
108
+ "choices": [
109
+ {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
110
+ ],
111
+ }
112
+ except Exception:
113
+ logger.exception("HF inference failed")
114
+ raise
115
+
116
+
117
+ class StubImagesBackend(ImagesBackend):
118
+ async def generate_b64(self, request: Dict[str, Any]) -> str:
119
+ logger.warning("Image generation not supported in HF backend.")
120
+ return (
121
+ "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
122
+ )