| """ |
| This class contains multiple LLMs and handles LLMs response |
| """ |
|
|
| import json |
| import time |
| from openai import OpenAI |
| import openai |
| import torch |
| import re |
| import anthropic |
| import os |
| import streamlit as st |
| from google.genai import types |
| from google import genai |
|
|
|
|
|
|
| class LLM: |
| def __init__(self, Core): |
| self.Core = Core |
| self.model = None |
| self.model_type = "openai" |
| self.client = None |
| self.connect_to_llm() |
|
|
| def get_credential(self, key): |
| return os.getenv(key) or st.secrets.get(key) |
|
|
| def get_response(self, prompt, instructions): |
| if self.model_type == "openai": |
| response = self.get_message_openai(prompt, instructions) |
| |
| |
| elif self.model_type == "inference": |
| response = self.get_message_inference(prompt, instructions) |
| elif self.model_type == "claude": |
| response = self.get_message_claude(prompt, instructions) |
| elif self.model_type == "google": |
| response = self.get_message_google(prompt, instructions) |
| else: |
| raise f"Invalid model type : {self.model_type}" |
|
|
| return response |
|
|
| def connect_to_llm(self): |
| """ |
| connect to selected llm -> ollama or openai connection |
| :return: |
| """ |
|
|
| if self.Core.model in self.Core.config_file["openai_models"]: |
| self.model_type = "openai" |
|
|
| elif self.Core.model in self.Core.config_file["inference_models"]: |
| self.model_type = "inference" |
|
|
| elif self.Core.model in self.Core.config_file["google_models"]: |
| self.model_type = "google" |
|
|
| |
| |
| |
|
|
| elif self.Core.model in self.Core.config_file["claude_models"]: |
| self.model_type = "claude" |
| self.client = anthropic.Anthropic( |
| api_key=self.get_credential('claude_api_key'), |
| ) |
|
|
| self.model = self.Core.model |
|
|
| |
| def get_message_inference(self, prompt, instructions, max_retries=6): |
| """ |
| sending the prompt to openai LLM and get back the response |
| """ |
|
|
| api_key = self.get_credential('inference_api_key') |
| client = OpenAI( |
| base_url="https://api.inference.net/v1", |
| api_key=api_key, |
| ) |
|
|
| for attempt in range(max_retries): |
| try: |
| if self.Core.reasoning_model: |
| response = client.chat.completions.create( |
| model=self.Core.model, |
| response_format={"type": "json_object"}, |
| messages=[ |
| {"role": "system", "content": instructions}, |
| {"role": "user", "content": prompt} |
| ], |
| reasoning_effort="medium", |
| n=1, |
| ) |
|
|
| else: |
| response = client.chat.completions.create( |
| model=self.Core.model, |
| response_format={"type": "json_object"}, |
| messages=[ |
| {"role": "system", "content": instructions}, |
| {"role": "user", "content": prompt} |
| ], |
| n=1, |
| temperature=self.Core.temperature |
| ) |
|
|
| tokens = { |
| 'prompt_tokens': response.usage.prompt_tokens, |
| 'completion_tokens': response.usage.completion_tokens, |
| 'total_tokens': response.usage.total_tokens |
| } |
|
|
| |
| self.Core.total_tokens['prompt_tokens'] += tokens['prompt_tokens'] |
| self.Core.total_tokens['completion_tokens'] += tokens['completion_tokens'] |
| self.Core.temp_token_counter += tokens['total_tokens'] |
|
|
| try: |
| content = response.choices[0].message.content |
|
|
| |
|
|
| output = json.loads(content) |
|
|
| if 'message' not in output or 'header' not in output: |
| print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| continue |
|
|
| else: |
| if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| output["message"].strip()) > self.Core.config_file["message_limit"]: |
| print( |
| f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| continue |
|
|
| return output |
|
|
| except json.JSONDecodeError: |
| print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
|
|
| except openai.APIConnectionError as e: |
| print("The server could not be reached") |
| print(e.__cause__) |
| except openai.RateLimitError as e: |
| print("A 429 status code was received; we should back off a bit.") |
| except openai.APIStatusError as e: |
| print("Another non-200-range status code was received") |
| print(e.status_code) |
| print(e.response) |
|
|
| print("Max retries exceeded. Returning empty response.") |
| return None |
|
|
| |
| def get_message_google(self, prompt, instructions, max_retries=6): |
|
|
| client = genai.Client(api_key=self.get_credential("Google_API")) |
|
|
| for attempt in range(max_retries): |
| try: |
| response = client.models.generate_content( |
| model=self.Core.model, |
| contents=prompt, |
| config=types.GenerateContentConfig( |
| thinking_config=types.ThinkingConfig(thinking_budget=0), |
| system_instruction=instructions, |
| temperature=self.Core.temperature, |
| response_mime_type="application/json" |
| )) |
|
|
| |
| tokens = { |
| 'prompt_tokens': response.usage_metadata.prompt_token_count, |
| 'completion_tokens': response.usage_metadata.candidates_token_count, |
| 'total_tokens': response.usage_metadata.total_token_count |
| } |
|
|
| |
| self.Core.total_tokens['prompt_tokens'] += tokens['prompt_tokens'] |
| self.Core.total_tokens['completion_tokens'] += tokens['completion_tokens'] |
| self.Core.temp_token_counter += tokens['total_tokens'] |
|
|
| output = self.preprocess_and_parse_json(response.text) |
|
|
| if 'message' not in output or 'header' not in output: |
| print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| continue |
|
|
| else: |
| if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| output["message"].strip()) > self.Core.config_file["message_limit"]: |
| print( |
| f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| continue |
| return output |
|
|
| except json.JSONDecodeError: |
| print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
| except Exception as e: |
| print(f"Error in attempt {attempt}: {e}") |
|
|
| print("Max retries exceeded. Returning empty response.") |
| return None |
|
|
| |
|
|
| def get_message_openai(self, prompt, instructions, max_retries=5): |
| """ |
| sending the prompt to openai LLM and get back the response |
| """ |
|
|
| openai.api_key = self.Core.api_key |
| client = OpenAI(api_key=self.Core.api_key) |
|
|
| for attempt in range(max_retries): |
| try: |
| if self.Core.reasoning_model: |
| response = client.chat.completions.create( |
| model=self.Core.model, |
| response_format={"type": "json_object"}, |
| messages=[ |
| {"role": "system", "content": instructions}, |
| {"role": "user", "content": prompt} |
| ], |
| reasoning_effort="minimal", |
| n=1, |
| ) |
|
|
| else: |
| response = client.chat.completions.create( |
| model=self.Core.model, |
| response_format={"type": "json_object"}, |
| messages=[ |
| {"role": "system", "content": instructions}, |
| {"role": "user", "content": prompt} |
| ], |
| n=1, |
| temperature=self.Core.temperature |
| ) |
|
|
| tokens = { |
| 'prompt_tokens': response.usage.prompt_tokens, |
| 'completion_tokens': response.usage.completion_tokens, |
| 'total_tokens': response.usage.total_tokens |
| } |
|
|
| |
| self.Core.total_tokens['prompt_tokens'] += tokens['prompt_tokens'] |
| self.Core.total_tokens['completion_tokens'] += tokens['completion_tokens'] |
| self.Core.temp_token_counter += tokens['total_tokens'] |
|
|
| try: |
| content = response.choices[0].message.content |
|
|
| |
|
|
| output = json.loads(content) |
|
|
| if 'message' not in output or 'header' not in output: |
| print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| continue |
|
|
| else: |
| if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| output["message"].strip()) > self.Core.config_file["message_limit"]: |
| print( |
| f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| continue |
|
|
| return output |
|
|
| except json.JSONDecodeError: |
| print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
|
|
| except openai.APIConnectionError as e: |
| print("The server could not be reached") |
| print(e.__cause__) |
| except openai.RateLimitError as e: |
| print("A 429 status code was received; we should back off a bit.") |
| except openai.APIStatusError as e: |
| print("Another non-200-range status code was received") |
| print(e.status_code) |
| print(e.response) |
|
|
| print("Max retries exceeded. Returning empty response.") |
| return None |
|
|
| |
|
|
| def get_message_ollama(self, prompt, instructions, max_retries=10): |
| """ |
| Send the prompt to the LLM and get back the response. |
| Includes handling for GPU memory issues by clearing cache and waiting before retry. |
| """ |
| prompt = instructions + "\n \n" + prompt |
| for attempt in range(max_retries): |
| try: |
| |
| response = self.client.generate(model=self.model, prompt=prompt) |
| except Exception as e: |
| |
| print(f"Error on attempt {attempt + 1}: {e}.") |
| try: |
| |
| torch.cuda.empty_cache() |
| print("Cleared GPU cache.") |
| except Exception as cache_err: |
| print("Failed to clear GPU cache:", cache_err) |
| |
| time.sleep(2) |
| continue |
|
|
| try: |
| tokens = { |
| 'prompt_tokens': 0, |
| 'completion_tokens': 0, |
| 'total_tokens': 0 |
| } |
|
|
| try: |
| output = self.preprocess_and_parse_json(response.response) |
| if output is None: |
| continue |
|
|
| if 'message' not in output or 'header' not in output: |
| print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| continue |
|
|
| else: |
| if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| output["message"].strip()) > self.Core.config_file["message_limit"]: |
| print( |
| f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| continue |
| else: |
| return output |
|
|
| except json.JSONDecodeError: |
| print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
| except Exception as parse_error: |
| print("Error processing output:", parse_error) |
|
|
| print("Max retries exceeded. Returning empty response.") |
| return None |
|
|
| def get_message_claude(self, prompt, instructions, max_retries=6): |
| """ |
| send prompt to claude LLM and get back the response |
| :param prompt: |
| :param instructions: |
| :return: |
| """ |
|
|
|
|
| for attempt in range(max_retries): |
| try: |
|
|
| message = self.client.messages.create( |
| model=self.model, |
| max_tokens=4096, |
| system = instructions, |
| messages=[ |
| {"role": "user", "content": prompt + "\nHere is the JSON requested:\n"} |
| ], |
| temperature=self.Core.temperature |
| ) |
| |
| response = message.content[0].text |
|
|
| tokens = { |
| 'prompt_tokens': message.usage.input_tokens, |
| 'completion_tokens': message.usage.output_tokens, |
| 'total_tokens': message.usage.output_tokens + message.usage.input_tokens |
| } |
|
|
| self.Core.total_tokens['prompt_tokens'] += tokens['prompt_tokens'] |
| self.Core.total_tokens['completion_tokens'] += tokens['completion_tokens'] |
| self.Core.temp_token_counter += tokens['total_tokens'] |
|
|
| try: |
| output = self.preprocess_and_parse_json_claude(response) |
| if output is None: |
| continue |
|
|
| if 'message' not in output or 'header' not in output: |
| print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| continue |
|
|
| else: |
| if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| output["message"].strip()) > self.Core.config_file["message_limit"]: |
| print( |
| f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| continue |
| else: |
| return output |
|
|
| except json.JSONDecodeError: |
| print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
| except Exception as parse_error: |
| print("Error processing output:", parse_error) |
|
|
| print("Max retries exceeded. Returning empty response.") |
| return None |
|
|
| |
|
|
| def preprocess_and_parse_json(self, response: str): |
| """ |
| Remove <think> blocks, extract JSON (from ```json fences or first {...} block), |
| and parse. Includes a repair pass to handle common LLM issues like trailing commas. |
| """ |
|
|
| def extract_json(text: str) -> str: |
| |
| text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip() |
|
|
| |
| fence = re.search(r'```(?:json)?(.*?)```', text, flags=re.DOTALL | re.IGNORECASE) |
| if fence: |
| return fence.group(1).strip() |
|
|
| |
| brace = re.search(r'\{.*\}', text, flags=re.DOTALL) |
| return brace.group(0).strip() if brace else text.strip() |
|
|
| def normalize_quotes(text: str) -> str: |
| return (text |
| .replace('\ufeff', '') |
| .replace('“', '"').replace('”', '"') |
| .replace('‘', "'").replace('’', "'")) |
|
|
| def strip_comments(text: str) -> str: |
| |
| text = re.sub(r'//.*?$', '', text, flags=re.MULTILINE) |
| text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL) |
| return text |
|
|
| def remove_trailing_commas(text: str) -> str: |
| |
| return re.sub(r',(\s*[}\]])', r'\1', text) |
|
|
| raw = extract_json(response) |
| raw = normalize_quotes(raw) |
|
|
| try: |
| return json.loads(raw) |
| except json.JSONDecodeError: |
| |
| repaired = strip_comments(raw) |
| repaired = remove_trailing_commas(repaired) |
| repaired = repaired.strip() |
|
|
| try: |
| return json.loads(repaired) |
| except json.JSONDecodeError as e: |
| print(f"Failed to parse JSON: {e}") |
| |
| return None |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def preprocess_and_parse_json_claude(self, response: str): |
| """ |
| Process Claude response and extract JSON content safely |
| """ |
| try: |
| json_start = response.index("{") |
| json_end = response.rfind("}") |
| json_string = response[json_start:json_end + 1] |
|
|
| parsed_response = json.loads(json_string) |
|
|
| if not isinstance(parsed_response, dict): |
| raise ValueError(f"Parsed response is not a dict: {parsed_response}") |
|
|
| return parsed_response |
|
|
| except ValueError as ve: |
| raise ValueError(f"Could not extract JSON from Claude response: {ve}\nOriginal response: {response}") |
| except json.JSONDecodeError as je: |
| raise ValueError(f"Failed to parse JSON from string: {json_string}\nError: {je}") |
|
|
|
|
|
|
|
|
|
|
|
|