DATA-8003 / app.py
rzanoli's picture
Minor changes
00cfaee
raw
history blame
3.06 kB
import os
from langchain.prompts import PromptTemplate
from langchain.llms import CTransformers
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.chains import RetrievalQA
import gradio as gr
# Mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive/')
# !ls /content/drive/My\ Drive/stores/enron_cosine/
# The model used to generate responses based on retrieved content from the database in response to user inquiries.
local_llm = "TheBloke/zephyr-7B-beta-GGUF"
config = {
# Explicitly set the max_seq_len
"max_seq_len" : 4096,
"max_answer_len" : 1024,
"max_new_token": 1024,
"repetition_penalty": 1.1,
"temperature": 0.1,
"top_k": 50,
"top_p": 0.9,
"stream": True,
"threads": int(os.cpu_count() / 2),
}
llm_init = CTransformers(model=local_llm, model_type="mistral", lib="avx2", **config)
prompt_template = """Use the following piece of information to answers the question asked by the user.
Don't try to make up the answer if you don't know the answer, simply say I don't know.
Context: {context}
Question: {question}
Only helpful answer below.
Helpful answer:
"""
# The model to create the embeddings of the user queries
model_name = "BAAI/bge-large-en"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
load_vector_store = Chroma(
persist_directory="./stores/enron_cosine", embedding_function=embeddings
)
retriever = load_vector_store.as_retriever(search_kwargs={"k": 1})
print("retrieval", retriever)
#query = "In what context is mentioned Natural Gas Storage Overview?"
#semantic_search = retriever.get_relevant_documents(query)
#print(semantic_search)
# chain_type_kwargs = {"prompt": prompt}
# qa = RetrievalQA.from_chain_type(
# llm=llm_init,
# chain_type="stuff",
# retriever=retriever,
# verbose=True,
# chain_type_kwargs=chain_type_kwargs,
# return_source_documents=True,
# )
sample_query = []
def get_response(input):
print("input", input)
query = input
chain_type_kwargs = {"prompt": prompt}
qa = RetrievalQA.from_chain_type(
llm=llm_init,
chain_type="stuff",
retriever=retriever,
verbose=True,
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True,
)
response = qa(query)
print("Response:", response)
return response
input = gr.Text(
label="Query",
show_label=True,
max_lines=2,
container=False,
placeholder="Enter your question",
)
gIface = gr.Interface(
fn=get_response,
inputs=input,
outputs="text",
title="Enron emails RAG AI",
description="RAG demo using Zephyr 7B Beta and Langchain",
examples=sample_query,
allow_flagging="never",
)
gIface.launch()