Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Imports | |
import asyncio | |
import os | |
import openai | |
import wandb | |
from typing import List, Optional | |
# from pydantic import BaseModel, Field | |
# from langchain.prompts import ChatPromptTemplate | |
# from langchain.pydantic_v1 import BaseModel | |
# from langchain.utils.openai_functions import convert_pydantic_to_openai_function | |
from llama_index.tools import FunctionTool | |
from llama_index.vector_stores.types import ( | |
VectorStoreInfo, | |
MetadataInfo, | |
ExactMatchFilter, | |
MetadataFilters, | |
) | |
from llama_index.agent import OpenAIAgent | |
from llama_index.retrievers import VectorIndexRetriever | |
from llama_index.query_engine import RetrieverQueryEngine | |
from typing import List, Tuple, Any | |
from pydantic import BaseModel, Field | |
from llama_index import load_index_from_storage | |
from llama_index import set_global_handler | |
import llama_index | |
from llama_index.embeddings import OpenAIEmbedding | |
from llama_index import ServiceContext | |
from llama_index.llms import OpenAI | |
from llama_index import GPTVectorStoreIndex | |
set_global_handler("wandb", run_args={"project": "final-project-v1"}) | |
wandb_callback = llama_index.global_handler | |
run = wandb.init() | |
artifact = run.use_artifact('rlpeter70/final-project-v1/earnings-index:v0', type='storage_context') | |
artifact_dir = artifact.download() | |
from dotenv import load_dotenv | |
load_dotenv() | |
openai.api_key = os.environ['OPENAI_API_KEY'] | |
top_k = 3 | |
vector_store_info = VectorStoreInfo( | |
content_info="transcripts of earnings calls", | |
metadata_info=[MetadataInfo( | |
name="title", | |
type="str", | |
description="Title of the earnings call", | |
), | |
MetadataInfo( | |
name="period", | |
type="str", | |
description="Period of the earnings call" | |
), | |
MetadataInfo( | |
name="ticker", | |
type="str", | |
description="Ticker of the company" | |
), | |
MetadataInfo( | |
name="year", | |
type="str", | |
description="Year of the earnings call" | |
), | |
MetadataInfo( | |
name="quarter", | |
type="str", | |
description="Quarter of the earnings call" | |
), | |
MetadataInfo( | |
name="path", | |
type="str", | |
description="Path to the earnings call" | |
), | |
]) | |
class AutoRetrieveModel(BaseModel): | |
query: str = Field(..., description="natural language query string") | |
filter_key_list: List[str] = Field( | |
..., description="List of metadata filter field names" | |
) | |
filter_value_list: List[str] = Field( | |
..., | |
description=( | |
"List of metadata filter field values (corresponding to names specified in filter_key_list)" | |
) | |
) | |
embed_model = OpenAIEmbedding() | |
chunk_size = 500 | |
llm = OpenAI( | |
temperature=0, | |
model="gpt-4" ### YOUR CODE HERE | |
) | |
service_context = ServiceContext.from_defaults( | |
llm=llm, | |
chunk_size=chunk_size, | |
embed_model=embed_model, | |
) | |
index = GPTVectorStoreIndex.from_documents([], service_context=service_context) | |
# Main function to extract information | |
async def extract_information(): | |
# Make sure to use a recent model that supports tools | |
storage_context = wandb_callback.load_storage_context( | |
artifact_url="rlpeter70/uncategorized/earnings-index:v0" | |
#artifact_url=artifact_dir | |
) | |
index = load_index_from_storage(storage_context, service_context=service_context) | |
def auto_retrieve_fn( | |
query: str, filter_key_list: List[str], filter_value_list: List[str] | |
): | |
"""Auto retrieval function. | |
Performs auto-retrieval from a vector database, and then applies a set of filters. | |
""" | |
query = query or "Query" | |
exact_match_filters = [ | |
ExactMatchFilter(key=k, value=v) | |
for k, v in zip(filter_key_list, filter_value_list) | |
] | |
retriever = VectorIndexRetriever( | |
index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k | |
) | |
query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context) | |
response = query_engine.query(query) | |
return str(response) | |
auto_retrieve_tool = FunctionTool.from_defaults( | |
fn=auto_retrieve_fn, | |
name="earnings-transcripts", | |
description="Earnings Bot", | |
fn_schema=AutoRetrieveModel | |
) | |
agent = OpenAIAgent.from_tools( | |
tools=[auto_retrieve_tool], | |
) | |
return agent | |
# if __name__ == "__main__": | |
# text = "Who is the CEO of MSFT." | |
# chain = extract_information() | |
# print(str(chain.chat(text))) | |
# async def extract_information_async(message: str): | |
# return str(chain.chat(text)) | |
# async def main(): | |
# res = await extract_information_async(text) | |
# print(res) | |
# asyncio.run(main()) | |