MediAgent / core /llm.py
medi422's picture
Upload 21 files
9a75c73 verified
# mediagent/core/llm.py
"""
Production-grade LLM client wrapper for MediAgent.
Handles text and multimodal (vision) completions against local Qwen model.
Implements retry logic, error handling, response parsing, and OpenAI-compatible API calls.
"""
import logging
import time
import re
from typing import Any, Dict, List, Optional, Union
import openai
logger = logging.getLogger(__name__)
class LLMClient:
"""
Lightweight, framework-agnostic LLM client wrapping OpenAI Python SDK.
Designed for local inference endpoints (vLLM, Ollama, TensorRT-LLM)
running at http://localhost:8000/v1 with model path "/model".
"""
DEFAULT_BASE_URL = "http://localhost:8000/v1"
DEFAULT_MODEL = "/model"
DEFAULT_API_KEY = "none"
def __init__(
self,
base_url: str = DEFAULT_BASE_URL,
model: str = DEFAULT_MODEL,
max_retries: int = 3,
timeout: float = 90.0,
temperature: float = 0.0
):
self.model = model
self.max_retries = max_retries
self.default_temperature = temperature
self.timeout = timeout
self.client = openai.OpenAI(
base_url=base_url,
api_key=self.DEFAULT_API_KEY,
timeout=timeout
)
logger.info(f"LLMClient initialized | Model: {self.model} | Endpoint: {base_url}")
# ─────────────────────────────────────────────────────────────────────────
# CORE GENERATION METHODS
# ─────────────────────────────────────────────────────────────────────────
def generate_text(
self,
prompt: str,
system_prompt: str = "",
temperature: Optional[float] = None,
force_json: bool = False,
max_tokens: Optional[int] = None,
extra_body: Optional[Dict] = None,
) -> Dict[str, Any]:
"""
Send a text-only completion request to the LLM.
Returns standardized response dict with content, usage, success flag, and error.
"""
messages = self._build_messages(system_prompt, prompt)
response_format = {"type": "json_object"} if force_json else None
return self._execute_with_retry(
messages=messages,
temperature=temperature,
response_format=response_format,
call_type="TEXT",
max_tokens=max_tokens,
extra_body=extra_body,
)
def generate_text_streaming(
self,
prompt: str,
system_prompt: str = "",
temperature: Optional[float] = None,
on_token: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Text completion with optional token-level streaming callback.
When on_token is provided, calls on_token(chunk: str) for every token chunk
as it arrives from the model. Returns the full response dict at the end.
Falls back to standard generate_text if streaming fails.
"""
if on_token is None:
return self.generate_text(prompt, system_prompt, temperature)
messages = self._build_messages(system_prompt, prompt)
temp = temperature if temperature is not None else self.default_temperature
try:
stream = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temp,
stream=True,
)
full_content = ""
for chunk in stream:
delta = (chunk.choices[0].delta.content or "") if chunk.choices else ""
if delta:
full_content += delta
try:
on_token(delta)
except Exception:
pass # callback errors must not break generation
logger.debug("Streaming TEXT generation completed | chars=%d", len(full_content))
return {
"success": True,
"content": full_content,
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
"model": self.model,
"error": None,
}
except Exception as e:
logger.warning("Streaming failed (%s), falling back to standard call", e)
return self.generate_text(prompt, system_prompt, temperature)
def generate_vision(
self,
base64_image: str,
prompt: str,
system_prompt: str = "",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""
Send a multimodal completion request with a base64 encoded medical image.
Automatically detects image MIME type and formats per OpenAI vision spec.
"""
img_url = self._format_image_url(base64_image)
user_content = [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": img_url}}
]
messages = self._build_messages(system_prompt, user_content)
return self._execute_with_retry(
messages=messages,
temperature=temperature,
response_format=None,
call_type="VISION",
max_tokens=max_tokens
)
# ─────────────────────────────────────────────────────────────────────────
# INTERNAL HELPERS
# ─────────────────────────────────────────────────────────────────────────
def _build_messages(
self,
system_prompt: str,
user_content: Union[str, List[Dict]]
) -> List[Dict]:
"""Construct OpenAI-compatible message array."""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if isinstance(user_content, str):
messages.append({"role": "user", "content": user_content})
else:
messages.append({"role": "user", "content": user_content})
return messages
def _format_image_url(self, base64_data: str) -> str:
"""Normalize base64 image data into OpenAI vision-compatible URL format."""
if base64_data.startswith(("data:image/png;base64,", "data:image/jpeg;base64,", "data:image/jpg;base64,")):
return base64_data
# Default to JPEG if no MIME prefix is present
return f"data:image/jpeg;base64,{base64_data}"
def _attempt_call(
self,
messages: List[Dict],
temperature: Optional[float],
response_format: Optional[Dict],
max_tokens: Optional[int] = None,
extra_body: Optional[Dict] = None,
) -> Dict[str, Any]:
"""Execute a single API call with the OpenAI client."""
kwargs = {
"model": self.model,
"messages": messages,
"temperature": temperature if temperature is not None else self.default_temperature,
}
if max_tokens:
kwargs["max_tokens"] = max_tokens
if response_format:
kwargs["response_format"] = response_format
if extra_body:
kwargs["extra_body"] = extra_body
response = self.client.chat.completions.create(**kwargs)
content = response.choices[0].message.content or ""
usage = response.usage
return {
"success": True,
"content": content,
"raw_response": response,
"usage": {
"prompt_tokens": usage.prompt_tokens if usage else 0,
"completion_tokens": usage.completion_tokens if usage else 0,
"total_tokens": usage.total_tokens if usage else 0,
},
"model": response.model,
"error": None
}
def _execute_with_retry(
self,
messages: List[Dict],
temperature: Optional[float],
response_format: Optional[Dict],
call_type: str,
max_tokens: Optional[int] = None,
extra_body: Optional[Dict] = None,
) -> Dict[str, Any]:
"""Retry wrapper with exponential backoff for robust local inference."""
last_error = None
for attempt in range(1, self.max_retries + 1):
try:
result = self._attempt_call(messages, temperature, response_format, max_tokens, extra_body)
if result["success"]:
logger.debug(f"{call_type} generation successful on attempt {attempt}")
return result
except Exception as e:
last_error = str(e)
logger.warning(f"{call_type} generation failed on attempt {attempt}/{self.max_retries}: {e}")
if attempt < self.max_retries:
# Short fixed backoff for local inference β€” no need for exponential waits
backoff = 1.0
logger.info(f"Retrying in {backoff}s...")
time.sleep(backoff)
logger.error(f"{call_type} generation failed permanently after {self.max_retries} attempts.")
return {
"success": False,
"content": "",
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
"model": self.model,
"error": last_error or f"{call_type} endpoint unreachable or max retries exceeded."
}
# ─────────────────────────────────────────────────────────────────────────
# RESPONSE PARSING UTILITIES
# ─────────────────────────────────────────────────────────────────────────
@staticmethod
def extract_json_from_response(content: str) -> Optional[Dict[str, Any]]:
"""
Safely extract JSON from LLM output, stripping markdown formatting
and handling partial/comma-separated JSON arrays if necessary.
"""
if not content:
return None
try:
# Strip markdown code fences if present
cleaned = re.sub(r"^```(?:json)?\s*|\s*```$", "", content.strip(), flags=re.MULTILINE)
# First try direct JSON decode
return LLMClient._safe_json_decode(cleaned)
except Exception:
logger.debug("Direct JSON extraction failed. Attempting fallback parsing...")
return LLMClient._fallback_json_parse(cleaned)
@staticmethod
def _safe_json_decode(text: str):
"""Import json lazily and decode, raising cleanly on failure."""
import json
return json.loads(text)
@staticmethod
def _fallback_json_parse(text: str) -> Optional[Dict[str, Any]]:
"""
Fallback: scan for first valid JSON object or array in the text.
Handles cases where the LLM adds conversational padding.
"""
import json
brace_depth = 0
start_idx = None
for i, char in enumerate(text):
if char == "{":
if brace_depth == 0:
start_idx = i
brace_depth += 1
elif char == "}":
brace_depth -= 1
if brace_depth == 0 and start_idx is not None:
candidate = text[start_idx:i+1]
try:
return json.loads(candidate)
except json.JSONDecodeError:
continue
# Try array fallback
bracket_depth = 0
start_idx = None
for i, char in enumerate(text):
if char == "[":
if bracket_depth == 0:
start_idx = i
bracket_depth += 1
elif char == "]":
bracket_depth -= 1
if bracket_depth == 0 and start_idx is not None:
candidate = text[start_idx:i+1]
try:
return json.loads(candidate)
except json.JSONDecodeError:
continue
return None