File size: 11,665 Bytes
42a73b9
 
 
 
 
 
 
be5cac6
42a73b9
 
 
 
 
 
 
67a7b40
 
42a73b9
 
 
 
 
 
 
 
 
 
 
 
 
 
e17df50
42a73b9
 
 
 
 
 
be5cac6
 
42a73b9
 
e17df50
42a73b9
e17df50
 
 
 
42a73b9
e17df50
 
 
42a73b9
 
e17df50
42a73b9
e17df50
 
42a73b9
 
 
 
 
 
 
 
67a7b40
 
42a73b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8c8660
 
 
 
 
67a7b40
 
 
 
 
e8c8660
67a7b40
42a73b9
 
 
 
 
 
 
 
 
 
 
e17df50
e8c8660
42a73b9
 
e8c8660
 
42a73b9
e8c8660
42a73b9
 
 
 
 
 
 
 
 
 
 
 
 
e17df50
42a73b9
 
 
 
67a7b40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e17df50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
from langchain_openai import ChatOpenAI#, OpenAIEmbeddings # No need to pay for using embeddings as well when have free alternatives

# Data
from langchain_community.document_loaders import DirectoryLoader, TextLoader, WebBaseLoader
# from langchain_chroma import Chroma # The documentation uses this one, but it is extremely recent, and the same functionality is available in langchain_community and langchain (which imports community)
from langchain_community.vectorstores import Chroma # This has documentation on-hover, while the indirect import through non-community does not
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings # The free alternative (also the default in docs, with model_name = 'all-MiniLM-L6-v2')
from langchain.text_splitter import RecursiveCharacterTextSplitter # Recursive to better keep related bits contiguous (also recommended in docs: https://python.langchain.com/docs/modules/data_connection/document_transformers/)

# Chains
from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.tools.retriever import create_retriever_tool
from langchain_core.runnables import RunnablePassthrough, RunnableParallel, chain
from langchain_core.pydantic_v1 import BaseModel, Field

# Agents
from langchain import hub
from langchain.agents import create_tool_calling_agent, AgentExecutor

# To manually create inputs to test pipelines
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.documents import Document

import requests
from bs4 import BeautifulSoup

import os
import shutil
from pathlib import Path
import re

import dotenv
dotenv.load_dotenv()




## Vector stores

# Non-persistent; build from documents

# scripts = DirectoryLoader('scripts', glob = '*.txt', loader_cls = TextLoader).load()
# for s in scripts: s.page_content = re.sub(r'^[\t ]+', '', s.page_content, flags = re.MULTILINE)  # Spacing to centre text noise
# script_chunks = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200, separators = ['\n\n\n', '\n\n', '\n']).split_documents(scripts)
# script_db = Chroma.from_documents(script_chunks, SentenceTransformerEmbeddings(model_name = 'all-MiniLM-L6-v2'))

# pages = DirectoryLoader('wookieepedia', glob = '*.txt', loader_cls = TextLoader).load()
# page_chunks = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200, separators = ['\n\n\n', '\n\n', '\n']).split_documents(pages)
# woo_db = Chroma.from_documents(page_chunks, SentenceTransformerEmbeddings(model_name = 'all-MiniLM-L6-v2'))


# # Load pre-built persistent ones

script_db = Chroma(embedding_function = SentenceTransformerEmbeddings(model_name = 'all-MiniLM-L6-v2'), persist_directory = str(Path('scripts') / 'db'))
woo_db = Chroma(embedding_function = SentenceTransformerEmbeddings(model_name = 'all-MiniLM-L6-v2'), persist_directory = str(Path('wookieepedia') / 'db'))



# Chains

llm = ChatOpenAI(model = 'gpt-3.5-turbo-0125', temperature = 0)


## Base version (only one retriever)

document_prompt_system_text = '''
You are very knowledgeable about Star Wars and your job is to answer questions about its plot, characters, etc.
Use the context below to produce your answers with as much detail as possible.
If you do not know an answer, say so; do not make up information not in the context.

<context>
{context}
</context>
'''
document_prompt = ChatPromptTemplate.from_messages([
    ('system', document_prompt_system_text),
    MessagesPlaceholder(variable_name = 'chat_history', optional = True),
    ('user', '{input}')
])
document_chain = create_stuff_documents_chain(llm, document_prompt)

script_retriever_prompt = ChatPromptTemplate.from_messages([
    MessagesPlaceholder(variable_name = 'chat_history'),
    ('user', '{input}'),
    ('user', '''Given the above conversation, generate a search query to look up relevant information in a database containing the full scripts from the Star Wars films (i.e. just dialogue and brief scene descriptions).
     The query need not be a proper sentence, but a list of keywords likely to be in dialogue or scene descriptions''')
])
script_retriever_chain = create_history_aware_retriever(llm, script_db.as_retriever(), script_retriever_prompt) # Essentially just: prompt | llm | StrOutputParser() | retriever

woo_retriever_prompt = ChatPromptTemplate.from_messages([
    MessagesPlaceholder(variable_name = 'chat_history'),
    ('user', '{input}'),
    ('user', 'Given the above conversation, generate a search query to find a relevant page in the Star Wars fandom wiki; the query should be something simple, such as the name of a character, place, event, item, etc.')
])
woo_retriever_chain = create_history_aware_retriever(llm, woo_db.as_retriever(), woo_retriever_prompt) # Essentially just: prompt | llm | StrOutputParser() | retriever

# full_chain = create_retrieval_chain(script_retriever_chain, document_chain)
full_chain = create_retrieval_chain(woo_retriever_chain, document_chain)



