brandfolder_image_upload / chat_engine.py
mylesai's picture
Update chat_engine.py
8df2a69 verified
raw
history blame
6.89 kB
from llama_index.core import (
VectorStoreIndex,
get_response_synthesizer,
GPTListIndex,
PromptHelper,
set_global_service_context,
Settings
)
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core.schema import Document
from llama_index.llms.anyscale import Anyscale
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.openai import OpenAI
from llama_index.core.indices.service_context import ServiceContext
import urllib
import nltk
import os
import tiktoken
from nltk.tokenize import sent_tokenize
from llama_index.core.callbacks import CallbackManager, TokenCountingHandler
from llama_index.core import SimpleDirectoryReader
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import TokenTextSplitter
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.core.chat_engine.condense_question import CondenseQuestionChatEngine
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.core.postprocessor import MetadataReplacementPostProcessor
from llama_index.core.postprocessor import LongContextReorder
from llama_index.postprocessor.rankgpt_rerank import RankGPTRerank
from llama_index.embeddings.mistralai import MistralAIEmbedding
from llama_index.core.node_parser import TokenTextSplitter
from pypdf import PdfReader
import gradio as gr
mistral_api_key = os.environ['MISTRALAI_API_KEY']
# Functions
def extract_text_from_pdf(pdf_path):
"""
Function to extract all text from a PDF file.
Args:
pdf_path (str): The file path to the PDF from which text is to be extracted.
Returns:
str: All extracted text concatenated together with each page separated by a newline.
"""
# Create a PDF reader object that opens and reads the PDF file at the specified path.
pdf_reader = PdfReader(pdf_path)
# Initialize a variable to store the text extracted from all pages.
full_text = ''
# Loop through each page in the PDF file.
for page in pdf_reader.pages:
# Extract text from the current page and concatenate it to the full_text variable.
# Add a newline character after each page's text to separate the text of different pages.
full_text += page.extract_text() + '\n'
# Return the complete text extracted from the PDF.
return full_text
def get_api_type(api_type):
if api_type == 'openai':
# default is gpt-3.5-turbo, can also be gpt-4-0314
return OpenAI(model='gpt-4o') # for QA, temp is low
elif api_type == 'claude':
return Anthropic(model="claude-3-opus-20240229")
elif api_type == 'llama':
return Anyscale(model='meta-llama/Llama-2-70b-chat-hf')
elif api_type == 'mistral':
return Anyscale(model='mistralai/Mixtral-8x7B-Instruct-v0.1', max_tokens=10000)
else:
raise NotImplementedError
def get_chat_engine(files, progress=gr.Progress()):
progress(0, desc="Uploading Documents...")
llm = get_api_type('openai')
Settings.llm = llm
embed_model = MistralAIEmbedding(model_name='mistral-embed', api_key=mistral_api_key)
Settings.embed_model = embed_model
documents = SimpleDirectoryReader(input_files=files).load_data()
splitter = TokenTextSplitter(
chunk_size=1024,
chunk_overlap=20,
separator=" ",
)
progress(0.3, desc="Creating index...")
nodes = splitter.get_nodes_from_documents(documents)
index = VectorStoreIndex(nodes)
chat_text_qa_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"""
% You are an expert on developing websites for contractors and explaining your expertise to a general audience.
% If a character or word limit is mentioned in the prompt, ADHERE TO IT.
% For example, if a user wants a summary of a business less than 750 characters, the summary must be less than 750 characters.
"""
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"""
% You are an expert on developing websites for contractors and explaining your expertise to a general audience.
% Goal: Given the Context below, give a detailed and thorough answer to the following question without mentioning where you found the answer: {query_str}
% Context:
```{context_str}```
% Instructions:"
Answer in a friendly manner.
Do not answer any questions that have no relevance to the context provided.
Do not include any instructions in your response.
Do not mention the context provided in your answer
ANSWER WITHOUT MENTIONING THE PROVIDED DOCUMENTS
YOU ARE NOT PERMITTED TO GIVE PAGE NUMBERS IN YOUR ANSWER UNDER ANY CIRCUMSTANCE
"""
),
),
]
text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)
reorder = LongContextReorder()
# postprocessor = SimilarityPostprocessor(similarity_cutoff=0.7)
rerank = RankGPTRerank(top_n=5, llm=OpenAI(model="gpt-3.5-turbo"))
progress(0.5, desc="Creating LLM...")
chat_engine = index.as_chat_engine('condense_plus_context',
text_qa_prompt=text_qa_template,
node_postprocessors=[
reorder,
MetadataReplacementPostProcessor(target_metadata_key="window"),
rerank
],
similarity_top_k=15,
streaming=True)
query_engine = index.as_query_engine(
node_postprocessors=[
reorder,
MetadataReplacementPostProcessor(target_metadata_key="window"),
rerank
],
similarity_top_k=15)
progress(1, desc="LLM Created")
return chat_engine, query_engine, "LLM Created"