Spaces:
Runtime error
Runtime error
| from langchain_openai import ChatOpenAI#, OpenAIEmbeddings # No need to pay for using embeddings as well when have free alternatives | |
| # Data | |
| from langchain_community.document_loaders import DirectoryLoader, TextLoader, WebBaseLoader | |
| # from langchain_chroma import Chroma # The documentation uses this one, but it is extremely recent, and the same functionality is available in langchain_community and langchain (which imports community) | |
| from langchain_community.vectorstores import Chroma # This has documentation on-hover, while the indirect import through non-community does not | |
| from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings # The free alternative (also the default in docs, with model_name = 'all-MiniLM-L6-v2') | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter # Recursive to better keep related bits contiguous (also recommended in docs: https://python.langchain.com/docs/modules/data_connection/document_transformers/) | |
| # Chains | |
| from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
| from langchain.tools.retriever import create_retriever_tool | |
| from langchain_core.runnables import RunnablePassthrough, RunnableParallel, chain | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| # Agents | |
| from langchain import hub | |
| from langchain.agents import create_tool_calling_agent, AgentExecutor | |
| # To manually create inputs to test pipelines | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_core.documents import Document | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| import re | |
| import dotenv | |
| dotenv.load_dotenv() | |
| ## Vector stores | |
| # Non-persistent; build from documents | |
| # scripts = DirectoryLoader('scripts', glob = '*.txt', loader_cls = TextLoader).load() | |
| # for s in scripts: s.page_content = re.sub(r'^[\t ]+', '', s.page_content, flags = re.MULTILINE) # Spacing to centre text noise | |
| # script_chunks = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200, separators = ['\n\n\n', '\n\n', '\n']).split_documents(scripts) | |
| # script_db = Chroma.from_documents(script_chunks, SentenceTransformerEmbeddings(model_name = 'all-MiniLM-L6-v2')) | |
| # pages = DirectoryLoader('wookieepedia', glob = '*.txt', loader_cls = TextLoader).load() | |
| # page_chunks = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200, separators = ['\n\n\n', '\n\n', '\n']).split_documents(pages) | |
| # woo_db = Chroma.from_documents(page_chunks, SentenceTransformerEmbeddings(model_name = 'all-MiniLM-L6-v2')) | |
| # # Load pre-built persistent ones | |
| script_db = Chroma(embedding_function = SentenceTransformerEmbeddings(model_name = 'all-MiniLM-L6-v2'), persist_directory = str(Path('scripts') / 'db')) | |
| woo_db = Chroma(embedding_function = SentenceTransformerEmbeddings(model_name = 'all-MiniLM-L6-v2'), persist_directory = str(Path('wookieepedia') / 'db')) | |
| # Chains | |
| llm = ChatOpenAI(model = 'gpt-3.5-turbo-0125', temperature = 0) | |
| ## Base version (only one retriever) | |
| document_prompt_system_text = ''' | |
| You are very knowledgeable about Star Wars and your job is to answer questions about its plot, characters, etc. | |
| Use the context below to produce your answers with as much detail as possible. | |
| If you do not know an answer, say so; do not make up information not in the context. | |
| <context> | |
| {context} | |
| </context> | |
| ''' | |
| document_prompt = ChatPromptTemplate.from_messages([ | |
| ('system', document_prompt_system_text), | |
| MessagesPlaceholder(variable_name = 'chat_history', optional = True), | |
| ('user', '{input}') | |
| ]) | |
| document_chain = create_stuff_documents_chain(llm, document_prompt) | |
| script_retriever_prompt = ChatPromptTemplate.from_messages([ | |
| MessagesPlaceholder(variable_name = 'chat_history'), | |
| ('user', '{input}'), | |
| ('user', '''Given the above conversation, generate a search query to look up relevant information in a database containing the full scripts from the Star Wars films (i.e. just dialogue and brief scene descriptions). | |
| The query need not be a proper sentence, but a list of keywords likely to be in dialogue or scene descriptions''') | |
| ]) | |
| script_retriever_chain = create_history_aware_retriever(llm, script_db.as_retriever(), script_retriever_prompt) # Essentially just: prompt | llm | StrOutputParser() | retriever | |
| woo_retriever_prompt = ChatPromptTemplate.from_messages([ | |
| MessagesPlaceholder(variable_name = 'chat_history'), | |
| ('user', '{input}'), | |
| ('user', 'Given the above conversation, generate a search query to find a relevant page in the Star Wars fandom wiki; the query should be something simple, such as the name of a character, place, event, item, etc.') | |
| ]) | |
| woo_retriever_chain = create_history_aware_retriever(llm, woo_db.as_retriever(), woo_retriever_prompt) # Essentially just: prompt | llm | StrOutputParser() | retriever | |
| # full_chain = create_retrieval_chain(script_retriever_chain, document_chain) | |
| full_chain = create_retrieval_chain(woo_retriever_chain, document_chain) | |
| # simplify_query_prompt = ChatPromptTemplate.from_messages([ | |
| # ('system', 'Given the above conversation, generate a search query to find a relevant page in the Star Wars fandom wiki; the query should be something simple, at most 4 words, such as the name of a character, place, event, item, etc.'), | |
| # MessagesPlaceholder('chat_history', optional = True), # Using this form since not clear how to have optional = True in the tuple form | |
| # ('human', '{query}') | |
| # ]) | |
| # simplify_query_chain = simplify_query_prompt | llm | StrOutputParser() # To extract just the message | |
| ## Agent version | |
| script_tool = create_retriever_tool( | |
| script_db.as_retriever(search_kwargs = dict(k = 4)), | |
| 'search_film_scripts', | |
| '''Search the Star Wars film scripts. This tool should be the first choice for Star Wars related questions. | |
| Queries passed to this tool should be lists of keywords likely to be in dialogue or scene descriptions, and should not include film titles.''' | |
| ) | |
| woo_tool = create_retriever_tool( | |
| woo_db.as_retriever(search_kwargs = dict(k = 4)), | |
| 'search_wookieepedia', | |
| 'Search the Star Wars fandom wiki. This tool should be the first choice for Star Wars related questions.' | |
| # This tool should be used for queries about details of a particular character, location, event, weapon, etc., and the query should be something simple, such as the name of a character, place, event, item, etc.''' | |
| ) | |
| tools = [script_tool, woo_tool] | |
| agent_system_text = ''' | |
| You are a helpful agent who is very knowledgeable about Star Wars and your job is to answer questions about its plot, characters, etc. | |
| Use the context provided in the exchanges to come to produce your answers with as much detail as possible. | |
| If you do not know an answer, say so; do not make up information. | |
| ''' | |
| agent_prompt = ChatPromptTemplate.from_messages([ | |
| ('system', agent_system_text), | |
| MessagesPlaceholder('chat_history', optional = True), # Using this form since not clear how to have optional = True in the tuple form | |
| ('human', '{input}'), | |
| ('placeholder', '{agent_scratchpad}') # Required for chat history and the agent's intermediate processing values | |
| ]) | |
| agent = create_tool_calling_agent(llm, tools, agent_prompt) | |
| agent_executor = AgentExecutor(agent = agent, tools = tools, verbose = True) | |
| ## Non-agent chain-logic version | |
| # Determine which retriever is best and generate an appropriate query for it | |
| class DirectedQuery(BaseModel): | |
| '''Determine whether a query is best answered by looking at scripts rather than articles''' | |
| query: str = Field( | |
| ..., | |
| description = '''The query to either search film scripts or wiki articles. | |
| A film script query should include character names and relevant keywords of what they are saying in the a scene which is likely to contain the required information. | |
| A wiki articles search should instead be at most 4 words, simply being the name of a character or location or event whose page is likely to contain the required information.''', | |
| ) | |
| source: str = Field( | |
| ..., | |
| description = 'Either "wiki" or "scripts", indicating which source the query should be passed to.', | |
| ) | |
| query_analyser_prompt = ChatPromptTemplate.from_messages([ | |
| ('system', 'You have the ability to issue search queries of one of two kinds to get information to help answer questions.'), | |
| ('human', '{question}'), | |
| ]) | |
| structured_llm = llm.with_structured_output(DirectedQuery) | |
| query_generator = dict(question = RunnablePassthrough()) | query_analyser_prompt | structured_llm | |
| retrievers = dict(wiki = woo_db.as_retriever(search_kwargs = dict(k = 4)), scripts = script_db.as_retriever(search_kwargs = dict(k = 4))) | |
| def compound_retriever(question): | |
| response = query_generator.invoke(question) | |
| retriever = retrievers[response.source] | |
| return retriever.invoke(response.query) | |
| compound_chain = create_retrieval_chain(compound_retriever, document_chain) | |
| ## Wookieepedia functions | |
| def first_wookieepedia_result(query: str) -> str: | |
| '''Get the url of the first result when searching Wookieepedia for a query | |
| (best for simple names as queries, ideally generated by the llm for something like | |
| "Produce a input consisting of the name of the most important element in the query so that its article can be looked up") | |
| ''' | |
| search_results = requests.get(f'https://starwars.fandom.com/wiki/Special:Search?query={"+".join(query.split(" "))}') | |
| soup = BeautifulSoup(search_results.content, 'html.parser') | |
| first_res = soup.find('a', class_ = 'unified-search__result__link') | |
| return first_res['href'] | |
| def get_wookieepedia_page_content(query: str, previous_sources: set[str]) -> Document | None: | |
| '''Return cleaned content from a Wookieepedia page provided it was not already sourced | |
| ''' | |
| url = first_wookieepedia_result(query) | |
| if url in previous_sources: return None | |
| else: | |
| response = requests.get(url) | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| doc = soup.find('div', id = 'content').get_text() | |
| # Cleaning | |
| doc = doc.split('\n\n\n\n\n\n\n\n\n\n\n\n\n\n')[-1] # The (multiple) preambles are separated by these many newlines; no harm done if not present | |
| doc = re.sub('\[\d*\]', '', doc) # References (and section title's "[]" suffixes) are noise | |
| doc = doc.split('\nAppearances\n')[0] # Keep only content before these sections | |
| doc = doc.split('\nSources\n')[0] # Technically no need to check this if successfully cut on appearances, but no harm done | |
| doc = re.sub('Contents\n\n(?:[\d\.]+ [^\n]+\n+)+', '', doc) # Remove table of contents | |
| return Document(page_content = doc, metadata = dict(source = url)) | |
| def get_wookieepedia_context(original_query: str, simple_query: str, wdb: Chroma) -> list[Document]: | |
| try: | |
| doc = get_wookieepedia_page_content(simple_query, previous_sources = set(md.get('source') for md in wdb.get()['metadatas'])) | |
| if doc is not None: | |
| new_chunks = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200).split_documents([doc]) | |
| wdb.add_documents(new_chunks) | |
| print(f"Added new chunks (for '{simple_query}' -> {doc['metadata']['source']}) to the Wookieepedia database.") | |
| except: return [] | |
| return wdb.similarity_search(original_query, k = 10) | |