# simplify_query_prompt = ChatPromptTemplate.from_messages([
#     ('system', 'Given the above conversation, generate a search query to find a relevant page in the Star Wars fandom wiki; the query should be something simple, at most 4 words, such as the name of a character, place, event, item, etc.'),
#     MessagesPlaceholder('chat_history', optional = True), # Using this form since not clear how to have optional = True in the tuple form
#     ('human', '{query}')
# ])

# simplify_query_chain = simplify_query_prompt | llm | StrOutputParser() # To extract just the message



## Agent version

script_tool = create_retriever_tool(
    script_db.as_retriever(search_kwargs = dict(k = 4)),
    'search_film_scripts',
    '''Search the Star Wars film scripts. This tool should be the first choice for Star Wars related questions.
    Queries passed to this tool should be lists of keywords likely to be in dialogue or scene descriptions, and should not include film titles.'''
)

woo_tool = create_retriever_tool(
    woo_db.as_retriever(search_kwargs = dict(k = 4)),
    'search_wookieepedia',
    'Search the Star Wars fandom wiki. This tool should be the first choice for Star Wars related questions.'
    # This tool should be used for queries about details of a particular character, location, event, weapon, etc., and the query should be something simple, such as the name of a character, place, event, item, etc.'''
)
tools = [script_tool, woo_tool]

agent_system_text = '''
You are a helpful agent who is very knowledgeable about Star Wars and your job is to answer questions about its plot, characters, etc.
Use the context provided in the exchanges to come to produce your answers with as much detail as possible.
If you do not know an answer, say so; do not make up information.
'''
agent_prompt = ChatPromptTemplate.from_messages([
    ('system', agent_system_text),
    MessagesPlaceholder('chat_history', optional = True), # Using this form since not clear how to have optional = True in the tuple form
    ('human', '{input}'),
    ('placeholder', '{agent_scratchpad}') # Required for chat history and the agent's intermediate processing values
])
agent = create_tool_calling_agent(llm, tools, agent_prompt)

agent_executor = AgentExecutor(agent = agent, tools = tools, verbose = True)



## Non-agent chain-logic version

# Determine which retriever is best and generate an appropriate query for it
class DirectedQuery(BaseModel):
    '''Determine whether a query is best answered by looking at scripts rather than articles'''

    query: str = Field(
        ...,
        description = '''The query to either search film scripts or wiki articles.
        A film script query should include character names and relevant keywords of what they are saying in the a scene which is likely to contain the required information.
        A wiki articles search should instead be at most 4 words, simply being the name of a character or location or event whose page is likely to contain the required information.''',
    )
    source: str = Field(
        ...,
        description = 'Either "wiki" or "scripts", indicating which source the query should be passed to.',
    )
query_analyser_prompt = ChatPromptTemplate.from_messages([
        ('system', 'You have the ability to issue search queries of one of two kinds to get information to help answer questions.'),
        ('human', '{question}'),
])
structured_llm = llm.with_structured_output(DirectedQuery)
query_generator = dict(question = RunnablePassthrough()) | query_analyser_prompt | structured_llm

retrievers = dict(wiki = woo_db.as_retriever(search_kwargs = dict(k = 4)), scripts = script_db.as_retriever(search_kwargs = dict(k = 4)))

@chain
def compound_retriever(question):
    response = query_generator.invoke(question)
    retriever = retrievers[response.source]
    return retriever.invoke(response.query)

compound_chain = create_retrieval_chain(compound_retriever, document_chain)



## Wookieepedia functions

def first_wookieepedia_result(query: str) -> str:
    '''Get the url of the first result when searching Wookieepedia for a query
    (best for simple names as queries, ideally generated by the llm for something like
    "Produce a input consisting of the name of the most important element in the query so that its article can be looked up")
    '''
    search_results = requests.get(f'https://starwars.fandom.com/wiki/Special:Search?query={"+".join(query.split(" "))}')
    soup = BeautifulSoup(search_results.content, 'html.parser')
    first_res = soup.find('a', class_ = 'unified-search__result__link')
    return first_res['href']


def get_wookieepedia_page_content(query: str, previous_sources: set[str]) -> Document | None:
    '''Return cleaned content from a Wookieepedia page provided it was not already sourced
    '''
    url = first_wookieepedia_result(query)

    if url in previous_sources: return None
    else:
        response = requests.get(url)
        soup = BeautifulSoup(response.content, 'html.parser')
        doc = soup.find('div', id = 'content').get_text()

        # Cleaning
        doc = doc.split('\n\n\n\n\n\n\n\n\n\n\n\n\n\n')[-1] # The (multiple) preambles are separated by these many newlines; no harm done if not present
        doc = re.sub('\[\d*\]', '', doc) # References (and section title's "[]" suffixes) are noise
        doc = doc.split('\nAppearances\n')[0] # Keep only content before these sections
        doc = doc.split('\nSources\n')[0] # Technically no need to check this if successfully cut on appearances, but no harm done
        doc = re.sub('Contents\n\n(?:[\d\.]+ [^\n]+\n+)+', '', doc) # Remove table of contents

        return Document(page_content = doc, metadata = dict(source = url))

def get_wookieepedia_context(original_query: str, simple_query: str, wdb: Chroma) -> list[Document]:
    try:
        doc = get_wookieepedia_page_content(simple_query, previous_sources = set(md.get('source') for md in wdb.get()['metadatas']))
        if doc is not None:
            new_chunks = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200).split_documents([doc])
            wdb.add_documents(new_chunks)
            print(f"Added new chunks (for '{simple_query}' -> {doc['metadata']['source']}) to the Wookieepedia database.")
    except: return []

    return wdb.similarity_search(original_query, k = 10)