Rajan Sharma commited on
Commit
74d8604
·
verified ·
1 Parent(s): 44a97e9

Update llm_router.py

Browse files
Files changed (1) hide show
  1. llm_router.py +97 -32
llm_router.py CHANGED
@@ -1,34 +1,99 @@
1
- from typing import Optional
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
- from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS
5
-
6
- class LocalLLM:
7
- def __init__(self):
8
- self.pipe = None
9
- self._load_any()
10
-
11
- def _load_any(self):
12
- for mid in OPEN_LLM_CANDIDATES:
13
- try:
14
- tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
15
- mdl = AutoModelForCausalLM.from_pretrained(
16
- mid, device_map="auto",
17
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
18
- trust_remote_code=True
19
- )
20
- self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok)
21
- return
22
- except Exception:
23
- continue
24
-
25
- def chat(self, prompt: str) -> Optional[str]:
26
- if not self.pipe: return None
27
- out = self.pipe(
28
- prompt, max_new_tokens=LOCAL_MAX_NEW_TOKENS,
29
- do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.12,
30
- eos_token_id=self.pipe.tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
- text = out[0]["generated_text"]
33
- return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import time
3
+ import cohere
4
+ from settings import (
5
+ COHERE_API_KEY, COHERE_API_URL, COHERE_MODEL_PRIMARY, COHERE_EMBED_MODEL,
6
+ MODEL_SETTINGS, USE_OPEN_FALLBACKS, COHERE_TIMEOUT_S
7
+ )
8
+
9
+ try:
10
+ from local_llm import LocalLLM
11
+ _HAS_LOCAL = True
12
+ except Exception:
13
+ _HAS_LOCAL = False
14
+
15
+ _client: Optional[cohere.Client] = None
16
+
17
+ def _co_client() -> Optional[cohere.Client]:
18
+ global _client
19
+ if _client is not None:
20
+ return _client
21
+ if not COHERE_API_KEY:
22
+ return None
23
+ kwargs = {"api_key": COHERE_API_KEY, "timeout": COHERE_TIMEOUT_S}
24
+ if COHERE_API_URL:
25
+ kwargs["base_url"] = COHERE_API_URL
26
+ _client = cohere.Client(**kwargs)
27
+ return _client
28
+
29
+ def _retry(fn, attempts=3, backoff=0.8):
30
+ last = None
31
+ for i in range(attempts):
32
+ try:
33
+ return fn()
34
+ except Exception as e:
35
+ last = e
36
+ time.sleep(backoff * (2 ** i))
37
+ raise last if last else RuntimeError("Unknown error")
38
+
39
+ def cohere_chat(prompt: str) -> Optional[str]:
40
+ cli = _co_client();
41
+ if not cli: return None
42
+ def _call():
43
+ resp = cli.chat(
44
+ model=COHERE_MODEL_PRIMARY,
45
+ message=prompt,
46
+ temperature=MODEL_SETTINGS["temperature"],
47
+ max_tokens=MODEL_SETTINGS["max_new_tokens"],
48
  )
49
+ return getattr(resp, "text", None) or getattr(resp, "reply", None) \
50
+ or (resp.generations[0].text if getattr(resp, "generations", None) else None)
51
+ try:
52
+ return _retry(_call, attempts=2)
53
+ except Exception as e:
54
+ from audit_log import log_event; log_event("cohere_chat_error", None, {"err": str(e)})
55
+ return None
56
+
57
+ def open_fallback_chat(prompt: str) -> Optional[str]:
58
+ if not USE_OPEN_FALLBACKS or not _HAS_LOCAL:
59
+ return None
60
+ try:
61
+ return LocalLLM().chat(prompt)
62
+ except Exception:
63
+ return None
64
+
65
+ def cohere_embed(texts: List[str]) -> List[List[float]]:
66
+ cli = _co_client()
67
+ if not cli or not texts:
68
+ return []
69
+ def _call():
70
+ resp = cli.embed(texts=texts, model=COHERE_EMBED_MODEL)
71
+ return getattr(resp, "embeddings", None) or getattr(resp, "data", []) or []
72
+ try:
73
+ return _retry(_call, attempts=2)
74
+ except Exception as e:
75
+ from audit_log import log_event; log_event("cohere_embed_error", None, {"err": str(e)})
76
+ return []
77
+
78
+ def generate_narrative(scenario_text: str, structured_sections_md: str, rag_snippets: List[str]) -> str:
79
+ grounding = "\n\n".join([f"[RAG {i+1}]\n{t}" for i, t in enumerate(rag_snippets or [])])
80
+ prompt = f"""You are a Canadian healthcare operations copilot.
81
+ Follow the scenario's requested deliverables exactly. Use the structured computations provided (already calculated deterministically) and the RAG snippets for grounding.
82
+
83
+ # Scenario
84
+ {scenario_text}
85
+
86
+ # Deterministic Results (already computed)
87
+ {structured_sections_md}
88
+
89
+ # Grounding (Canadian sources, snippets)
90
+ {grounding}
91
 
92
+ Write a concise, decision-ready report tailored to provincial operations leaders.
93
+ Do not invent numbers. If data are missing, say so clearly.
94
+ """
95
+ out = cohere_chat(prompt)
96
+ if out: return out
97
+ out = open_fallback_chat(prompt)
98
+ if out: return out
99
+ return "Unable to generate narrative at this time."