WilliamGazeley commited on
Commit
c124df1
1 Parent(s): 0583c4b

Initial untested rag code

Browse files
Files changed (11) hide show
  1. .gitignore +7 -0
  2. app.py +17 -17
  3. config.py +13 -0
  4. functioncall.py +163 -0
  5. functions.py +314 -0
  6. prompt_assets/few_shot.json +8 -0
  7. prompt_assets/sys_prompt.yml +43 -0
  8. prompter.py +76 -0
  9. schema.py +23 -0
  10. utils.py +149 -0
  11. validator.py +132 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .env
2
+
3
+ # Python
4
+ __pycache__/
5
+
6
+ # vLLM
7
+ inference_logs/
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import os
2
  import huggingface_hub
3
  import streamlit as st
 
4
  from vllm import LLM, SamplingParams
 
5
 
6
  sys_msg = """You are an expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
7
  #Objective:
8
  Answer questions accurately and truthfully given your current knowledge. You do not have access to up-to-date current market data; this will be available in the future. Answer the question directly.
9
- Style and tone:
10
  Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
11
  #Audience:
12
  The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
@@ -15,24 +17,18 @@ Direct answer to question, concise yet insightful."""
15
 
16
  @st.cache_resource(show_spinner="Loading model..")
17
  def init_llm():
18
- huggingface_hub.login(token=os.getenv("HF_TOKEN"))
19
- llm = LLM(model="InvestmentResearchAI/LLM-ADE-dev")
20
- tok = llm.get_tokenizer()
21
- tok.eos_token = '<|im_end|>' # Override to use turns
22
  return llm
23
 
24
  def get_response(prompt):
25
  try:
26
- convo = [
27
- {"role": "system", "content": sys_msg},
28
- {"role": "user", "content": prompt},
29
- ]
30
- llm = init_llm()
31
- prompts = [llm.get_tokenizer().apply_chat_template(convo, tokenize=False)]
32
- sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=500, stop_token_ids=[128009])
33
- outputs = llm.generate(prompts, sampling_params)
34
- for output in outputs:
35
- return output.outputs[0].text
36
  except Exception as e:
37
  return f"An error occurred: {str(e)}"
38
 
@@ -52,6 +48,10 @@ def main():
52
 
53
  llm = init_llm()
54
 
55
- if __name__ == "__main__":
56
- main()
 
 
57
 
 
 
 
1
  import os
2
  import huggingface_hub
3
  import streamlit as st
4
+ from config import config
5
  from vllm import LLM, SamplingParams
6
+ from functioncall import ModelInference
7
 
