github-actions[bot] commited on
Commit
d778d65
·
1 Parent(s): cc25b3c

🚀 Auto-deploy backend from GitHub (3001f56)

Browse files
Files changed (3) hide show
  1. config/env.sample +5 -2
  2. main.py +41 -2
  3. tests/test_api.py +20 -0
config/env.sample CHANGED
@@ -38,17 +38,20 @@ INFERENCE_INTERACTIVE_TIMEOUT_SEC=55
38
  INFERENCE_BACKGROUND_TIMEOUT_SEC=120
39
 
40
  # model defaults
41
- INFERENCE_MODEL_ID=meta-llama/Llama-3.1-8B-Instruct
 
42
  INFERENCE_MAX_NEW_TOKENS=640
43
  INFERENCE_TEMPERATURE=0.2
44
  INFERENCE_TOP_P=0.9
45
- INFERENCE_CHAT_MODEL_ID=meta-llama/Llama-3.1-8B-Instruct
46
  INFERENCE_CHAT_HARD_MODEL_ID=meta-llama/Meta-Llama-3-70B-Instruct
47
  INFERENCE_CHAT_HARD_TRIGGER_ENABLED=true
48
  INFERENCE_CHAT_HARD_PROMPT_CHARS=650
49
  INFERENCE_CHAT_HARD_HISTORY_CHARS=1500
50
  INFERENCE_CHAT_HARD_KEYWORDS=step-by-step,show all steps,explain each step,justify each step,derive,derivation,proof,prove,rigorous,multi-step,word problem
51
  CHAT_MAX_NEW_TOKENS=768
 
 
52
  # Optional: force quiz-generation model. Leave empty to use routing.task_model_map.quiz_generation.
53
  HF_QUIZ_MODEL_ID=
54
  HF_QUIZ_JSON_REPAIR_MODEL_ID=meta-llama/Llama-3.1-8B-Instruct
 
38
  INFERENCE_BACKGROUND_TIMEOUT_SEC=120
39
 
40
  # model defaults
41
+ # Leave empty unless you intentionally want one global model for every task.
42
+ INFERENCE_MODEL_ID=
43
  INFERENCE_MAX_NEW_TOKENS=640
44
  INFERENCE_TEMPERATURE=0.2
45
  INFERENCE_TOP_P=0.9
46
+ INFERENCE_CHAT_MODEL_ID=Qwen/Qwen2.5-7B-Instruct
47
  INFERENCE_CHAT_HARD_MODEL_ID=meta-llama/Meta-Llama-3-70B-Instruct
48
  INFERENCE_CHAT_HARD_TRIGGER_ENABLED=true
49
  INFERENCE_CHAT_HARD_PROMPT_CHARS=650
50
  INFERENCE_CHAT_HARD_HISTORY_CHARS=1500
51
  INFERENCE_CHAT_HARD_KEYWORDS=step-by-step,show all steps,explain each step,justify each step,derive,derivation,proof,prove,rigorous,multi-step,word problem
52
  CHAT_MAX_NEW_TOKENS=768
53
+ CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC=25
54
+ CHAT_STREAM_TOTAL_TIMEOUT_SEC=120
55
  # Optional: force quiz-generation model. Leave empty to use routing.task_model_map.quiz_generation.
56
  HF_QUIZ_MODEL_ID=
57
  HF_QUIZ_JSON_REPAIR_MODEL_ID=meta-llama/Llama-3.1-8B-Instruct
main.py CHANGED
@@ -206,6 +206,11 @@ FIREBASE_AUTH_PROJECT_ALLOWLIST: Set[str] = {
206
  if value.strip()
207
  }
208
  CHAT_MAX_NEW_TOKENS = max(256, int(os.getenv("CHAT_MAX_NEW_TOKENS", "576")))
 
 
 
 
 
209
 
210
  ALLOWED_UPLOAD_EXTENSIONS: Set[str] = {".csv", ".xlsx", ".xls", ".pdf"}
