Rajan Sharma commited on
Commit
30d5bf6
·
verified ·
1 Parent(s): 04c0a3b

Update llm_router.py

Browse files
Files changed (1) hide show
  1. llm_router.py +68 -16
llm_router.py CHANGED
@@ -1,36 +1,87 @@
1
  from typing import Optional, List
 
2
  import cohere
3
- from settings import COHERE_API_KEY, COHERE_MODEL_PRIMARY, MODEL_SETTINGS
4
- from local_llm import LocalLLM
 
 
5
 
6
- _local = None
 
 
 
 
 
7
 
8
- def _local_llm() -> LocalLLM:
9
- global _local
10
- if _local is None:
11
- _local = LocalLLM()
12
- return _local
13
 
14
- def cohere_chat(prompt: str) -> Optional[str]:
 
 
 
15
  if not COHERE_API_KEY:
16
  return None
17
- try:
18
- cli = cohere.Client(api_key=COHERE_API_KEY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  resp = cli.chat(
20
  model=COHERE_MODEL_PRIMARY,
21
  message=prompt,
22
  temperature=MODEL_SETTINGS["temperature"],
23
  max_tokens=MODEL_SETTINGS["max_new_tokens"],
24
  )
25
- if hasattr(resp, "text") and resp.text: return resp.text
26
- if hasattr(resp, "reply") and resp.reply: return resp.reply
27
- if hasattr(resp, "generations") and resp.generations: return resp.generations[0].text
 
 
 
 
 
 
 
28
  except Exception:
29
  return None
30
- return None
31
 
32
  def open_fallback_chat(prompt: str) -> Optional[str]:
33
- return _local_llm().chat(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def generate_narrative(scenario_text: str, structured_sections_md: str, rag_snippets: List[str]) -> str:
36
  grounding = "\n\n".join([f"[RAG {i+1}]\n{t}" for i, t in enumerate(rag_snippets or [])])
@@ -55,3 +106,4 @@ Do not invent numbers. If data are missing, say so clearly.
55
  if out: return out
56
  return "Unable to generate narrative at this time."
57
 
 
 
1
  from typing import Optional, List
2
+ import time
3
  import cohere
4
+ from settings import (
5
+ COHERE_API_KEY, COHERE_API_URL, COHERE_MODEL_PRIMARY, COHERE_EMBED_MODEL,
6
+ MODEL_SETTINGS, USE_OPEN_FALLBACKS
7
+ )
8
 
9
+ # Optional open-model fallback (only used if USE_OPEN_FALLBACKS=True)
10
+ try:
11
+ from local_llm import LocalLLM
12
+ _HAS_LOCAL = True
13
+ except Exception:
14
+ _HAS_LOCAL = False
15
 
16
+ _client: Optional[cohere.Client] = None
 
 
 
 
17
 
18
+ def _co_client() -> Optional[cohere.Client]:
19
+ global _client
20
+ if _client is not None:
21
+ return _client
22
  if not COHERE_API_KEY:
23
  return None
24
+ # NOTE: The Cohere Python SDK auto-selects API base; you can pass a custom base if provided.
25
+ if COHERE_API_URL:
26
+ _client = cohere.Client(api_key=COHERE_API_KEY, base_url=COHERE_API_URL, timeout=MODEL_SETTINGS.get("timeout_s", 45))
27
+ else:
28
+ _client = cohere.Client(api_key=COHERE_API_KEY, timeout=MODEL_SETTINGS.get("timeout_s", 45))
29
+ return _client
30
+
31
+ def _retry(fn, attempts=3, backoff=0.8):
32
+ last = None
33
+ for i in range(attempts):
34
+ try:
35
+ return fn()
36
+ except Exception as e:
37
+ last = e
38
+ time.sleep(backoff * (2 ** i))
39
+ raise last if last else RuntimeError("Unknown error")
40
+
41
+ def cohere_chat(prompt: str) -> Optional[str]:
42
+ cli = _co_client()
43
+ if not cli:
44
+ return None
45
+ def _call():
46
  resp = cli.chat(
47
  model=COHERE_MODEL_PRIMARY,
48
  message=prompt,
49
  temperature=MODEL_SETTINGS["temperature"],
50
  max_tokens=MODEL_SETTINGS["max_new_tokens"],
51
  )
52
+ # SDK shape may provide .text, .reply, or generations
53
+ if hasattr(resp, "text") and resp.text:
54
+ return resp.text
55
+ if hasattr(resp, "reply") and resp.reply:
56
+ return resp.reply
57
+ if hasattr(resp, "generations") and resp.generations:
58
+ return resp.generations[0].text
59
+ return None
60
+ try:
61
+ return _retry(_call, attempts=3)
62
  except Exception:
63
  return None
 
64
 
65
  def open_fallback_chat(prompt: str) -> Optional[str]:
66
+ if not USE_OPEN_FALLBACKS or not _HAS_LOCAL:
67
+ return None
68
+ try:
69
+ return LocalLLM().chat(prompt)
70
+ except Exception:
71
+ return None
72
+
73
+ def cohere_embed(texts: List[str]) -> List[List[float]]:
74
+ cli = _co_client()
75
+ if not cli or not texts:
76
+ return []
77
+ def _call():
78
+ resp = cli.embed(texts=texts, model=COHERE_EMBED_MODEL)
79
+ # Newer SDK: resp.embeddings; older: resp.data
80
+ return getattr(resp, "embeddings", None) or getattr(resp, "data", []) or []
81
+ try:
82
+ return _retry(_call, attempts=3)
83
+ except Exception:
84
+ return []
85
 
86
  def generate_narrative(scenario_text: str, structured_sections_md: str, rag_snippets: List[str]) -> str:
87
  grounding = "\n\n".join([f"[RAG {i+1}]\n{t}" for i, t in enumerate(rag_snippets or [])])
 
106
  if out: return out
107
  return "Unable to generate narrative at this time."
108
 
109
+