| import os |
| import shutil |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.document_loaders import DirectoryLoader |
| from langchain_openai import OpenAIEmbeddings |
| from langchain.vectorstores.chroma import Chroma |
| from langchain_openai import ChatOpenAI |
| from langchain.prompts import ChatPromptTemplate |
| import gradio as gr |
|
|
|
|
| script_directory = os.path.dirname(os.path.abspath(__file__)) |
| DATA_PATH = os.path.join(script_directory, "pdfs") |
| CHROMA_PATH = "chroma" |
| |
| os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") |
| |
| PROMPT_TEMPLATE = """ |
| Answer the question based only on the following context: |
| {context} |
| --- |
| Answer the question based on the above context: {question} |
| """ |
|
|
| def load_documents(): |
| loader = DirectoryLoader(DATA_PATH, glob="*.pdf") |
| documents = loader.load() |
| return documents |
|
|
| def split_text(documents): |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=300, |
| chunk_overlap=100, |
| length_function=len, |
| add_start_index=True, |
| ) |
| chunks = text_splitter.split_documents(documents) |
| print(f"Split {len(documents)} documents into {len(chunks)} chunks.") |
| return chunks |
|
|
| def save_to_chroma(chunks): |
| |
| if os.path.exists(CHROMA_PATH): |
| shutil.rmtree(CHROMA_PATH) |
|
|
| embeddings = OpenAIEmbeddings() |
| |
| db = Chroma.from_documents( |
| chunks, embeddings, persist_directory=CHROMA_PATH |
| ) |
| db.persist() |
| print(f"Saved {len(chunks)} chunks to {CHROMA_PATH}.") |
|
|
|
|
| def get_response(query_text): |
| |
| |
| embedding_function = OpenAIEmbeddings() |
| db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function) |
| |
|
|
| results = db.similarity_search_with_relevance_scores(query_text, k=4) |
| if len(results) == 0 or results[0][1] < 0.7: |
| print(f"Unable to find matching results.") |
| return |
|
|
| context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) |
| |
| context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) |
| prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) |
| prompt = prompt_template.format(context=context_text, question=query_text) |
| |
| model = ChatOpenAI() |
| response_text = model.predict(prompt) |
| |
| sources = [doc.metadata.get("source", None) for doc, _score in results] |
| sources = list(dict.fromkeys(sources)) |
| formatted_response = f"Response: {response_text}\nSources: {sources}" |
| return formatted_response |
| |
| def prepare(): |
| documents = load_documents() |
| chunks = split_text(documents) |
| save_to_chroma(chunks) |
| |
| |
| |
|
|
| iface = gr.Interface(fn=get_response, |
| inputs=gr.components.Textbox(lines=7, label="Enter your text"), |
| outputs="text", |
| title="UK Insurance Law AI Tool") |
| |
|
|
|
|
| prepare() |
| iface.launch() |
|
|