# -*- coding: utf-8 -*- # Imports import asyncio import os import openai import wandb from typing import List, Optional # from pydantic import BaseModel, Field # from langchain.prompts import ChatPromptTemplate # from langchain.pydantic_v1 import BaseModel # from langchain.utils.openai_functions import convert_pydantic_to_openai_function from llama_index.tools import FunctionTool from llama_index.vector_stores.types import ( VectorStoreInfo, MetadataInfo, ExactMatchFilter, MetadataFilters, ) from llama_index.agent import OpenAIAgent from llama_index.retrievers import VectorIndexRetriever from llama_index.query_engine import RetrieverQueryEngine from typing import List, Tuple, Any from pydantic import BaseModel, Field from llama_index import load_index_from_storage from llama_index import set_global_handler import llama_index from llama_index.embeddings import OpenAIEmbedding from llama_index import ServiceContext from llama_index.llms import OpenAI from llama_index import GPTVectorStoreIndex set_global_handler("wandb", run_args={"project": "final-project-v1"}) wandb_callback = llama_index.global_handler run = wandb.init() artifact = run.use_artifact('rlpeter70/final-project-v1/earnings-index:v0', type='storage_context') artifact_dir = artifact.download() from dotenv import load_dotenv load_dotenv() openai.api_key = os.environ['OPENAI_API_KEY'] top_k = 3 vector_store_info = VectorStoreInfo( content_info="transcripts of earnings calls", metadata_info=[MetadataInfo( name="title", type="str", description="Title of the earnings call", ), MetadataInfo( name="period", type="str", description="Period of the earnings call" ), MetadataInfo( name="ticker", type="str", description="Ticker of the company" ), MetadataInfo( name="year", type="str", description="Year of the earnings call" ), MetadataInfo( name="quarter", type="str", description="Quarter of the earnings call" ), MetadataInfo( name="path", type="str", description="Path to the earnings call" ), ]) class AutoRetrieveModel(BaseModel): query: str = Field(..., description="natural language query string") filter_key_list: List[str] = Field( ..., description="List of metadata filter field names" ) filter_value_list: List[str] = Field( ..., description=( "List of metadata filter field values (corresponding to names specified in filter_key_list)" ) ) embed_model = OpenAIEmbedding() chunk_size = 500 llm = OpenAI( temperature=0, model="gpt-4" ### YOUR CODE HERE ) service_context = ServiceContext.from_defaults( llm=llm, chunk_size=chunk_size, embed_model=embed_model, ) index = GPTVectorStoreIndex.from_documents([], service_context=service_context) # Main function to extract information async def extract_information(): # Make sure to use a recent model that supports tools storage_context = wandb_callback.load_storage_context( artifact_url="chrisalexiuk/llamaindex-demo-v1/wiki-index:v0" #artifact_url=artifact_dir ) index = load_index_from_storage(storage_context, service_context=service_context) def auto_retrieve_fn( query: str, filter_key_list: List[str], filter_value_list: List[str] ): """Auto retrieval function. Performs auto-retrieval from a vector database, and then applies a set of filters. """ query = query or "Query" exact_match_filters = [ ExactMatchFilter(key=k, value=v) for k, v in zip(filter_key_list, filter_value_list) ] retriever = VectorIndexRetriever( index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k ) query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context) response = query_engine.query(query) return str(response) auto_retrieve_tool = FunctionTool.from_defaults( fn=auto_retrieve_fn, name="earnings-transcripts", description="Earnings Bot", fn_schema=AutoRetrieveModel ) agent = OpenAIAgent.from_tools( tools=[auto_retrieve_tool], ) return agent # if __name__ == "__main__": # text = "Who is the CEO of MSFT." # chain = extract_information() # print(str(chain.chat(text))) # async def extract_information_async(message: str): # return str(chain.chat(text)) # async def main(): # res = await extract_information_async(text) # print(res) # asyncio.run(main())