8
  sys_msg = """You are an expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
9
  #Objective:
10
  Answer questions accurately and truthfully given your current knowledge. You do not have access to up-to-date current market data; this will be available in the future. Answer the question directly.
11
+ #Style and tone:
12
  Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
13
  #Audience:
14
  The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
 
17
 
18
  @st.cache_resource(show_spinner="Loading model..")
19
  def init_llm():
20
+ huggingface_hub.login(token=os.getenv("HF_TOKEN"), new_session=False)
21
+ llm = ModelInference(chat_template='chatml')
 
 
22
  return llm
23
 
24
  def get_response(prompt):
25
  try:
26
+ return llm.generate_function_call(
27
+ prompt,
28
+ config.chat_template,
29
+ config.num_fewshot,
30
+ config.max_depth
31
+ )
 
 
 
 
32
  except Exception as e:
33
  return f"An error occurred: {str(e)}"
34
 
 
48
 
49
  llm = init_llm()
50
 
51
+ def main_headless():
52
+ while True:
53
+ input_text = input("Enter your text here: ")
54
+ print(get_response(input_text))
55
 
56
+ if __name__ == "__main__":
57
+ main_headless()
config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import Field
2
+ from pydantic_settings import BaseSettings
3
+
4
+ class Config(BaseSettings):
5
+ hf_token: str = Field(...)
6
+ model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
7
+
8
+ chat_template: str = Field("chatml", description="Chat template for prompt formatting")
9
+ num_fewshot: int | None = Field(None, description="Option to use json mode examples")
10
+ load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
11
+ max_depth: int = Field(5, description="Maximum number of recursive iteration")
12
+
13
+ config = Config(_env_file=".env")
functioncall.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import json
4
+ from config import config
5
+ from vllm import LLM, SamplingParams
6
+
7
+ from transformers import BitsAndBytesConfig
8
+
9
+ import functions
10
+ from prompter import PromptManager
11
+ from validator import validate_function_call_schema
12
+
13
+ from utils import (
14
+ inference_logger,
15
+ get_assistant_message,
16
+ get_chat_template,
17
+ validate_and_extract_tool_calls
18
+ )
19
+
20
+ class ModelInference:
21
+ def __init__(self, chat_template: str, load_in_4bit: bool = False):
22
+ self.prompter = PromptManager()
23
+ self.bnb_config = None
24
+
25
+ if load_in_4bit == "True": # Never use this
26
+ self.bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_use_double_quant=True,
30
+ )
31
+ self.model = LLM(model=config.model)
32
+
33
+ self.tokenizer = self.model.get_tokenizer()
34
+ self.tokenizer.pad_token = self.tokenizer.eos_token
35
+ self.tokenizer.padding_side = "left"
36
+
37
+ if self.tokenizer.chat_template is None:
38
+ print("No chat template defined, getting chat_template...")
39
+ self.tokenizer.chat_template = get_chat_template(chat_template)
40
+
41
+ inference_logger.info(self.model.config)
42
+ inference_logger.info(self.model.generation_config)
43
+ inference_logger.info(self.tokenizer.special_tokens_map)
44
+
45
+ def process_completion_and_validate(self, completion, chat_template):
46
+
47
+ assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
48
+
49
+ if assistant_message:
50
+ validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message)
51
+
52
+ if validation:
53
+ inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
54
+ return tool_calls, assistant_message, error_message
55
+ else:
56
+ tool_calls = None
57
+ return tool_calls, assistant_message, error_message
58
+ else:
59
+ inference_logger.warning("Assistant message is None")
60
+ raise ValueError("Assistant message is None")
61
+
62
+ def execute_function_call(self, tool_call):
63
+ function_name = tool_call.get("name")
64
+ function_to_call = getattr(functions, function_name, None)
65
+ function_args = tool_call.get("arguments", {})
66
+
67
+ inference_logger.info(f"Invoking function call {function_name} ...")
68
+ function_response = function_to_call(*function_args.values())
69
+ results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
70
+ return results_dict
71
+
72
+ def run_inference(self, prompt):
73
+ sampling_params = SamplingParams(
74
+ temperature=0.8,
75
+ top_p=0.95,
76
+ repetition_penalty=1.1,
77
+ max_tokens=500,
78
+ stop_token_ids=[128009])
79
+
80
+ outputs = self.model.generate([prompt], sampling_params)
81
+ for output in outputs:
82
+ return output.outputs[0].text
83
+
84
+ def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
85
+ try:
86
+ depth = 0
87
+ user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
88
+ chat = [{"role": "user", "content": user_message}]
89
+ tools = functions.get_openai_tools()
90
+ prompt = self.prompter.generate_prompt(chat, tools, num_fewshot)
91
+ completion = self.run_inference(prompt)
92
+
93
+ def recursive_loop(prompt, completion, depth):
94
+ nonlocal max_depth
95
+ tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
96
+ prompt.append({"role": "assistant", "content": assistant_message})
97
+
98
+ tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
99
+ if tool_calls:
100
+ inference_logger.info(f"Assistant Message:\n{assistant_message}")
101
+
102
+ for tool_call in tool_calls:
103
+ validation, message = validate_function_call_schema(tool_call, tools)
104
+ if validation:
105
+ try:
106
+ function_response = self.execute_function_call(tool_call)
107
+ tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
108
+ inference_logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
109
+ except Exception as e:
110
+ inference_logger.info(f"Could not execute function: {e}")
111
+ tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
112
+ else:
113
+ inference_logger.info(message)
114
+ tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
115
+ prompt.append({"role": "tool", "content": tool_message})
116
+
117
+ depth += 1
118
+ if depth >= max_depth:
119
+ print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
120
+ return
121
+
122
+ completion = self.run_inference(prompt)
123
+ recursive_loop(prompt, completion, depth)
124
+ elif error_message:
125
+ inference_logger.info(f"Assistant Message:\n{assistant_message}")
126
+ tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
127
+ prompt.append({"role": "tool", "content": tool_message})
128
+
129
+ depth += 1
130
+ if depth >= max_depth:
131
+ print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
132
+ return
133
+
134
+ completion = self.run_inference(prompt)
135
+ recursive_loop(prompt, completion, depth)
136
+ else:
137
+ inference_logger.info(f"Assistant Message:\n{assistant_message}")
138
+
139
+ recursive_loop(prompt, completion, depth)
140
+
141
+ except Exception as e:
142
+ inference_logger.error(f"Exception occurred: {e}")
143
+ raise e
144
+
145
+ if __name__ == "__main__":
146
+ parser = argparse.ArgumentParser(description="Run recursive function calling loop")
147
+ parser.add_argument("--model_path", type=str, help="Path to the model folder")
148
+ parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting")
149
+ parser.add_argument("--num_fewshot", type=int, default=None, help="Option to use json mode examples")
150
+ parser.add_argument("--load_in_4bit", type=str, default="False", help="Option to load in 4bit with bitsandbytes")
151
+ parser.add_argument("--query", type=str, default="I need the current stock price of Tesla (TSLA)")
152
+ parser.add_argument("--max_depth", type=int, default=5, help="Maximum number of recursive iteration")
153
+ args = parser.parse_args()
154
+
155
+ # specify custom model path
156
+ if args.model_path:
157
+ inference = ModelInference(args.model_path, args.chat_template, args.load_in_4bit)
158
+ else:
159
+ model_path = 'InvestmentResearchAI/LLM-ADE-dev'
160
+ inference = ModelInference(model_path, args.chat_template, args.load_in_4bit)
161
+
162
+ # Run the model evaluator
163
+ inference.generate_function_call(args.query, args.chat_template, args.num_fewshot, args.max_depth)
functions.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inspect
3
+ import requests
4
+ import pandas as pd
5
+ import yfinance as yf
6
+ import concurrent.futures
7
+
8
+ from typing import List
9
+ from bs4 import BeautifulSoup
10
+ from utils import inference_logger
11
+ from langchain.tools import tool
12
+ from langchain_core.utils.function_calling import convert_to_openai_tool
13
+
14
+ @tool
15
+ def code_interpreter(code_markdown: str) -> dict | str:
16
+ """
17
+ Execute the provided Python code string on the terminal using exec.
18
+
19
+ The string should contain valid, executable and pure Python code in markdown syntax.
20
+ Code should also import any required Python packages.
21
+
22
+ Args:
23
+ code_markdown (str): The Python code with markdown syntax to be executed.
24
+ For example: ```python\n<code-string>\n```
25
+
26
+ Returns:
27
+ dict | str: A dictionary containing variables declared and values returned by function calls,
28
+ or an error message if an exception occurred.
29
+
30
+ Note:
31
+ Use this function with caution, as executing arbitrary code can pose security risks.
32
+ """
33
+ try:
34
+ # Extracting code from Markdown code block
35
+ code_lines = code_markdown.split('\n')[1:-1]
36
+ code_without_markdown = '\n'.join(code_lines)
37
+
38
+ # Create a new namespace for code execution
39
+ exec_namespace = {}
40
+
41
+ # Execute the code in the new namespace
42
+ exec(code_without_markdown, exec_namespace)
43
+
44
+ # Collect variables and function call results
45
+ result_dict = {}
46
+ for name, value in exec_namespace.items():
47
+ if callable(value):
48
+ try:
49
+ result_dict[name] = value()
50
+ except TypeError:
51
+ # If the function requires arguments, attempt to call it with arguments from the namespace
52
+ arg_names = inspect.getfullargspec(value).args
53
+ args = {arg_name: exec_namespace.get(arg_name) for arg_name in arg_names}
54
+ result_dict[name] = value(**args)
55
+ elif not name.startswith('_'): # Exclude variables starting with '_'
56
+ result_dict[name] = value
57
+
58
+ return result_dict
59
+
60
+ except Exception as e:
61
+ error_message = f"An error occurred: {e}"
62
+ inference_logger.error(error_message)
63
+ return error_message
64
+
65
+ @tool
66
+ def google_search_and_scrape(query: str) -> dict:
67
+ """
68
+ Performs a Google search for the given query, retrieves the top search result URLs,
69
+ and scrapes the text content and table data from those pages in parallel.
70
+
71
+ Args:
72
+ query (str): The search query.
73
+ Returns:
74
+ list: A list of dictionaries containing the URL, text content, and table data for each scraped page.
75
+ """
76
+ num_results = 2
77
+ url = 'https://www.google.com/search'
78
+ params = {'q': query, 'num': num_results}
79
+ headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3'}
80
+
81
+ inference_logger.info(f"Performing google search with query: {query}\nplease wait...")
82
+ response = requests.get(url, params=params, headers=headers)
83
+ soup = BeautifulSoup(response.text, 'html.parser')
84
+ urls = [result.find('a')['href'] for result in soup.find_all('div', class_='tF2Cxc')]
85
+
86
+ inference_logger.info(f"Scraping text from urls, please wait...")
87
+ [inference_logger.info(url) for url in urls]
88
+ with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
89
+ futures = [executor.submit(lambda url: (url, requests.get(url, headers=headers).text if isinstance(url, str) else None), url) for url in urls[:num_results] if isinstance(url, str)]
90
+ results = []
91
+ for future in concurrent.futures.as_completed(futures):
92
+ url, html = future.result()
93
+ soup = BeautifulSoup(html, 'html.parser')
94
+ paragraphs = [p.text.strip() for p in soup.find_all('p') if p.text.strip()]
95
+ text_content = ' '.join(paragraphs)
96
+ text_content = re.sub(r'\s+', ' ', text_content)
97
+ table_data = [[cell.get_text(strip=True) for cell in row.find_all('td')] for table in soup.find_all('table') for row in table.find_all('tr')]
98
+ if text_content or table_data:
99
+ results.append({'url': url, 'content': text_content, 'tables': table_data})
100
+ return results
101
+
102
+ @tool
103
+ def get_current_stock_price(symbol: str) -> float:
104
+ """
105
+ Get the current stock price for a given symbol.
106
+
107
+ Args:
108
+ symbol (str): The stock symbol.
109
+
110
+ Returns:
111
+ float: The current stock price, or None if an error occurs.
112
+ """
113
+ try:
114
+ stock = yf.Ticker(symbol)
115
+ # Use "regularMarketPrice" for regular market hours, or "currentPrice" for pre/post market
116
+ current_price = stock.info.get("regularMarketPrice", stock.info.get("currentPrice"))
117
+ return current_price if current_price else None
118
+ except Exception as e:
119
+ print(f"Error fetching current price for {symbol}: {e}")
120
+ return None
121
+
122
+ @tool
123
+ def get_stock_fundamentals(symbol: str) -> dict:
124
+ """
125
+ Get fundamental data for a given stock symbol using yfinance API.
126
+
127
+ Args:
128
+ symbol (str): The stock symbol.
129
+
130
+ Returns:
131
+ dict: A dictionary containing fundamental data.
132
+ Keys:
133
+ - 'symbol': The stock symbol.
134
+ - 'company_name': The long name of the company.
135
+ - 'sector': The sector to which the company belongs.
136
+ - 'industry': The industry to which the company belongs.
137
+ - 'market_cap': The market capitalization of the company.
138
+ - 'pe_ratio': The forward price-to-earnings ratio.
139
+ - 'pb_ratio': The price-to-book ratio.
140
+ - 'dividend_yield': The dividend yield.
141
+ - 'eps': The trailing earnings per share.
142
+ - 'beta': The beta value of the stock.
143
+ - '52_week_high': The 52-week high price of the stock.
144
+ - '52_week_low': The 52-week low price of the stock.
145
+ """
146
+ try:
147
+ stock = yf.Ticker(symbol)
148
+ info = stock.info
149
+ fundamentals = {
150
+ 'symbol': symbol,
151
+ 'company_name': info.get('longName', ''),
152
+ 'sector': info.get('sector', ''),
153
+ 'industry': info.get('industry', ''),
154
+ 'market_cap': info.get('marketCap', None),
155
+ 'pe_ratio': info.get('forwardPE', None),
156
+ 'pb_ratio': info.get('priceToBook', None),
157
+ 'dividend_yield': info.get('dividendYield', None),
158
+ 'eps': info.get('trailingEps', None),
159
+ 'beta': info.get('beta', None),
160
+ '52_week_high': info.get('fiftyTwoWeekHigh', None),
161
+ '52_week_low': info.get('fiftyTwoWeekLow', None)
162
+ }
163
+ return fundamentals
164
+ except Exception as e:
165
+ print(f"Error getting fundamentals for {symbol}: {e}")
166
+ return {}
167
+
168
+ @tool
169
+ def get_financial_statements(symbol: str) -> dict:
170
+ """
171
+ Get financial statements for a given stock symbol.
172
+
173
+ Args:
174
+ symbol (str): The stock symbol.
175
+
176
+ Returns:
177
+ dict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement).
178
+ """
179
+ try:
180
+ stock = yf.Ticker(symbol)
181
+ financials = stock.financials
182
+ return financials
183
+ except Exception as e:
184
+ print(f"Error fetching financial statements for {symbol}: {e}")
185
+ return {}
186
+
187
+ @tool
188
+ def get_key_financial_ratios(symbol: str) -> dict:
189
+ """
190
+ Get key financial ratios for a given stock symbol.
191
+
192
+ Args:
193
+ symbol (str): The stock symbol.
194
+
195
+ Returns:
196
+ dict: Dictionary containing key financial ratios.
197
+ """
198
+ try:
199
+ stock = yf.Ticker(symbol)
200
+ key_ratios = stock.info
201
+ return key_ratios
202
+ except Exception as e:
203
+ print(f"Error fetching key financial ratios for {symbol}: {e}")
204
+ return {}
205
+
206
+ @tool
207
+ def get_analyst_recommendations(symbol: str) -> pd.DataFrame:
208
+ """
209
+ Get analyst recommendations for a given stock symbol.
210
+
211
+ Args:
212
+ symbol (str): The stock symbol.
213
+
214
+ Returns:
215
+ pd.DataFrame: DataFrame containing analyst recommendations.
216
+ """
217
+ try:
218
+ stock = yf.Ticker(symbol)
219
+ recommendations = stock.recommendations
220
+ return recommendations
221
+ except Exception as e:
222
+ print(f"Error fetching analyst recommendations for {symbol}: {e}")
223
+ return pd.DataFrame()
224
+
225
+ @tool
226
+ def get_dividend_data(symbol: str) -> pd.DataFrame:
227
+ """
228
+ Get dividend data for a given stock symbol.
229
+
230
+ Args:
231
+ symbol (str): The stock symbol.
232
+
233
+ Returns:
234
+ pd.DataFrame: DataFrame containing dividend data.
235
+ """
236
+ try:
237
+ stock = yf.Ticker(symbol)
238
+ dividends = stock.dividends
239
+ return dividends
240
+ except Exception as e:
241
+ print(f"Error fetching dividend data for {symbol}: {e}")
242
+ return pd.DataFrame()
243
+
244
+ @tool
245
+ def get_company_news(symbol: str) -> pd.DataFrame:
246
+ """
247
+ Get company news and press releases for a given stock symbol.
248
+
249
+ Args:
250
+ symbol (str): The stock symbol.
251
+
252
+ Returns:
253
+ pd.DataFrame: DataFrame containing company news and press releases.
254
+ """
255
+ try:
256
+ news = yf.Ticker(symbol).news
257
+ return news
258
+ except Exception as e:
259
+ print(f"Error fetching company news for {symbol}: {e}")
260
+ return pd.DataFrame()
261
+
262
+ @tool
263
+ def get_technical_indicators(symbol: str) -> pd.DataFrame:
264
+ """
265
+ Get technical indicators for a given stock symbol.
266
+
267
+ Args:
268
+ symbol (str): The stock symbol.
269
+
270
+ Returns:
271
+ pd.DataFrame: DataFrame containing technical indicators.
272
+ """
273
+ try:
274
+ indicators = yf.Ticker(symbol).history(period="max")
275
+ return indicators
276
+ except Exception as e:
277
+ print(f"Error fetching technical indicators for {symbol}: {e}")
278
+ return pd.DataFrame()
279
+
280
+ @tool
281
+ def get_company_profile(symbol: str) -> dict:
282
+ """
283
+ Get company profile and overview for a given stock symbol.
284
+
285
+ Args:
286
+ symbol (str): The stock symbol.
287
+
288
+ Returns:
289
+ dict: Dictionary containing company profile and overview.
290
+ """
291
+ try:
292
+ profile = yf.Ticker(symbol).info
293
+ return profile
294
+ except Exception as e:
295
+ print(f"Error fetching company profile for {symbol}: {e}")
296
+ return {}
297
+
298
+ def get_openai_tools() -> List[dict]:
299
+ functions = [
300
+ code_interpreter,
301
+ google_search_and_scrape,
302
+ get_current_stock_price,
303
+ get_company_news,
304
+ get_company_profile,
305
+ get_stock_fundamentals,
306
+ get_financial_statements,
307
+ get_key_financial_ratios,
308
+ get_analyst_recommendations,
309
+ get_dividend_data,
310
+ get_technical_indicators
311
+ ]
312
+
313
+ tools = [convert_to_openai_tool(f) for f in functions]
314
+ return tools
prompt_assets/few_shot.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n</tool_call>\n```\n"
4
+ },
5
+ {
6
+ "example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n</tool_call>\n```"
7
+ }
8
+ ]
prompt_assets/sys_prompt.yml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Role: |
2
+ You are an expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
3
+ You are a function calling AI agent with self-recursion.
4
+ You can call only one function at a time and analyse data you get from function response.
5
+ You are provided with function signatures within <tools></tools> XML tags.
6
+ The current date is: {date}.
7
+ Objective: |
8
+ You may use agentic frameworks for reasoning and planning to help with user query.
9
+ Please call a function and wait for function results to be provided to you in the next iteration.
10
+ Don't make assumptions about what values to plug into function arguments.
11
+ Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
12
+ Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
13
+ Analyze the data once you get the results and call another function.
14
+ At each iteration please continue adding the your analysis to previous summary.
15
+ Your final response should directly answer the user query with an anlysis or summary of the results of function calls.
16
+ Tools: |
17
+ Here are the available tools:
18
+ <tools> {tools} </tools>
19
+ If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
20
+ <tool_call>
21
+ {{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
22
+ </tool_call>
23
+ Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
24
+ Examples: |
25
+ Here are some example usage of functions:
26
+ {examples}
27
+ Schema: |
28
+ Use the following pydantic model json schema for each tool call you will make:
29
+ {schema}
30
+ Instructions: |
31
+ At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
32
+ Please keep a running summary with analysis of previous function results and summaries from previous iterations.
33
+ Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
34
+ Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
35
+ If you plan to continue with analysis, always call another function.
36
+ For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
37
+ <tool_call>
38
+ {{"arguments": <args-dict>, "name": <function-name>}}
39
+ </tool_call>
40
+ Style and tone: |
41
+ Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
42
+ Audience: |
43
+ The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
prompter.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from pydantic import BaseModel
3
+ from typing import Dict
4
+ from schema import FunctionCall
5
+ from utils import (
6
+ get_fewshot_examples
7
+ )
8
+ import yaml
9
+ import json
10
+ import os
11
+
12
+ class PromptSchema(BaseModel):
13
+ Role: str
14
+ Objective: str
15
+ Tools: str
16
+ Examples: str
17
+ Schema: str
18
+ Instructions: str
19
+
20
+ class PromptManager:
21
+ def __init__(self):
22
+ self.script_dir = os.path.dirname(os.path.abspath(__file__))
23
+
24
+ def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str:
25
+ formatted_prompt = ""
26
+ for field, value in prompt_schema.dict().items():
27
+ if field == "Examples" and variables.get("examples") is None:
28
+ continue
29
+ formatted_value = value.format(**variables)
30
+ if field == "Instructions":
31
+ formatted_prompt += f"{formatted_value}"
32
+ else:
33
+ formatted_value = formatted_value.replace("\n", " ")
34
+ formatted_prompt += f"{formatted_value}"
35
+ return formatted_prompt
36
+
37
+ def read_yaml_file(self, file_path: str) -> PromptSchema:
38
+ with open(file_path, 'r') as file:
39
+ yaml_content = yaml.safe_load(file)
40
+
41
+ prompt_schema = PromptSchema(
42
+ Role=yaml_content.get('Role', ''),
43
+ Objective=yaml_content.get('Objective', ''),
44
+ Tools=yaml_content.get('Tools', ''),
45
+ Examples=yaml_content.get('Examples', ''),
46
+ Schema=yaml_content.get('Schema', ''),
47
+ Instructions=yaml_content.get('Instructions', ''),
48
+ )
49
+ return prompt_schema
50
+
51
+ def generate_prompt(self, user_prompt, tools, num_fewshot=None):
52
+ prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt.yml')
53
+ prompt_schema = self.read_yaml_file(prompt_path)
54
+
55
+ if num_fewshot is not None:
56
+ examples = get_fewshot_examples(num_fewshot)
57
+ else:
58
+ examples = None
59
+
60
+ schema_json = json.loads(FunctionCall.schema_json())
61
+
62
+ variables = {
63
+ "date": datetime.date.today(),
64
+ "tools": tools,
65
+ "examples": examples,
66
+ "schema": schema_json
67
+ }
68
+ sys_prompt = self.format_yaml_prompt(prompt_schema, variables)
69
+
70
+ prompt = [
71
+ {'content': sys_prompt, 'role': 'system'}
72
+ ]
73
+ prompt.extend(user_prompt)
74
+ return prompt
75
+
76
+
schema.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Dict, Literal, Optional
3
+
4
+ class FunctionCall(BaseModel):
5
+ arguments: dict
6
+ """
7
+ The arguments to call the function with, as generated by the model in JSON
8
+ format. Note that the model does not always generate valid JSON, and may
9
+ hallucinate parameters not defined by your function schema. Validate the
10
+ arguments in your code before calling your function.
11
+ """
12
+
13
+ name: str
14
+ """The name of the function to call."""
15
+
16
+ class FunctionDefinition(BaseModel):
17
+ name: str
18
+ description: Optional[str] = None
19
+ parameters: Optional[Dict[str, object]] = None
20
+
21
+ class FunctionSignature(BaseModel):
22
+ function: FunctionDefinition
23
+ type: Literal["function"]
utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+ import re
4
+ import json
5
+ import logging
6
+ import datetime
7
+ import xml.etree.ElementTree as ET
8
+
9
+ from logging.handlers import RotatingFileHandler
10
+
11
+ logging.basicConfig(
12
+ format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
13
+ datefmt="%Y-%m-%d:%H:%M:%S",
14
+ level=logging.INFO,
15
+ )
16
+ script_dir = os.path.dirname(os.path.abspath(__file__))
17
+ now = datetime.datetime.now()
18
+ log_folder = os.path.join(script_dir, "inference_logs")
19
+ os.makedirs(log_folder, exist_ok=True)
20
+ log_file_path = os.path.join(
21
+ log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log"
22
+ )
23
+ # Use RotatingFileHandler from the logging.handlers module
24
+ file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0)
25
+ file_handler.setLevel(logging.INFO)
26
+
27
+ formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
28
+ file_handler.setFormatter(formatter)
29
+
30
+ inference_logger = logging.getLogger("function-calling-inference")
31
+ inference_logger.addHandler(file_handler)
32
+
33
+ def get_fewshot_examples(num_fewshot):
34
+ """return a list of few shot examples"""
35
+ example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
36
+ with open(example_path, 'r') as file:
37
+ examples = json.load(file) # Use json.load with the file object, not the file path
38
+ if num_fewshot > len(examples):
39
+ raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).")
40
+ return examples[:num_fewshot]
41
+
42
+ def get_chat_template(chat_template):
43
+ """read chat template from jinja file"""
44
+ template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2")
45
+
46
+ if not os.path.exists(template_path):
47
+ print
48
+ inference_logger.error(f"Template file not found: {chat_template}")
49
+ return None
50
+ try:
51
+ with open(template_path, 'r') as file:
52
+ template = file.read()
53
+ return template
54
+ except Exception as e:
55
+ print(f"Error loading template: {e}")
56
+ return None
57
+
58
+ def get_assistant_message(completion, chat_template, eos_token):
59
+ """define and match pattern to find the assistant message"""
60
+ completion = completion.strip()
61
+
62
+ if chat_template == "zephyr":
63
+ assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL)
64
+ elif chat_template == "chatml":
65
+ assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL)
66
+
67
+ elif chat_template == "vicuna":
68
+ assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL)
69
+ else:
70
+ raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.")
71
+
72
+ assistant_match = assistant_pattern.search(completion)
73
+ if assistant_match:
74
+ assistant_content = assistant_match.group(1).strip()
75
+ if chat_template == "vicuna":
76
+ eos_token = f"</s>{eos_token}"
77
+ return assistant_content.replace(eos_token, "")
78
+ else:
79
+ assistant_content = None
80
+ inference_logger.info("No match found for the assistant pattern")
81
+ return assistant_content
82
+
83
+ def validate_and_extract_tool_calls(assistant_content):
84
+ validation_result = False
85
+ tool_calls = []
86
+ error_message = None
87
+
88
+ try:
89
+ # wrap content in root element
90
+ xml_root_element = f"<root>{assistant_content}</root>"
91
+ root = ET.fromstring(xml_root_element)
92
+
93
+ # extract JSON data
94
+ for element in root.findall(".//tool_call"):
95
+ json_data = None
96
+ try:
97
+ json_text = element.text.strip()
98
+
99
+ try:
100
+ # Prioritize json.loads for better error handling
101
+ json_data = json.loads(json_text)
102
+ except json.JSONDecodeError as json_err:
103
+ try:
104
+ # Fallback to ast.literal_eval if json.loads fails
105
+ json_data = ast.literal_eval(json_text)
106
+ except (SyntaxError, ValueError) as eval_err:
107
+ error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\
108
+ f"- JSON Decode Error: {json_err}\n"\
109
+ f"- Fallback Syntax/Value Error: {eval_err}\n"\
110
+ f"- Problematic JSON text: {json_text}"
111
+ inference_logger.error(error_message)
112
+ continue
113
+ except Exception as e:
114
+ error_message = f"Cannot strip text: {e}"
115
+ inference_logger.error(error_message)
116
+
117
+ if json_data is not None:
118
+ tool_calls.append(json_data)
119
+ validation_result = True
120
+
121
+ except ET.ParseError as err:
122
+ error_message = f"XML Parse Error: {err}"
123
+ inference_logger.error(f"XML Parse Error: {err}")
124
+
125
+ # Return default values if no valid data is extracted
126
+ return validation_result, tool_calls, error_message
127
+
128
+ def extract_json_from_markdown(text):
129
+ """
130
+ Extracts the JSON string from the given text using a regular expression pattern.
131
+
132
+ Args:
133
+ text (str): The input text containing the JSON string.
134
+
135
+ Returns:
136
+ dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
137
+ """
138
+ json_pattern = r'```json\r?\n(.*?)\r?\n```'
139
+ match = re.search(json_pattern, text, re.DOTALL)
140
+ if match:
141
+ json_string = match.group(1)
142
+ try:
143
+ data = json.loads(json_string)
144
+ return data
145
+ except json.JSONDecodeError as e:
146
+ print(f"Error decoding JSON string: {e}")
147
+ else:
148
+ print("JSON string not found in the text.")
149
+ return None
validator.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ from jsonschema import validate
4
+ from pydantic import ValidationError
5
+ from utils import inference_logger, extract_json_from_markdown
6
+ from schema import FunctionCall, FunctionSignature
7
+
8
+ def validate_function_call_schema(call, signatures):
9
+ try:
10
+ call_data = FunctionCall(**call)
11
+ except ValidationError as e:
12
+ return False, str(e)
13
+
14
+ for signature in signatures:
15
+ try:
16
+ signature_data = FunctionSignature(**signature)
17
+ if signature_data.function.name == call_data.name:
18
+ # Validate types in function arguments
19
+ for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
20
+ if arg_name in call_data.arguments:
21
+ call_arg_value = call_data.arguments[arg_name]
22
+ if call_arg_value:
23
+ try:
24
+ validate_argument_type(arg_name, call_arg_value, arg_schema)
25
+ except Exception as arg_validation_error:
26
+ return False, str(arg_validation_error)
27
+
28
+ # Check if all required arguments are present
29
+ required_arguments = signature_data.function.parameters.get('required', [])
30
+ result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments)
31
+ if not result:
32
+ return False, f"Missing required arguments: {missing_arguments}"
33
+
34
+ return True, None
35
+ except Exception as e:
36
+ # Handle validation errors for the function signature
37
+ return False, str(e)
38
+
39
+ # No matching function signature found
40
+ return False, f"No matching function signature found for function: {call_data.name}"
41
+
42
+ def check_required_arguments(call_arguments, required_arguments):
43
+ missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
44
+ return not bool(missing_arguments), missing_arguments
45
+
46
+ def validate_enum_value(arg_name, arg_value, enum_values):
47
+ if arg_value not in enum_values:
48
+ raise Exception(
49
+ f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
50
+ )
51
+
52
+ def validate_argument_type(arg_name, arg_value, arg_schema):
53
+ arg_type = arg_schema.get('type', None)
54
+ if arg_type:
55
+ if arg_type == 'string' and 'enum' in arg_schema:
56
+ enum_values = arg_schema['enum']
57
+ if None not in enum_values and enum_values != []:
58
+ try:
59
+ validate_enum_value(arg_name, arg_value, enum_values)
60
+ except Exception as e:
61
+ # Propagate the validation error message
62
+ raise Exception(f"Error validating function call: {e}")
63
+
64
+ python_type = get_python_type(arg_type)
65
+ if not isinstance(arg_value, python_type):
66
+ raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")
67
+
68
+ def get_python_type(json_type):
69
+ type_mapping = {
70
+ 'string': str,
71
+ 'number': (int, float),
72
+ 'integer': int,
73
+ 'boolean': bool,
74
+ 'array': list,
75
+ 'object': dict,
76
+ 'null': type(None),
77
+ }
78
+ return type_mapping[json_type]
79
+
80
+ def validate_json_data(json_object, json_schema):
81
+ valid = False
82
+ error_message = None
83
+ result_json = None
84
+
85
+ try:
86
+ # Attempt to load JSON using json.loads
87
+ try:
88
+ result_json = json.loads(json_object)
89
+ except json.decoder.JSONDecodeError:
90
+ # If json.loads fails, try ast.literal_eval
91
+ try:
92
+ result_json = ast.literal_eval(json_object)
93
+ except (SyntaxError, ValueError) as e:
94
+ try:
95
+ result_json = extract_json_from_markdown(json_object)
96
+ except Exception as e:
97
+ error_message = f"JSON decoding error: {e}"
98
+ inference_logger.info(f"Validation failed for JSON data: {error_message}")
99
+ return valid, result_json, error_message
100
+
101
+ # Return early if both json.loads and ast.literal_eval fail
102
+ if result_json is None:
103
+ error_message = "Failed to decode JSON data"
104
+ inference_logger.info(f"Validation failed for JSON data: {error_message}")
105
+ return valid, result_json, error_message
106
+
107
+ # Validate each item in the list against schema if it's a list
108
+ if isinstance(result_json, list):
109
+ for index, item in enumerate(result_json):
110
+ try:
111
+ validate(instance=item, schema=json_schema)
112
+ inference_logger.info(f"Item {index+1} is valid against the schema.")
113
+ except ValidationError as e:
114
+ error_message = f"Validation failed for item {index+1}: {e}"
115
+ break
116
+ else:
117
+ # Default to validation without list
118
+ try:
119
+ validate(instance=result_json, schema=json_schema)
120
+ except ValidationError as e:
121
+ error_message = f"Validation failed: {e}"
122
+
123
+ except Exception as e:
124
+ error_message = f"Error occurred: {e}"
125
+
126
+ if error_message is None:
127
+ valid = True
128
+ inference_logger.info("JSON data is valid against the schema.")
129
+ else:
130
+ inference_logger.info(f"Validation failed for JSON data: {error_message}")
131
+
132
+ return valid, result_json, error_message