rag-test-venkat / Index.py
DeepVen's picture
Upload Index.py
ee65653
raw
history blame contribute delete
No virus
4.53 kB
from fastapi import FastAPI
import os
import phoenix as px
from phoenix.trace.langchain import OpenInferenceTracer, LangChainInstrumentor
from langchain.embeddings import HuggingFaceEmbeddings #for using HugginFace models
from langchain.chains.question_answering import load_qa_chain
from langchain import HuggingFaceHub
from langchain.chains import RetrievalQA
from langchain.callbacks import StdOutCallbackHandler
#from langchain.retrievers import KNNRetriever
from langchain.storage import LocalFileStore
from langchain.embeddings import CacheBackedEmbeddings
from langchain.vectorstores import FAISS
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain import HuggingFaceHub
# from langchain.prompts import PromptTemplate
# from langchain.chains import LLMChain
# from txtai.embeddings import Embeddings
# from txtai.pipeline import Extractor
# import pandas as pd
# import sqlite3
# import os
# NOTE - we configure docs_url to serve the interactive Docs at the root path
# of the app. This way, we can use the docs as a landing page for the app on Spaces.
app = FastAPI(docs_url="/")
#phoenix setup
session = px.launch_app()
# If no exporter is specified, the tracer will export to the locally running Phoenix server
tracer = OpenInferenceTracer()
# If no tracer is specified, a tracer is constructed for you
LangChainInstrumentor(tracer).instrument()
print(session.url)
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "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"
# embedding cache
store = LocalFileStore("./cache/")
# define embedder
core_embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
embedder = CacheBackedEmbeddings.from_bytes_store(core_embeddings_model, store)
# define llm
llm=HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":1, "max_length":1000000})
#llm=HuggingFaceHub(repo_id="gpt2", model_kwargs={"temperature":1, "max_length":1000000})
handler = StdOutCallbackHandler()
# set global variable
vectorstore = None
retriever = None
def initialize_vectorstore():
webpage_loader = WebBaseLoader("https://www.tredence.com/case-studies/tredence-helped-a-global-retailer-providing-holistic-campaign-analytics-by-using-the-power-of-gcp").load()
webpage_chunks = _text_splitter(webpage_loader)
global vectorstore
global retriever
# store embeddings in vector store
vectorstore = FAISS.from_documents(webpage_chunks, embedder)
print("vector store initialized with sample doc")
# instantiate a retriever
retriever = vectorstore.as_retriever()
def _text_splitter(doc):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=50,
length_function=len,
)
return text_splitter.transform_documents(doc)
def _load_docs(path: str):
load_doc = WebBaseLoader(path).load()
doc = _text_splitter(load_doc)
return doc
@app.get("/index/")
def get_domain_file_path(file_path: str):
print("file_path " ,file_path)
webpage_loader = _load_docs(file_path)
webpage_chunks = _text_splitter(webpage_loader)
# store embeddings in vector store
vectorstore.add_documents(webpage_chunks)
return "document loaded to vector store successfully!!"
def _prompt(question):
return f"""Answer following question using only the context below. Say 'Could not find answer with provided context' when question can't be answered.
Question: {question}
Context: """
@app.get("/rag")
def rag( question: str):
chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
callbacks=[handler],
return_source_documents=True
)
#response = chain("how tredence brought good insight?")
response = chain(_prompt(question))
return {"question": question, "answer": response['result']}
initialize_vectorstore()
@app.get("/trace")
def trace():
df = px.active_session().get_spans_dataframe().fillna('')
return df
'''
#import getpass
from pyngrok import ngrok, conf
#print("Enter your authtoken, which can be copied from https://dashboard.ngrok.com/auth")
conf.get_default().auth_token="2WJNWULs5bCOyJnV24WQYJEKod3_YQUbM5EGCp8sgE4aQvzi"
port = 37689
# Open a ngrok tunnel to the HTTP server
conf.get_default().monitor_thread = False
public_url = ngrok.connect(port).public_url
print(" * ngrok tunnel \"{}\" -> \"http://127.0.0.1:{}\"".format(public_url, port))
'''