huggingface init
Browse files- README.md +4 -3
- kth_qa/config.py +133 -0
- kth_qa/courses.json +0 -0
- kth_qa/ingest.py +70 -0
- kth_qa/ingest_pinecone.py +59 -0
- kth_qa/magic/conversational.py +85 -0
- kth_qa/magic/prompts.py +22 -0
- kth_qa/magic/self_query_retriever.py +128 -0
- kth_qa/magic/vectordb.py +13 -0
- kth_qa/main.py +100 -0
- kth_qa/schema.py +11 -0
- kth_qa/static/styles.css +102 -0
- kth_qa/templates/index.html +60 -0
- kth_qa/test_response.json +18 -0
- kth_qa/utils.py +29 -0
README.md
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
---
|
2 |
title: KTH QA
|
3 |
emoji: 🤖
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
|
|
|
7 |
pinned: true
|
8 |
---
|
|
|
1 |
---
|
2 |
title: KTH QA
|
3 |
emoji: 🤖
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: gray
|
6 |
+
sdk: docker
|
7 |
+
app_port: 7860
|
8 |
pinned: true
|
9 |
---
|
kth_qa/config.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logger = logging.getLogger()
|
3 |
+
|
4 |
+
import openai
|
5 |
+
from pydantic import BaseSettings
|
6 |
+
|
7 |
+
from langchain.chat_models import ChatOpenAI
|
8 |
+
from langchain.chains import RetrievalQAWithSourcesChain
|
9 |
+
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
10 |
+
from langchain.chains import SequentialChain
|
11 |
+
from langchain.llms import OpenAI
|
12 |
+
from langchain.chains import LLMCheckerChain
|
13 |
+
from langchain.chains.query_constructor.base import AttributeInfo
|
14 |
+
from langchain.vectorstores import Pinecone
|
15 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
16 |
+
|
17 |
+
from magic.prompts import PROMPT, EXAMPLE_PROMPT
|
18 |
+
from magic.vectordb import VectorIndex
|
19 |
+
from magic.self_query_retriever import SelfQueryRetriever
|
20 |
+
|
21 |
+
from kth_qa.utils import get_courses
|
22 |
+
|
23 |
+
import pinecone
|
24 |
+
|
25 |
+
|
26 |
+
class Settings(BaseSettings):
|
27 |
+
OPENAI_API_KEY: str = 'OPENAI_API_KEY'
|
28 |
+
OPENAI_CHAT_MODEL: str = 'gpt-3.5-turbo'
|
29 |
+
PINECONE_API_KEY: str = 'PINECONE_API_KEY'
|
30 |
+
PINECONE_INDEX_NAME: str = 'kth-qa'
|
31 |
+
PINECONE_ENV: str = 'us-west1-gcp-free'
|
32 |
+
class Config:
|
33 |
+
env_file = '.env'
|
34 |
+
|
35 |
+
def set_openai_key(key):
|
36 |
+
"""Sets OpenAI key."""
|
37 |
+
openai.api_key = key
|
38 |
+
|
39 |
+
class State:
|
40 |
+
settings: Settings
|
41 |
+
store: Pinecone | VectorIndex
|
42 |
+
chain: RetrievalQAWithSourcesChain
|
43 |
+
courses: list
|
44 |
+
|
45 |
+
def __init__(self):
|
46 |
+
self.settings = Settings()
|
47 |
+
|
48 |
+
self.courses = get_courses()
|
49 |
+
|
50 |
+
# OPENAI
|
51 |
+
set_openai_key(self.settings.OPENAI_API_KEY)
|
52 |
+
|
53 |
+
# LOCAL VECTORSTORE
|
54 |
+
# self.store: VectorIndex = VectorIndex()
|
55 |
+
|
56 |
+
# PINECONE VECTORSTORE
|
57 |
+
embeddings = OpenAIEmbeddings()
|
58 |
+
pinecone.init(api_key=self.settings.PINECONE_API_KEY, environment=self.settings.PINECONE_ENV)
|
59 |
+
self.store: Pinecone = Pinecone.from_existing_index(self.settings.PINECONE_INDEX_NAME, embeddings, "text")
|
60 |
+
logger.info(f"Pinecone store initialized")
|
61 |
+
|
62 |
+
# CHAINS
|
63 |
+
doc_chain = self._load_doc_chain()
|
64 |
+
qa_chain = self._load_qa_chain(doc_chain, self_query=True)
|
65 |
+
|
66 |
+
# JUST QA
|
67 |
+
self.chain = qa_chain
|
68 |
+
|
69 |
+
# SEQ CHAIN with QA and CHECKER
|
70 |
+
# checker_chain = self._load_checker_chain()
|
71 |
+
# self.chain = self._load_seq_chain([qa_chain, checker_chain])
|
72 |
+
|
73 |
+
def _load_seq_chain(self, chains):
|
74 |
+
sequential_chain = SequentialChain(
|
75 |
+
chains=chains,
|
76 |
+
input_variables=["question"],
|
77 |
+
output_variables=["answer"],
|
78 |
+
verbose=True)
|
79 |
+
return sequential_chain
|
80 |
+
|
81 |
+
def _load_checker_chain(self):
|
82 |
+
llm = OpenAI(temperature=0)
|
83 |
+
checker_chain = LLMCheckerChain(llm=llm, verbose=True, input_key="answer", output_key="result")
|
84 |
+
return checker_chain
|
85 |
+
|
86 |
+
def _load_doc_chain(self):
|
87 |
+
doc_chain = load_qa_with_sources_chain(
|
88 |
+
ChatOpenAI(temperature=0, max_tokens=256, model=self.settings.OPENAI_CHAT_MODEL, request_timeout=120),
|
89 |
+
chain_type="stuff",
|
90 |
+
document_variable_name="context",
|
91 |
+
prompt=PROMPT,
|
92 |
+
document_prompt=EXAMPLE_PROMPT
|
93 |
+
)
|
94 |
+
return doc_chain
|
95 |
+
|
96 |
+
def _load_qa_chain(self, doc_chain, self_query=False):
|
97 |
+
"""Load QA chain with retriever.
|
98 |
+
If self_query is True, the retriever will be a SelfQueryRetriever,
|
99 |
+
which will extract a metadata filter from question, and add to the vectorstore query.
|
100 |
+
"""
|
101 |
+
if self_query:
|
102 |
+
metadata_field_info=[
|
103 |
+
AttributeInfo(
|
104 |
+
name="course",
|
105 |
+
description="A course code for a course",
|
106 |
+
type="string"
|
107 |
+
)]
|
108 |
+
document_content_description = "Brief description of a course"
|
109 |
+
llm = OpenAI(temperature=0, model_name='text-davinci-002')
|
110 |
+
retriever = SelfQueryRetriever.from_llm(llm, self.store, document_content_description,
|
111 |
+
metadata_field_info, verbose=True)
|
112 |
+
qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain,
|
113 |
+
retriever=retriever,
|
114 |
+
return_source_documents=False)
|
115 |
+
else:
|
116 |
+
qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain,
|
117 |
+
retriever=self.store.as_retriever(),
|
118 |
+
return_source_documents=False)
|
119 |
+
return qa_chain
|
120 |
+
|
121 |
+
def course_exists(self, course: str):
|
122 |
+
course = course.upper()
|
123 |
+
exists = course in self.courses
|
124 |
+
if exists:
|
125 |
+
logger.info(f'Course {course} exists')
|
126 |
+
return True
|
127 |
+
else:
|
128 |
+
logger.info(f'Course {course} does not exist')
|
129 |
+
return False
|
130 |
+
|
131 |
+
if __name__ == '__main__':
|
132 |
+
state = State()
|
133 |
+
print(state.settings)
|
kth_qa/courses.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
kth_qa/ingest.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from utils import get_courses
|
4 |
+
logger = logging.getLogger()
|
5 |
+
|
6 |
+
import os
|
7 |
+
from langchain.docstore.document import Document
|
8 |
+
from langchain.text_splitter import NLTKTextSplitter
|
9 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
10 |
+
from langchain.vectorstores import Chroma
|
11 |
+
from langchain.callbacks import get_openai_callback
|
12 |
+
|
13 |
+
PERSIST_DIR = 'db'
|
14 |
+
FILE_DIR = 'files'
|
15 |
+
KURS_URL = "https://www.kth.se/student/kurser/kurs/{course_code}?l={language}"
|
16 |
+
DEFAULT_LANGUAGE = "en"
|
17 |
+
|
18 |
+
def ingest():
|
19 |
+
# make sure pwd is kth_qa
|
20 |
+
with get_openai_callback() as cb:
|
21 |
+
pwd = os.getcwd()
|
22 |
+
if pwd.split('/')[-1] != 'kth_qa':
|
23 |
+
logger.error(f"pwd is not kth_qa, but {pwd}. Please run from kth_qa directory.")
|
24 |
+
return
|
25 |
+
|
26 |
+
CHUNK_SIZE = 1000
|
27 |
+
|
28 |
+
embedding = OpenAIEmbeddings(chunk_size=CHUNK_SIZE)
|
29 |
+
|
30 |
+
text_splitter = NLTKTextSplitter.from_tiktoken_encoder(
|
31 |
+
chunk_size=CHUNK_SIZE,
|
32 |
+
chunk_overlap=100,
|
33 |
+
)
|
34 |
+
|
35 |
+
file_folder_name = f'files/{DEFAULT_LANGUAGE}'
|
36 |
+
file_folder = os.listdir(file_folder_name)
|
37 |
+
all_langdocs = []
|
38 |
+
for file in file_folder:
|
39 |
+
raw_docs = []
|
40 |
+
with open(f'{file_folder_name}/{file}', 'r') as f:
|
41 |
+
text = f.read()
|
42 |
+
filename = file.split('.')[0]
|
43 |
+
course_code, language = filename.split('?l=')
|
44 |
+
doc = Document(page_content=text, metadata={"source": course_code})
|
45 |
+
raw_docs.append(doc)
|
46 |
+
logger.debug(f"loaded file {file}")
|
47 |
+
|
48 |
+
langdocs = text_splitter.split_documents(raw_docs)
|
49 |
+
logger.debug(f"split documents into {len(langdocs)} chunks")
|
50 |
+
all_langdocs.extend(langdocs)
|
51 |
+
|
52 |
+
# add course title to page content in each document
|
53 |
+
logger.info(f"split all documents into {len(all_langdocs)} chunks")
|
54 |
+
|
55 |
+
logger.info(f"creating vector index in Chroma...")
|
56 |
+
vectordb = Chroma.from_documents(documents=all_langdocs,
|
57 |
+
embedding=embedding,
|
58 |
+
persist_directory=PERSIST_DIR)
|
59 |
+
logger.info(f"created vector index")
|
60 |
+
vectordb.persist()
|
61 |
+
logger.info(f"persisted vector index")
|
62 |
+
vectordb = None
|
63 |
+
logger.info(f"Done!")
|
64 |
+
|
65 |
+
logger.info(f"Total cost of openai api calls: {cb.total_cost}")
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
logging.basicConfig(level=logging.INFO)
|
69 |
+
logger.setLevel(logging.INFO)
|
70 |
+
ingest()
|
kth_qa/ingest_pinecone.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logger = logging.getLogger()
|
3 |
+
|
4 |
+
import os
|
5 |
+
from langchain.docstore.document import Document
|
6 |
+
from langchain.text_splitter import NLTKTextSplitter
|
7 |
+
from langchain.callbacks import get_openai_callback
|
8 |
+
|
9 |
+
from kth_qa.config import State
|
10 |
+
|
11 |
+
FILE_DIR = 'files'
|
12 |
+
KURS_URL = "https://www.kth.se/student/kurser/kurs/{course_code}?l={language}"
|
13 |
+
DEFAULT_LANGUAGE = "en"
|
14 |
+
CHUNK_SIZE = 1000
|
15 |
+
|
16 |
+
def ingest(state: State):
|
17 |
+
with get_openai_callback() as cb:
|
18 |
+
# make sure pwd is kth_qa
|
19 |
+
pwd = os.getcwd()
|
20 |
+
if pwd.split('/')[-1] != 'kth_qa':
|
21 |
+
logger.error(f"pwd is not kth_qa, but {pwd}. Please run from kth_qa directory.")
|
22 |
+
return
|
23 |
+
|
24 |
+
text_splitter = NLTKTextSplitter.from_tiktoken_encoder(
|
25 |
+
chunk_size=CHUNK_SIZE,
|
26 |
+
chunk_overlap=100,
|
27 |
+
)
|
28 |
+
|
29 |
+
file_folder_name = f'files/{DEFAULT_LANGUAGE}'
|
30 |
+
file_folder = os.listdir(file_folder_name)
|
31 |
+
all_langdocs = []
|
32 |
+
for file in file_folder:
|
33 |
+
raw_docs = []
|
34 |
+
with open(f'{file_folder_name}/{file}', 'r') as f:
|
35 |
+
text = f.read()
|
36 |
+
filename = file.split('.')[0]
|
37 |
+
course_code, language = filename.split('?l=')
|
38 |
+
doc = Document(page_content=text, metadata={"course": course_code})
|
39 |
+
raw_docs.append(doc)
|
40 |
+
logger.debug(f"loaded file {file}")
|
41 |
+
|
42 |
+
langdocs = text_splitter.split_documents(raw_docs)
|
43 |
+
logger.debug(f"split documents into {len(langdocs)} chunks")
|
44 |
+
all_langdocs.extend(langdocs)
|
45 |
+
|
46 |
+
logger.info(f"split all documents into {len(all_langdocs)} chunks")
|
47 |
+
|
48 |
+
logger.info(f"Adding documents to pinecone...")
|
49 |
+
state.store.add_documents(all_langdocs)
|
50 |
+
logger.info(f"...done!")
|
51 |
+
|
52 |
+
logger.info(f"Total cost of openai api calls: {cb.total_cost}")
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
logging.basicConfig(level=logging.INFO)
|
56 |
+
logger.setLevel(logging.INFO)
|
57 |
+
|
58 |
+
state = State()
|
59 |
+
ingest(state)
|
kth_qa/magic/conversational.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import re
|
3 |
+
import logging
|
4 |
+
|
5 |
+
from kth_qa.schema import Answer, Question
|
6 |
+
logger = logging.getLogger()
|
7 |
+
import re
|
8 |
+
from ingest import KURS_URL, DEFAULT_LANGUAGE
|
9 |
+
from langchain.callbacks import get_openai_callback
|
10 |
+
|
11 |
+
from config import State
|
12 |
+
|
13 |
+
COURSE_PATTERN = r"\w{2,3}\d{3,4}\w?" # e.g. DD1315
|
14 |
+
|
15 |
+
def blocking_chain(chain, request):
|
16 |
+
return chain(request, return_only_outputs=True)
|
17 |
+
|
18 |
+
async def question_handler(question: Question, state: State) -> Answer:
|
19 |
+
question = question.question
|
20 |
+
logger.info(f"Q: {question}")
|
21 |
+
|
22 |
+
cost = 0
|
23 |
+
with get_openai_callback() as cb:
|
24 |
+
result = await asyncio.to_thread(blocking_chain, state.chain, {"question": question})
|
25 |
+
cost = cb.total_cost
|
26 |
+
logger.debug(f"result: {result}")
|
27 |
+
|
28 |
+
answer = result['answer']
|
29 |
+
logger.info(f"A: {answer}")
|
30 |
+
|
31 |
+
if answer.startswith("I cannot help"):
|
32 |
+
answer = "I'm sorry, " + answer
|
33 |
+
return Answer(**{"answer": answer, "url": ""})
|
34 |
+
|
35 |
+
sources = result.get('sources')
|
36 |
+
logger.info(f"Sources: {sources}")
|
37 |
+
if sources:
|
38 |
+
sources = re.findall(COURSE_PATTERN, sources)
|
39 |
+
else:
|
40 |
+
answer, sources = split_sources(answer)
|
41 |
+
|
42 |
+
courses = [source.upper() for source in sources if state.course_exists(source)] # filter out courses that don't exist
|
43 |
+
courses = set(courses)
|
44 |
+
logger.info(f"unique courses: {courses}")
|
45 |
+
|
46 |
+
urls = [KURS_URL.format(course_code=course, language=DEFAULT_LANGUAGE) for course in courses] # format into urls
|
47 |
+
logger.info(f"urls: {urls}")
|
48 |
+
|
49 |
+
answer = answer.strip().removesuffix("(").strip()
|
50 |
+
|
51 |
+
if (not answer or len(answer) < 3) and urls:
|
52 |
+
answer = "Something went wrong, but I found a link."
|
53 |
+
|
54 |
+
logging.info(f"Cost of query: ${'{0:.2g}'.format(cost)}")
|
55 |
+
|
56 |
+
return Answer(answer=answer, urls=urls if urls else [])
|
57 |
+
|
58 |
+
def split_sources(answer: str):
|
59 |
+
patterns = [
|
60 |
+
"Sources",
|
61 |
+
"Source",
|
62 |
+
"References",
|
63 |
+
"Reference",
|
64 |
+
"sources",
|
65 |
+
"source",
|
66 |
+
"SOURCE"
|
67 |
+
]
|
68 |
+
for pattern in patterns:
|
69 |
+
if pattern in answer:
|
70 |
+
all_answers = answer.split(pattern)
|
71 |
+
if len(all_answers) == 2:
|
72 |
+
ans, sources = all_answers
|
73 |
+
courses = re.findall(COURSE_PATTERN, sources)
|
74 |
+
elif len(all_answers) > 2:
|
75 |
+
ans = ""
|
76 |
+
courses = []
|
77 |
+
for i, a in enumerate(all_answers):
|
78 |
+
if i % 2 == 0:
|
79 |
+
ans += a
|
80 |
+
else:
|
81 |
+
courses = re.findall(COURSE_PATTERN, a)
|
82 |
+
courses.extend(courses)
|
83 |
+
return ans, courses
|
84 |
+
|
85 |
+
return answer, []
|
kth_qa/magic/prompts.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from langchain import PromptTemplate
|
3 |
+
|
4 |
+
|
5 |
+
EXAMPLE_PROMPT = PromptTemplate(
|
6 |
+
template=">Course Description\n{page_content}\n----------\nSource: {course}",
|
7 |
+
input_variables=["page_content", "course"],
|
8 |
+
)
|
9 |
+
template ="""
|
10 |
+
You are a study counselor for KTH.
|
11 |
+
Given the following extracted parts of course descriptions and a question, create a short final answer with references ("SOURCES").
|
12 |
+
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
13 |
+
Just answer the question and don't add any extra information.
|
14 |
+
Only return sources that contain the answer to the question.
|
15 |
+
ALWAYS return a "SOURCES" part in your answer.
|
16 |
+
|
17 |
+
QUESTION: {question}
|
18 |
+
=========
|
19 |
+
{context}
|
20 |
+
=========
|
21 |
+
FINAL ANSWER:"""
|
22 |
+
PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
|
kth_qa/magic/self_query_retriever.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Retriever that generates and executes structured queries over its own data source.
|
2 |
+
|
3 |
+
This code is adapted from the original implementation in the LangChain repo,
|
4 |
+
but has been modified to work with the KTH QA system.
|
5 |
+
|
6 |
+
"""
|
7 |
+
|
8 |
+
import re
|
9 |
+
from typing import Any, Dict, List, Optional, Type, cast
|
10 |
+
|
11 |
+
from pydantic import BaseModel, Field, root_validator
|
12 |
+
|
13 |
+
from langchain import LLMChain
|
14 |
+
from langchain.base_language import BaseLanguageModel
|
15 |
+
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
16 |
+
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
17 |
+
from langchain.chains.query_constructor.schema import AttributeInfo
|
18 |
+
from langchain.retrievers.self_query.pinecone import PineconeTranslator
|
19 |
+
from langchain.schema import BaseRetriever, Document
|
20 |
+
from langchain.vectorstores import Pinecone, VectorStore
|
21 |
+
|
22 |
+
COURSE_PATTERN = r"\w{2,3}\d{3,4}\w?" # e.g. DD1315
|
23 |
+
|
24 |
+
|
25 |
+
def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor:
|
26 |
+
"""Get the translator class corresponding to the vector store class."""
|
27 |
+
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
|
28 |
+
Pinecone: PineconeTranslator
|
29 |
+
}
|
30 |
+
if vectorstore_cls not in BUILTIN_TRANSLATORS:
|
31 |
+
raise ValueError(
|
32 |
+
f"Self query retriever with Vector Store type {vectorstore_cls}"
|
33 |
+
f" not supported."
|
34 |
+
)
|
35 |
+
return BUILTIN_TRANSLATORS[vectorstore_cls]()
|
36 |
+
|
37 |
+
|
38 |
+
class SelfQueryRetriever(BaseRetriever, BaseModel):
|
39 |
+
"""Retriever that wraps around a vector store and uses an LLM to generate
|
40 |
+
the vector store queries."""
|
41 |
+
|
42 |
+
vectorstore: VectorStore
|
43 |
+
"""The underlying vector store from which documents will be retrieved."""
|
44 |
+
llm_chain: LLMChain
|
45 |
+
"""The LLMChain for generating the vector store queries."""
|
46 |
+
search_type: str = "similarity"
|
47 |
+
"""The search type to perform on the vector store."""
|
48 |
+
search_kwargs: dict = Field(default_factory=dict)
|
49 |
+
"""Keyword arguments to pass in to the vector store search."""
|
50 |
+
structured_query_translator: Visitor
|
51 |
+
"""Translator for turning internal query language into vectorstore search params."""
|
52 |
+
verbose: bool = False
|
53 |
+
|
54 |
+
class Config:
|
55 |
+
"""Configuration for this pydantic object."""
|
56 |
+
|
57 |
+
arbitrary_types_allowed = True
|
58 |
+
|
59 |
+
@root_validator(pre=True)
|
60 |
+
def validate_translator(cls, values: Dict) -> Dict:
|
61 |
+
"""Validate translator."""
|
62 |
+
if "structured_query_translator" not in values:
|
63 |
+
vectorstore_cls = values["vectorstore"].__class__
|
64 |
+
values["structured_query_translator"] = _get_builtin_translator(
|
65 |
+
vectorstore_cls
|
66 |
+
)
|
67 |
+
return values
|
68 |
+
|
69 |
+
def get_relevant_documents(self, query: str) -> List[Document]:
|
70 |
+
"""Get documents relevant for a query.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
query: string to find relevant documents for
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
List of relevant documents
|
77 |
+
"""
|
78 |
+
if re.findall(COURSE_PATTERN, query):
|
79 |
+
inputs = self.llm_chain.prep_inputs(query)
|
80 |
+
structured_query = cast(
|
81 |
+
StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs)
|
82 |
+
)
|
83 |
+
if self.verbose:
|
84 |
+
print("Found course pattern in query, using structured query:")
|
85 |
+
print(structured_query)
|
86 |
+
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
|
87 |
+
structured_query
|
88 |
+
)
|
89 |
+
search_kwargs = {**self.search_kwargs, **new_kwargs}
|
90 |
+
else:
|
91 |
+
search_kwargs = self.search_kwargs
|
92 |
+
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
93 |
+
return docs
|
94 |
+
|
95 |
+
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
@classmethod
|
99 |
+
def from_llm(
|
100 |
+
cls,
|
101 |
+
llm: BaseLanguageModel,
|
102 |
+
vectorstore: VectorStore,
|
103 |
+
document_contents: str,
|
104 |
+
metadata_field_info: List[AttributeInfo],
|
105 |
+
structured_query_translator: Optional[Visitor] = None,
|
106 |
+
chain_kwargs: Optional[Dict] = None,
|
107 |
+
**kwargs: Any,
|
108 |
+
) -> "SelfQueryRetriever":
|
109 |
+
if structured_query_translator is None:
|
110 |
+
structured_query_translator = _get_builtin_translator(vectorstore.__class__)
|
111 |
+
chain_kwargs = chain_kwargs or {}
|
112 |
+
if "allowed_comparators" not in chain_kwargs:
|
113 |
+
chain_kwargs[
|
114 |
+
"allowed_comparators"
|
115 |
+
] = structured_query_translator.allowed_comparators
|
116 |
+
if "allowed_operators" not in chain_kwargs:
|
117 |
+
chain_kwargs[
|
118 |
+
"allowed_operators"
|
119 |
+
] = structured_query_translator.allowed_operators
|
120 |
+
llm_chain = load_query_constructor_chain(
|
121 |
+
llm, document_contents, metadata_field_info, **chain_kwargs
|
122 |
+
)
|
123 |
+
return cls(
|
124 |
+
llm_chain=llm_chain,
|
125 |
+
vectorstore=vectorstore,
|
126 |
+
structured_query_translator=structured_query_translator,
|
127 |
+
**kwargs,
|
128 |
+
)
|
kth_qa/magic/vectordb.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logger = logging.getLogger()
|
3 |
+
import os
|
4 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
5 |
+
from langchain.vectorstores import Chroma
|
6 |
+
from ingest import PERSIST_DIR
|
7 |
+
|
8 |
+
embedding = OpenAIEmbeddings()
|
9 |
+
class VectorIndex(Chroma):
|
10 |
+
def __init__(self):
|
11 |
+
if len(os.listdir(PERSIST_DIR)) < 2: # check if there are files in the directory
|
12 |
+
logger.error(f"VectorIndex: No files in {PERSIST_DIR}, have you run ingest.py?")
|
13 |
+
super().__init__(persist_directory=PERSIST_DIR, embedding_function=embedding)
|
kth_qa/main.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from magic.conversational import question_handler
|
5 |
+
from schema import Answer
|
6 |
+
|
7 |
+
logger = logging.getLogger()
|
8 |
+
logging.basicConfig(encoding='utf-8', level=logging.INFO)
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from fastapi import FastAPI, Request
|
13 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
15 |
+
from fastapi.templating import Jinja2Templates
|
16 |
+
from fastapi.staticfiles import StaticFiles
|
17 |
+
from starlette.routing import WebSocketRoute
|
18 |
+
import uvicorn
|
19 |
+
|
20 |
+
from schema import Question
|
21 |
+
from config import State
|
22 |
+
import arel
|
23 |
+
|
24 |
+
# --- Setup ---
|
25 |
+
|
26 |
+
# hot reload
|
27 |
+
async def reload_data():
|
28 |
+
print("Reloading server data...")
|
29 |
+
|
30 |
+
BASE_PATH = Path(__file__).resolve().parent
|
31 |
+
static_path = str(BASE_PATH / "static")
|
32 |
+
template_path = str(BASE_PATH / "templates")
|
33 |
+
|
34 |
+
hotreload = arel.HotReload(
|
35 |
+
paths=[
|
36 |
+
arel.Path(static_path),
|
37 |
+
arel.Path(template_path),
|
38 |
+
],
|
39 |
+
)
|
40 |
+
|
41 |
+
state = State()
|
42 |
+
|
43 |
+
app = FastAPI(
|
44 |
+
routes=[WebSocketRoute("/hot-reload", hotreload, name="hot-reload")],
|
45 |
+
on_startup=[hotreload.startup],
|
46 |
+
on_shutdown=[hotreload.shutdown],
|
47 |
+
)
|
48 |
+
|
49 |
+
# templates
|
50 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
51 |
+
BASE_PATH = Path(__file__).resolve().parent
|
52 |
+
templates = Jinja2Templates(directory=template_path)
|
53 |
+
templates.env.globals["DEBUG"] = True
|
54 |
+
templates.env.globals["hotreload"] = hotreload
|
55 |
+
|
56 |
+
# CORS
|
57 |
+
origins = [
|
58 |
+
"http://localhost",
|
59 |
+
"http://localhost:5001",
|
60 |
+
]
|
61 |
+
|
62 |
+
app.add_middleware(
|
63 |
+
CORSMiddleware,
|
64 |
+
allow_origins=origins,
|
65 |
+
allow_credentials=True,
|
66 |
+
allow_methods=["*"],
|
67 |
+
allow_headers=["*"],
|
68 |
+
)
|
69 |
+
|
70 |
+
# test questions
|
71 |
+
with open("test_response.json", "r") as f:
|
72 |
+
test_questions = json.load(f)
|
73 |
+
|
74 |
+
# --- Routes ---
|
75 |
+
|
76 |
+
@app.get("/", response_class=HTMLResponse)
|
77 |
+
def index(request: Request):
|
78 |
+
return templates.TemplateResponse(
|
79 |
+
"index.html",
|
80 |
+
{"request": request}
|
81 |
+
)
|
82 |
+
|
83 |
+
@app.post("/api/ask", response_class=JSONResponse)
|
84 |
+
async def ask(question: Question):
|
85 |
+
question_str = question.question
|
86 |
+
if question_str in test_questions:
|
87 |
+
return test_questions[question_str]
|
88 |
+
|
89 |
+
answer = None
|
90 |
+
try:
|
91 |
+
answer: Answer = await question_handler(question, state)
|
92 |
+
except Exception as e:
|
93 |
+
logger.exception(e)
|
94 |
+
if not answer:
|
95 |
+
return JSONResponse(status_code=404, content={"answer": "Something went wrong."})
|
96 |
+
return answer.dict(include={"answer", "urls"})
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
uvicorn.run("kth_qa:app", host="localhost", port=5001, reload=True, reload_excludes=['files/', 'logs/'], reload_dirs=['/templates', '/static'])
|
kth_qa/schema.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class Question(BaseModel):
|
7 |
+
question: str
|
8 |
+
|
9 |
+
class Answer(BaseModel):
|
10 |
+
answer: str
|
11 |
+
urls: List[str]
|
kth_qa/static/styles.css
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@import "https://fonts.googleapis.com/css?family=Poppins:300,400,500,600,700";
|
2 |
+
body {
|
3 |
+
font-family: 'Poppins', sans-serif;
|
4 |
+
background: #fafafa;
|
5 |
+
display: flex;
|
6 |
+
flex-direction: column;
|
7 |
+
align-items: center;
|
8 |
+
margin-left: 20%;
|
9 |
+
margin-right: 20%;
|
10 |
+
}
|
11 |
+
|
12 |
+
* {
|
13 |
+
font-family: 'Poppins', sans-serif;
|
14 |
+
}
|
15 |
+
|
16 |
+
.container {
|
17 |
+
display: flex;
|
18 |
+
flex-direction: column;
|
19 |
+
align-items: center;
|
20 |
+
justify-content: center;
|
21 |
+
justify-items: center;
|
22 |
+
margin-top: 5%;
|
23 |
+
margin-bottom: 5%;
|
24 |
+
width: 100%;
|
25 |
+
}
|
26 |
+
|
27 |
+
h1 {
|
28 |
+
font-size: 3em;
|
29 |
+
text-align: center;
|
30 |
+
color: blue;
|
31 |
+
}
|
32 |
+
|
33 |
+
form {
|
34 |
+
display: flex;
|
35 |
+
flex-direction: row;
|
36 |
+
margin-top: auto;
|
37 |
+
font-family: 'Poppins', sans-serif;
|
38 |
+
}
|
39 |
+
|
40 |
+
button {
|
41 |
+
margin: 0.5em;
|
42 |
+
padding: 0.5em;
|
43 |
+
background-color: blue;
|
44 |
+
color: white;
|
45 |
+
border: none;
|
46 |
+
border-radius: 0.5em;
|
47 |
+
padding-left: 1em;
|
48 |
+
padding-right: 1em;
|
49 |
+
font-size: 1.1em;
|
50 |
+
font-weight: 300;
|
51 |
+
line-height: 1.7em;
|
52 |
+
transition: all 0.3s;
|
53 |
+
cursor: pointer;
|
54 |
+
align-self: self-end;
|
55 |
+
}
|
56 |
+
|
57 |
+
button:disabled {
|
58 |
+
background-color: #ccc;
|
59 |
+
color: #666;
|
60 |
+
cursor: not-allowed;
|
61 |
+
}
|
62 |
+
|
63 |
+
button:hover {
|
64 |
+
background-color: #ace;
|
65 |
+
color: #fff;
|
66 |
+
}
|
67 |
+
|
68 |
+
input {
|
69 |
+
margin: 0.5em;
|
70 |
+
padding: 0.5em;
|
71 |
+
width: 35em;
|
72 |
+
}
|
73 |
+
|
74 |
+
p {
|
75 |
+
font-size: 1.1em;
|
76 |
+
font-weight: 300;
|
77 |
+
line-height: 1.7em;
|
78 |
+
color: #999;
|
79 |
+
max-width: 30em;
|
80 |
+
}
|
81 |
+
|
82 |
+
a {
|
83 |
+
color: blue;
|
84 |
+
text-decoration: none;
|
85 |
+
transition: all 0.3s;
|
86 |
+
}
|
87 |
+
a:hover,
|
88 |
+
a:focus {
|
89 |
+
color: cornflowerblue;
|
90 |
+
text-decoration: none;
|
91 |
+
transition: all 0.3s;
|
92 |
+
}
|
93 |
+
|
94 |
+
#content {
|
95 |
+
width: 100%;
|
96 |
+
padding: 20px;
|
97 |
+
min-height: 100vh;
|
98 |
+
transition: all 0.3s;
|
99 |
+
position: absolute;
|
100 |
+
top: 0;
|
101 |
+
right: 0;
|
102 |
+
}
|
kth_qa/templates/index.html
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<html>
|
2 |
+
|
3 |
+
<head>
|
4 |
+
<title>KTH Q&A</title>
|
5 |
+
<link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet">
|
6 |
+
</head>
|
7 |
+
|
8 |
+
<body>
|
9 |
+
<div class="container">
|
10 |
+
<h1>KTH Q&A</h1>
|
11 |
+
<form name="form" method="post">
|
12 |
+
<input type="text" name="question" />
|
13 |
+
<button type="submit" id="ask">Ask</button>
|
14 |
+
</form>
|
15 |
+
<p id="answer"></p>
|
16 |
+
<p id="readmore"></p>
|
17 |
+
<ul id="urls"></ul>
|
18 |
+
</div>
|
19 |
+
<script>
|
20 |
+
const form = document.querySelector('form');
|
21 |
+
form.addEventListener('submit', async (event) => {
|
22 |
+
document.getElementById('ask').disabled = true;
|
23 |
+
event.preventDefault();
|
24 |
+
const formData = new FormData(form);
|
25 |
+
const question = formData.get('question');
|
26 |
+
const response = await fetch('/api/ask', {
|
27 |
+
method: 'POST',
|
28 |
+
body: JSON.stringify({ question }),
|
29 |
+
headers: {
|
30 |
+
'content-type': 'application/json'
|
31 |
+
}
|
32 |
+
});
|
33 |
+
const data = await response.json();
|
34 |
+
console.log(data);
|
35 |
+
document.querySelector('#answer').textContent = data.answer;
|
36 |
+
if (data.urls && data.urls.length > 0) {
|
37 |
+
document.getElementById('readmore').textContent = 'You might find related info at: ';
|
38 |
+
const urls = document.getElementById('urls');
|
39 |
+
urls.innerHTML = '';
|
40 |
+
data.urls.forEach(url => {
|
41 |
+
const li = document.createElement('li');
|
42 |
+
const a = document.createElement('a');
|
43 |
+
a.href = url;
|
44 |
+
a.textContent = url;
|
45 |
+
li.appendChild(a);
|
46 |
+
urls.appendChild(li);
|
47 |
+
});
|
48 |
+
} else {
|
49 |
+
document.getElementById('readmore').textContent = '';
|
50 |
+
document.getElementById('urls').innerHTML = '';
|
51 |
+
}
|
52 |
+
document.getElementById('ask').disabled = false;
|
53 |
+
});
|
54 |
+
</script>
|
55 |
+
{% if DEBUG %}
|
56 |
+
{{ hotreload.script(url_for('hot-reload')) | safe }}
|
57 |
+
{% endif %}
|
58 |
+
</body>
|
59 |
+
|
60 |
+
</html>
|
kth_qa/test_response.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"test": {
|
3 |
+
"answer": "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi ultrices nibh id molestie posuere. Aliquam mollis, nibh ut ultrices pretium, sem sem lacinia mi, consectetur interdum orci sapien et justo. In tempor sed eros quis tincidunt. Curabitur nunc ante, lobortis quis nunc eu, molestie consequat lacus. Pellentesque molestie interdum pellentesque. Phasellus non cursus risus. Suspendisse dictum tempor scelerisque. Vivamus et consectetur ante. Fusce bibendum augue mauris, sed scelerisque quam pharetra at. Vivamus scelerisque tristique elit eu commodo. Nam eu ante dui. Proin vestibulum quam id nisl commodo, quis malesuada erat sollicitudin. Sed lorem nisi, pellentesque eget pulvinar vel, faucibus a mi. Nulla quis velit porttitor, consequat lorem ut, egestas mi.",
|
4 |
+
"urls": ["https://kth.se"]
|
5 |
+
},
|
6 |
+
"no url": {
|
7 |
+
"answer": "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi ultrices nibh id molestie posuere. Aliquam mollis, nibh ut ultrices pretium, sem sem lacinia mi, consectetur interdum orci sapien et justo. In tempor sed eros quis tincidunt. Curabitur nunc ante, lobortis quis nunc eu, molestie consequat lacus. Pellentesque molestie interdum pellentesque. Phasellus non cursus risus. Suspendisse dictum tempor scelerisque. Vivamus et consectetur ante. Fusce bibendum augue mauris, sed scelerisque quam pharetra at. Vivamus scelerisque tristique elit eu commodo. Nam eu ante dui. Proin vestibulum quam id nisl commodo, quis malesuada erat sollicitudin. Sed lorem nisi, pellentesque eget pulvinar vel, faucibus a mi. Nulla quis velit porttitor, consequat lorem ut, egestas mi."
|
8 |
+
},
|
9 |
+
"empty url": {
|
10 |
+
"answer": "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi ultrices nibh id molestie posuere. Aliquam mollis, nibh ut ultrices pretium, sem sem lacinia mi, consectetur interdum orci sapien et justo. In tempor sed eros quis tincidunt. Curabitur nunc ante, lobortis quis nunc eu, molestie consequat lacus. Pellentesque molestie interdum pellentesque. Phasellus non cursus risus. Suspendisse dictum tempor scelerisque. Vivamus et consectetur ante. Fusce bibendum augue mauris, sed scelerisque quam pharetra at. Vivamus scelerisque tristique elit eu commodo. Nam eu ante dui. Proin vestibulum quam id nisl commodo, quis malesuada erat sollicitudin. Sed lorem nisi, pellentesque eget pulvinar vel, faucibus a mi. Nulla quis velit porttitor, consequat lorem ut, egestas mi.",
|
11 |
+
"urls": []
|
12 |
+
},
|
13 |
+
"many urls": {
|
14 |
+
"answer": "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi ultrices nibh id molestie posuere. Aliquam mollis, nibh ut ultrices pretium, sem sem lacinia mi, consectetur interdum orci sapien et justo. In tempor sed eros quis tincidunt. Curabitur nunc ante, lobortis quis nunc eu, molestie consequat lacus. Pellentesque molestie interdum pellentesque. Phasellus non cursus risus. Suspendisse dictum tempor scelerisque. Vivamus et consectetur ante. Fusce bibendum augue mauris, sed scelerisque quam pharetra at. Vivamus scelerisque tristique elit eu commodo. Nam eu ante dui. Proin vestibulum quam id nisl commodo, quis malesuada erat sollicitudin. Sed lorem nisi, pellentesque eget pulvinar vel, faucibus a mi. Nulla quis velit porttitor, consequat lorem ut, egestas mi.",
|
15 |
+
"url": "",
|
16 |
+
"urls": ["https://kth.se", "https://kth.se", "https://kth.se"]
|
17 |
+
}
|
18 |
+
}
|
kth_qa/utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
def touch_folder(folder):
|
5 |
+
if not os.path.exists(folder):
|
6 |
+
os.makedirs(folder)
|
7 |
+
|
8 |
+
def get_courses():
|
9 |
+
try:
|
10 |
+
with open('kth_qa/courses.json', 'r') as f:
|
11 |
+
data = json.load(f)
|
12 |
+
except FileNotFoundError:
|
13 |
+
try:
|
14 |
+
with open('courses.json', 'r') as f:
|
15 |
+
data = json.load(f)
|
16 |
+
except FileNotFoundError:
|
17 |
+
raise FileNotFoundError('courses.json not found')
|
18 |
+
courses = data.get('courses')
|
19 |
+
return courses
|
20 |
+
|
21 |
+
if __name__ == '__main__':
|
22 |
+
courses = get_courses()
|
23 |
+
print(len(courses))
|
24 |
+
new_courses = {}
|
25 |
+
for c in courses.keys():
|
26 |
+
if c[:2] in ['ME', 'DA', 'DM', 'DT', 'DH', 'MF', "EI"]:
|
27 |
+
new_courses[c] = courses[c]
|
28 |
+
|
29 |
+
print(len(new_courses))
|