|
|
|
|
|
|
|
|
|
import json |
|
import gradio as gr |
|
import uuid |
|
import os |
|
import pandas as pd |
|
|
|
from openai import OpenAI |
|
|
|
from langchain_community.embeddings.sentence_transformer import ( |
|
SentenceTransformerEmbeddings |
|
) |
|
from langchain_community.vectorstores import Chroma |
|
|
|
from huggingface_hub import CommitScheduler |
|
from dotenv import load_dotenv |
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
os.environ["anyscale_api_key"]=os.getenv("anyscale_api_key") |
|
|
|
client = OpenAI( |
|
base_url="https://api.endpoints.anyscale.com/v1", |
|
api_key=os.environ['anyscale_api_key'] |
|
) |
|
|
|
|
|
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') |
|
|
|
collection_name = 'Dataset-10k' |
|
reportdb = Chroma( |
|
collection_name=collection_name, |
|
persist_directory='./report_db1', |
|
embedding_function=embedding_model |
|
) |
|
|
|
|
|
|
|
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" |
|
log_folder = log_file.parent |
|
|
|
scheduler = CommitScheduler( |
|
repo_id="RAG-10K-log", |
|
repo_type="dataset", |
|
folder_path=log_folder, |
|
path_in_repo="data", |
|
every=2 |
|
) |
|
|
|
|
|
|
|
qna_system_message = """ |
|
You are an assistant to a financial services firm who answers user queries on annual reports. |
|
User input will have the context required by you to answer user questions. |
|
This context will begin with the token: ###Context. |
|
The context contains references to specific portions of a document relevant to the user query. |
|
|
|
User questions will begin with the token: ###Question. |
|
|
|
Please answer only using the context provided in the input. Do not mention anything about the context in your final answer. |
|
|
|
If the answer is not found in the context, respond "I don't know". |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
qna_user_message_template = """ |
|
###Context |
|
Here are some documents that are relevant to the question mentioned below. |
|
{context} |
|
|
|
###Question |
|
{question} |
|
""" |
|
|
|
|
|
|
|
def predict(user_input,company): |
|
sample = { |
|
'user_input':user_input, |
|
'company':company |
|
} |
|
|
|
filter = "dataset/"+company+"-10-k-2023.pdf" |
|
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter}) |
|
|
|
|
|
relevant_document_chunks = retriever.get_relevant_documents(user_question) |
|
context_list = [d.page_content for d in relevant_document_chunks] |
|
context_for_query = ". ".join(context_list) |
|
|
|
prompt = [ |
|
{'role':'system', 'content': qna_system_message}, |
|
{'role': 'user', 'content': qna_user_message_template.format( |
|
context=context_for_query, |
|
question=user_question |
|
) |
|
} |
|
] |
|
|
|
|
|
try: |
|
response = client.chat.completions.create( |
|
model=model_name, |
|
messages=prompt, |
|
temperature=0 |
|
) |
|
|
|
prediction = response.choices[0].message.content.strip() |
|
except Exception as e: |
|
|
|
prediction = f'Sorry, I encountered the following error: \n {e}' |
|
|
|
|
|
|
|
|
|
prediction = response.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
|
|
|
with scheduler.lock: |
|
with log_file.open("a") as f: |
|
f.write(json.dumps( |
|
{ |
|
'user_input': user_input, |
|
'retrieved_context': context_for_query, |
|
'model_response': prediction |
|
} |
|
)) |
|
f.write("\n") |
|
|
|
return prediction |
|
|
|
|
|
user_input = gr.Textbox (label = 'Query') |
|
company_input = gr.Radio( |
|
['aws','google','IBM','Meta','msft'], |
|
label = 'company' |
|
) |
|
|
|
model_output = gr.Label (label = 'Response') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=[user_input,company_input], |
|
outputs=prediction, |
|
title="RAG on 10k-reports", |
|
description="This API allows you to query on annaul reports", |
|
concurrency_limit=16 |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|