xtrade_bot / openai_functions_and_agents.py
Josh-Ola's picture
Upload folder using huggingface_hub
65976bc verified
"""
Module that contains all the functions used in the project.
Ensure no sensitive datapoint or endpoint is exposed here,
as the contents in this module are often fed into the LLM.
Instead, put sensitive contents in a different location and
call it in this module.
"""
# Import things that are needed generically
import os
from enum import Enum
import random
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.pydantic_v1 import (
BaseModel,
Field,
)
from langchain.tools import tool
from langchain.tools.retriever import create_retriever_tool
from langchain.tools.render import format_tool_to_openai_function
from langchain_community.utils.openai_functions import (
convert_pydantic_to_openai_function,
)
from langchain_community.vectorstores.chroma import Chroma
from dotenv import (
load_dotenv,
find_dotenv,
)
import chromadb
from constants import CHROMA_SETTINGS
from _endpoints import (
get_securities_list_final, # implemented, not tested
get_securities_and_prices_final, # implemented, not tested
# get_price_history, # implement when endpoint is working
get_wallet_history_final, # implemented, not tested
get_boards_list_final, # implementing
get_order_history_final, # yet to implement
get_tradeable_commodities_list_final, # yet to implement
put_up_trade, # implemented, not tested thoroughly
)
#--------- Load environment variables ---------#
if not load_dotenv(find_dotenv()):
print("Could not load `.env` file or it is empty. Please check that it exists \
and is readable by the current user")
#--------- Read environment variables ---------#
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
persist_directory = os.environ.get("PERSIST_DIRECTORY", "db")
source_directory = os.environ.get('SOURCE_DIRECTORY', "data")
LEN_RETURN_LIST = 5
embeddings_model = OpenAIEmbeddings()
def create_consumable_functions():
"""
For creating functions from pydantic shemas for openai consumption
"""
class PlaceBuyOrder(BaseModel):
"""Use for placing buy order. Don't assume; request any missing parameter \
from the user."""
commodity: str = Field(description="the name of the commodity")
quantity: float = Field(description="the quantity to be traded")
unit_price: float = Field(description="the rate at which the commodity should be traded")
class PlaceSellOrder(BaseModel):
"""Use for placing sell order. Don't assume; request any missing parameter \
from the user."""
commodity: str = Field(description="the name of the commodity")
quantity: float = Field(description="the quantity to be traded")
unit_price: float = Field(description="the rate at which the commodity should be traded")
return [
convert_pydantic_to_openai_function(PlaceBuyOrder),
convert_pydantic_to_openai_function(PlaceSellOrder),
]
TODO = """
Add a new tools that confirms if all the parameters needed for buy/sell order are available and
should return a confirmation message like "are you sure you want to place a --- order for <cdty>
at <rate> for <qty>. this tool should ALWAYS be called and affirmed to before calling the buy/sell
tool. handle any missing params in it and send back to the user to request for it
"""
class _SchemaGetCurrentSecuritiesPrice(BaseModel):
security_code: str | list | None = Field(description="the code of the security. Can be a \
string or a list of strings. If None, it fetches a long list of all the commodities and their \
respective prices")
class _SchemaBuyOrder(BaseModel):
security_code: str = Field(..., description="the code of the security to be traded on. If not \
given, return number 0.")
units: float = Field(..., description="the quantity of the security for the buy order. If \
not given, return number 0.")
unit_price: float = Field(..., description="the rate or unit price for the sell order. If \
not given, return number 0.")
is_order_confirmed: bool = Field(..., description="return true if the user has confirmed \
the order, otherwise false.")
class _SchemaSellOrder(BaseModel):
security_code: str = Field(..., description="the code of the security to be traded on. If not \
given, return number 0.")
units: float = Field(..., description="the quantity of the security for the sell order. If \
not given, return number 0.")
unit_price: float = Field(..., description="the rate or unit price for the sell order. If \
not given, return number 0.")
is_order_confirmed: bool = Field(..., description="return true if the user has confirmed \
the order, otherwise false.")
class _ConfirmOrderSchema(BaseModel):
order_type: str = Field(const=False, description="the order type. should be one of either \
'buy' or 'sell'.")
security_code: str = Field(const=False, description="the code of the security to be traded on. \
If not given, return number 0.")
quantity: float = Field(const=False, description="the quantity of the security for the sell \
order. If not given, return number 0.")
unit_price: float = Field(const=False, description="the rate or unit price for the sell \
order. If not given, return number 0.")
class _WikipediaSchema(BaseModel):
query: str = Field(..., description="the query to search for on wikipedia")
class OrderType(Enum):
BUY = "BUY"
SELL = "SELL"
# the tool
@tool
def CheckAccountBalance():
"""Always use this tool to check the account balance of the user. It does not require any parameter to run and it returns the account balance in Naira."""
amounts = [round(random.uniform(1, 2381218),2) for _ in range(3)]
accounts = {
"account_bal": None,
"lien_balance": None,
"portfolio_balance": None,
}
for acc, amt in zip(accounts, amounts):
accounts[acc]= amt
return accounts
# TO-DO: Consider renaming to "getcurrentprice" and evaluate the effect on performance
# TO-DO: Add a well-detailed docstring for each tool, stating the params they receive (if any)\
# and what they return (most importantly) - name and description.
@tool
def GetSecuritiesList(
board_code: str | list | None = None,
commodity_code: str | list | None = None
):
"""Always use this tool to get the list of tradeable securities on our platform,
and their respective boards without prices"""
response = get_securities_list_final(
board_code=board_code,
commodity_code=commodity_code
)
try:
return response[:LEN_RETURN_LIST] \
if isinstance(response, list) and len(response)>LEN_RETURN_LIST else response
except:
return
@tool(args_schema=_SchemaGetCurrentSecuritiesPrice)
def GetCurrentSecuritiesPrice(security_code: str | list | None = None)-> str | list:
"""Always use this tool to check the current price of commodities in Naira. \
ALWAYS use the Naira symbol for the price"""
response = get_securities_and_prices_final(str(security_code).upper())
try:
return response[:LEN_RETURN_LIST] \
if isinstance(response, list) and len(response)>LEN_RETURN_LIST else response
except:
return
@tool
def GetPriceHistoryofSecurities():
"Not implemented"
return NotImplemented
@tool
def GetCustomerWalletHistory():
"""Always use this tool to get the details of previous transactions on the user's wallet or account."""
response = get_wallet_history_final()
try:
return response[:LEN_RETURN_LIST] \
if isinstance(response, list) and len(response)>LEN_RETURN_LIST else response
except:
return
@tool
def GetBoardsList():
"""Always use this tool to get the list of tradeable boards"""
response = get_boards_list_final()
try:
return response[:LEN_RETURN_LIST] \
if isinstance(response, list) and len(response)>LEN_RETURN_LIST else response
except:
return
@tool
def GetCustomerOrderHistory():
"""Always use this tool to get the previous transactions of the user"""
response = get_order_history_final()
try:
return response[:LEN_RETURN_LIST] \
if isinstance(response, list) and len(response)>LEN_RETURN_LIST else response
except:
return
@tool
def GetTradeableCommodities():
"""Always use this tool to get the tradeable commodities on your platform """
response = get_tradeable_commodities_list_final()
try:
return response[:LEN_RETURN_LIST] \
if isinstance(response, list) and len(response)>LEN_RETURN_LIST else response
except:
return
@tool(args_schema=_ConfirmOrderSchema)
def ConfirmOrderDetails(
order_type: OrderType,
security_code: str ,
quantity: float ,
unit_price: float,
) -> str:
"""Confirm order details before carrying it out. Always use this tool before \
calling the "PlaceBuyOrder" or "PlaceSellOrder" tools. If the user confirms the \
order details, then call the right tool; if not, DON'T call the tool.\
Place a buy order. Buy orders are often referred to as "buy", "bid", or "long"."""
match order_type.upper():
case OrderType.BUY.value:
return f"Are you sure you want to place a `{order_type.upper()}` order for {quantity} \
quantities of {security_code} at {unit_price} per unit?"
case OrderType.SELL.value:
return f"Are you sure you want to place a `{order_type.upper()}` order for {quantity} \
quantities of {security_code} at {unit_price} per unit?"
case _default:
return f"Order type of `{order_type.upper()}` is not valid."
@tool(args_schema=_SchemaBuyOrder)
def PlaceBuyOrder(
security_code: str ,
units: float ,
unit_price: float,
is_order_confirmed: bool,
):
"""Place a buy order. Buy orders are often referred to as "buy", "bid", or "long"."""
# TO-DO: Add checker to be sure the security code being passed is tradeable.
try:
# If the user has not yet confirmed the order
if not is_order_confirmed:
return f"Are you sure you want to place a `BUY` order for {units} units of {security_code} \
at {unit_price} per unit?"
# If the user has confirmed the order
response = put_up_trade(
security_code= security_code.upper(),
units= units,
order_type= "buy",
market_order_type= "market_order",
order_price= unit_price,
)
# Uncomment when the endpoint is working fine
# If the response from the API is not successfull
# if str(response["responseCode"]).strip() != "100":
# return f"Sorry, an error occurred: {response['errors']}!"
# return str(response["message"]).replace("[", "`").replace("]", "`")
return f"`BUY` order for {units} unit of {security_code} at {unit_price} placed successfully."
except Exception as e:
print(f"An error occured: {e}")
return "Sorry, an error occured while carrying out the request. I have notified \
the developers and they will fix it soon."
@tool(args_schema=_SchemaSellOrder)
def PlaceSellOrder(
security_code: str ,
units: float ,
unit_price: float,
is_order_confirmed: bool,
):
"""Place a sell order. Sell orders are often referred to as "sell", "offer", or "short"."""
# TO-DO: Add checker to be sure the security code being passed is tradeable.
try:
# If the user has not yet confirmed the order
if not is_order_confirmed:
return f"Are you sure you want to place a `SELL` order for {units} units of {security_code} \
at {unit_price} per unit?"
# If the user has confirmed the order
response = put_up_trade(
security_code= security_code.upper(),
units= units,
order_type= "sell",
market_order_type= "market_order",
order_price= unit_price,
)
# Uncomment when the endpoint is working fine
# If the response from the API is not successfull
# if str(response["responseCode"]).strip() != "100":
# return f"Sorry, an error occurred: {response['errors']}!"
# return str(response["message"]).replace("[", "`").replace("]", "`")
return f"`SELL` order for {units} unit of {security_code} at {unit_price} placed successfully."
except Exception as e:
print(f"An error occured: {e}")
return "Sorry, an error occured while carrying out the request. I have notified \
the developers and they will fix it soon."
# def PlaceSellOrder(
# security_code: str ,
# units: float ,
# unit_price: float,
# ):
# """Place a sell order. Sell orders are often referred to as "sell", "offer", or "short"."""
# TO-DO: Add checker to be sure the security code being passed is tradeable.
# All_Commodities = ["SMAZ", "DMAZ"]
# if commodity.upper() not in All_Commodities:
# raise ValueError(
# "The commodity is not in the list of possible commodities to be traded on"
# )
# response = put_up_trade(
# security_code= security_code.upper(),
# units= units,
# order_type= "sell",
# market_order_type= "market_order",
# order_price= unit_price,
# )
# # print(response)
# try:
# if str(response["responseCode"]).strip() != "100":
# return f"Sorry, an error occurred: {response['errors']}!"
# return str(response["message"]).replace("[", "`").replace("]", "`")
# except Exception as e:
# print(f"An error occured: {e}")
# return "Sorry, an error occured while carrying out the request. I have notified the developers \
# and they will fix it soon."
# def PlaceSellOrder(
# commodity: str = "commodity",
# quantity: float =4,
# unit_price: float=3,
# ) -> str:
# """Place a sell order. Sell orders are often referred to as "sell", "offer", or "short"."""
# total_price = quantity * unit_price
# return f"Sell order for {commodity} placed successfully and total price is {total_price}."
# @tool
# def GetTradeableCommodities():
# """Used to get the tradeable commodities and their respective boards."""
# return [
# {
# "commodity": "Maize",
# "boards": ["otc", "virtual", "spot", "deliverable"]
# },
# {
# "commodity": "Cashew nuts",
# "boards": ["spot", "deliverable", "otc", "virtual"]
# },
# {
# "commodity": "ginger",
# "boards": ["otc", "spot", "deliverable", "virtual"]
# },
# {
# "commodity": "sesame",
# "boards": ["spot", "deliverable", "otc", "virtual"]
# },
# {
# "commodity": "sorghum",
# "boards": ["deliverable", "otc", "spot", "fixed income"]
# },
# {
# "commodity": "soya beans",
# "boards": ["spot", "deliverable", "otc", "virtual"]
# },
# {
# "commodity": "paddy rice",
# "boards": ["deliverable", "spot", "otc", "virtual"]
# },
# {
# "commodity": "millet",
# "boards": ["spot", "deliverable", "otc", "cash settled"]
# },
# {
# "commodity": "wheat",
# "boards": ["spot", "deliverable", "otc", "virtual"]
# },
# {
# "commodity": "hibiscus",
# "boards": ["spot", "deliverable", "otc", "virtual"]
# },
# ]
@tool(args_schema=_WikipediaSchema)
def search_wikipedia(query: str) -> str:
"""Run Wikipedia search and get page summaries."""
import wikipedia
page_titles = wikipedia.search(query)
summaries = []
for page_title in page_titles[: 3]:
try:
wiki_page = wikipedia.page(title=page_title, auto_suggest=False)
summaries.append(f"Page: {page_title}\nSummary: {wiki_page.summary}")
except Exception as e:
pass
if not summaries:
return "No good wikipedia search result was found"
return "\n\n".join(summaries)
#-------Create Retriever Tool-------#
def retriever_tool():
chroma_client = chromadb.PersistentClient(
settings=CHROMA_SETTINGS,
path=persist_directory,
)
db = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings_model,
client_settings=CHROMA_SETTINGS,
client=chroma_client,
)
retriever = db.as_retriever(
search_kwargs={
"k": target_source_chunks
}
)
retriever_tool = create_retriever_tool(
retriever=retriever,
name="AFEX_general_enquiry",
description="Search for information about AFEX and market histories and trends. To be used for any question or enquiry about AFEX or trading! Ensure the search query passed into this tool is as descriptive as possible to aid in retrieving relevant information for the query."
)
return retriever_tool
def consumable_functions(return_tool: bool=False):
tools = [
retriever_tool(),
CheckAccountBalance,
GetSecuritiesList,
GetCurrentSecuritiesPrice,
# GetPriceHistoryofSecurities,
GetCustomerWalletHistory,
GetBoardsList,
GetCustomerOrderHistory,
GetTradeableCommodities,
# search_wikipedia,
# ConfirmOrderDetails,
PlaceBuyOrder,
PlaceSellOrder,
]
if return_tool:
return tools
return [
format_tool_to_openai_function(tool)
for tool in tools
]
def consumable_tools(return_descripton: bool = False):
if return_descripton:
return {
"retriever_tool": "used for searching and retrieving information about AFEX and trading",
"CheckAccountBalance": "used for checking the account balance of the user",
"GetSecuritiesList": "used for getting the tradeable commodities",
"GetCurrentSecuritiesPrice": "used for checking the current price of securities",
# "GetPriceHistoryofSecurities": "used for getting the price history of commodities",
"GetCustomerWalletHistory": "used to get the history of the customer's wallet",
"GetBoardsList": "used to get the available boards list",
"GetCustomerOrderHistory": "used to get the customer's order history",
# "search_wikipedia": "used for searching and retrieving information from Wikipedia",
# "ConfirmOrderDetails": "used for confirming order details before placing the order",
"PlaceBuyOrder": "used for placing `buy`, `bid` or `long` orders",
"PlaceSellOrder": "used for placing `sell`, `offer` or `short` orders",
}
return {
"retriever_tool": retriever_tool(),
"CheckAccountBalance": CheckAccountBalance,
"GetSecuritiesList": GetSecuritiesList,
"GetCurrentSecuritiesPrice": GetCurrentSecuritiesPrice,
# "GetPriceHistoryofSecurities": GetPriceHistoryofSecurities,
"GetCustomerWalletHistory": GetCustomerWalletHistory,
"GetBoardsList": GetBoardsList,
"GetCustomerOrderHistory": GetCustomerOrderHistory,
"GetTradeableCommodities": GetTradeableCommodities,
# "search_wikipedia": search_wikipedia,
# "ConfirmOrderDetails": ConfirmOrderDetails,
"PlaceBuyOrder": PlaceBuyOrder,
"PlaceSellOrder": PlaceSellOrder,
}