211
  ALLOWED_UPLOAD_MIME_TYPES: Set[str] = {
@@ -1929,17 +1934,51 @@ async def chat_tutor_stream(request: ChatRequest):
1929
  return "\n".join(body) + "\n\n"
1930
 
1931
  async def event_generator():
 
 
 
1932
  try:
1933
- async for chunk in call_hf_chat_stream_async(
1934
  messages,
1935
  max_tokens=CHAT_MAX_NEW_TOKENS,
1936
  temperature=0.3,
1937
  top_p=0.85,
1938
  task_type="chat",
1939
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1940
  payload = json.dumps({"chunk": chunk}, ensure_ascii=False)
1941
  yield _sse("chunk", payload)
1942
  await asyncio.sleep(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1943
  except Exception as hf_err:
1944
  logger.error(f"HF chat stream failed: {hf_err}")
1945
  err_payload = json.dumps({
 
206
  if value.strip()
207
  }
208
  CHAT_MAX_NEW_TOKENS = max(256, int(os.getenv("CHAT_MAX_NEW_TOKENS", "576")))
209
+ CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC = max(5, int(os.getenv("CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC", "25")))
210
+ CHAT_STREAM_TOTAL_TIMEOUT_SEC = max(
211
+ CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC,
212
+ int(os.getenv("CHAT_STREAM_TOTAL_TIMEOUT_SEC", "120")),
213
+ )
214
 
215
  ALLOWED_UPLOAD_EXTENSIONS: Set[str] = {".csv", ".xlsx", ".xls", ".pdf"}
216
  ALLOWED_UPLOAD_MIME_TYPES: Set[str] = {
 
1934
  return "\n".join(body) + "\n\n"
1935
 
1936
  async def event_generator():
1937
+ stream_iterator = None
1938
+ stream_started_at = time.monotonic()
1939
+ emitted_any_chunk = False
1940
  try:
1941
+ stream_iterator = call_hf_chat_stream_async(
1942
  messages,
1943
  max_tokens=CHAT_MAX_NEW_TOKENS,
1944
  temperature=0.3,
1945
  top_p=0.85,
1946
  task_type="chat",
1947
+ )
1948
+
1949
+ while True:
1950
+ elapsed = time.monotonic() - stream_started_at
1951
+ remaining_total = CHAT_STREAM_TOTAL_TIMEOUT_SEC - elapsed
1952
+ if remaining_total <= 0:
1953
+ raise TimeoutError("Chat stream exceeded total timeout")
1954
+
1955
+ token_timeout = min(CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC, remaining_total)
1956
+ try:
1957
+ chunk = await asyncio.wait_for(stream_iterator.__anext__(), timeout=token_timeout)
1958
+ except StopAsyncIteration:
1959
+ break
1960
+
1961
+ if not chunk:
1962
+ continue
1963
+
1964
+ emitted_any_chunk = True
1965
  payload = json.dumps({"chunk": chunk}, ensure_ascii=False)
1966
  yield _sse("chunk", payload)
1967
  await asyncio.sleep(0)
1968
+ except (asyncio.TimeoutError, TimeoutError):
1969
+ logger.error(
1970
+ "HF chat stream timed out (idle=%ss total=%ss)",
1971
+ CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC,
1972
+ CHAT_STREAM_TOTAL_TIMEOUT_SEC,
1973
+ )
1974
+ err_payload = json.dumps({
1975
+ "detail": (
1976
+ "AI response stream timed out mid-response. Please retry."
1977
+ if emitted_any_chunk
1978
+ else "AI response stream timed out before any tokens were received. Please retry."
1979
+ ),
1980
+ })
1981
+ yield _sse("error", err_payload)
1982
  except Exception as hf_err:
1983
  logger.error(f"HF chat stream failed: {hf_err}")
1984
  err_payload = json.dumps({
tests/test_api.py CHANGED
@@ -286,6 +286,26 @@ class TestChatEndpoint:
286
  assert "event: error" in content
287
  assert "event: end" in content
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  class TestHFChatTransport:
291
  @patch("main.http_requests.post")
 
286
  assert "event: error" in content
287
  assert "event: end" in content
288
 
289
+ @patch("main.call_hf_chat_stream_async")
290
+ def test_chat_stream_timeout_emits_error_and_end_events(self, mock_stream_async):
291
+ async def _slow_stream(*args, **kwargs):
292
+ await asyncio.sleep(0.05)
293
+ yield "late chunk"
294
+
295
+ mock_stream_async.return_value = _slow_stream()
296
+
297
+ with patch.object(main_module, "CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC", 0.01), patch.object(main_module, "CHAT_STREAM_TOTAL_TIMEOUT_SEC", 0.03):
298
+ with client.stream("POST", "/api/chat/stream", json={
299
+ "message": "Say hello",
300
+ "history": [],
301
+ }) as response:
302
+ assert response.status_code == 200
303
+ content = "".join(response.iter_text())
304
+
305
+ assert "event: error" in content
306
+ assert "timed out" in content.lower()
307
+ assert "event: end" in content
308
+
309
 
310
  class TestHFChatTransport:
311
  @patch("main.http_requests.post")