|
import os |
|
from langchain.prompts import PromptTemplate |
|
from langchain.llms import CTransformers |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.chains import RetrievalQA |
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_llm = "TheBloke/zephyr-7B-beta-GGUF" |
|
|
|
config = { |
|
|
|
"max_seq_len" : 4096, |
|
"max_answer_len" : 1024, |
|
"max_new_token": 1024, |
|
"repetition_penalty": 1.1, |
|
"temperature": 0.1, |
|
"top_k": 50, |
|
"top_p": 0.9, |
|
"stream": True, |
|
"threads": int(os.cpu_count() / 2), |
|
} |
|
|
|
llm_init = CTransformers(model=local_llm, model_type="mistral", lib="avx2", **config) |
|
|
|
prompt_template = """Use the following piece of information to answers the question asked by the user. |
|
Don't try to make up the answer if you don't know the answer, simply say I don't know. |
|
|
|
Context: {context} |
|
Question: {question} |
|
|
|
Only helpful answer below. |
|
Helpful answer: |
|
""" |
|
|
|
|
|
model_name = "BAAI/bge-large-en" |
|
model_kwargs = {"device": "cpu"} |
|
encode_kwargs = {"normalize_embeddings": False} |
|
|
|
embeddings = HuggingFaceBgeEmbeddings( |
|
model_name=model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs, |
|
) |
|
|
|
prompt = PromptTemplate( |
|
template=prompt_template, input_variables=["context", "question"] |
|
) |
|
|
|
load_vector_store = Chroma( |
|
persist_directory="./stores/enron_cosine", embedding_function=embeddings |
|
) |
|
|
|
retriever = load_vector_store.as_retriever(search_kwargs={"k": 1}) |
|
|
|
print("retrieval", retriever) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample_query = [] |
|
|
|
def get_response(input): |
|
print("input", input) |
|
query = input |
|
chain_type_kwargs = {"prompt": prompt} |
|
qa = RetrievalQA.from_chain_type( |
|
llm=llm_init, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
verbose=True, |
|
chain_type_kwargs=chain_type_kwargs, |
|
return_source_documents=True, |
|
) |
|
response = qa(query) |
|
print("Response:", response) |
|
return response |
|
|
|
input = gr.Text( |
|
label="Query", |
|
show_label=True, |
|
max_lines=2, |
|
container=False, |
|
placeholder="Enter your question", |
|
) |
|
|
|
gIface = gr.Interface( |
|
fn=get_response, |
|
inputs=input, |
|
outputs="text", |
|
title="Enron emails RAG AI", |
|
description="RAG demo using Zephyr 7B Beta and Langchain", |
|
examples=sample_query, |
|
allow_flagging="never", |
|
) |
|
|
|
gIface.launch() |
|
|
|
|