Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import DirectoryLoader | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain.vectorstores.chroma import Chroma | |
| from langchain_openai import ChatOpenAI | |
| from langchain.prompts import ChatPromptTemplate | |
| import gradio as gr | |
| script_directory = os.path.dirname(os.path.abspath(__file__)) | |
| DATA_PATH = os.path.join(script_directory, "pdfs") | |
| CHROMA_PATH = "chroma" | |
| os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
| PROMPT_TEMPLATE = """ | |
| Answer the question based only on the following context: | |
| {context} | |
| --- | |
| Answer the question based on the above context: {question} | |
| """ | |
| def load_documents(): | |
| loader = DirectoryLoader(DATA_PATH, glob="*.pdf") | |
| documents = loader.load() | |
| return documents | |
| def split_text(documents): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=300, | |
| chunk_overlap=100, | |
| length_function=len, | |
| add_start_index=True, | |
| ) | |
| chunks = text_splitter.split_documents(documents) | |
| print(f"Split {len(documents)} documents into {len(chunks)} chunks.") | |
| return chunks | |
| def save_to_chroma(chunks): | |
| # Clear out the database first. | |
| if os.path.exists(CHROMA_PATH): | |
| shutil.rmtree(CHROMA_PATH) | |
| embeddings = OpenAIEmbeddings() | |
| # Create a new DB from the documents. | |
| db = Chroma.from_documents( | |
| chunks, embeddings, persist_directory=CHROMA_PATH | |
| ) | |
| db.persist() | |
| print(f"Saved {len(chunks)} chunks to {CHROMA_PATH}.") | |
| def get_response(query_text): | |
| # Prepare the DB. | |
| embedding_function = OpenAIEmbeddings() | |
| db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function) | |
| results = db.similarity_search_with_relevance_scores(query_text, k=4) | |
| if len(results) == 0 or results[0][1] < 0.7: | |
| print(f"Unable to find matching results.") | |
| return | |
| context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) | |
| context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) | |
| prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) | |
| prompt = prompt_template.format(context=context_text, question=query_text) | |
| model = ChatOpenAI() | |
| response_text = model.predict(prompt) | |
| sources = [doc.metadata.get("source", None) for doc, _score in results] | |
| sources = list(dict.fromkeys(sources)) | |
| formatted_response = f"Response: {response_text}\nSources: {sources}" | |
| return formatted_response | |
| def prepare(): | |
| documents = load_documents() | |
| chunks = split_text(documents) | |
| save_to_chroma(chunks) | |
| iface = gr.Interface(fn=get_response, | |
| inputs=gr.components.Textbox(lines=7, label="Enter your text"), | |
| outputs="text", | |
| title="UK Insurance Law AI Tool") | |
| prepare() | |
| iface.launch() | |