erseux commited on
Commit
8cb8290
1 Parent(s): a4a165d

huggingface init

Browse files
README.md CHANGED
@@ -1,8 +1,9 @@
1
  ---
2
  title: KTH QA
3
  emoji: 🤖
4
- colorFrom: "#FFD700"
5
- colorTo: "#FF8C00"
6
- app_file: kth_qa/main.py
 
7
  pinned: true
8
  ---
 
1
  ---
2
  title: KTH QA
3
  emoji: 🤖
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: docker
7
+ app_port: 7860
8
  pinned: true
9
  ---
kth_qa/config.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logger = logging.getLogger()
3
+
4
+ import openai
5
+ from pydantic import BaseSettings
6
+
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.chains import RetrievalQAWithSourcesChain
9
+ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
10
+ from langchain.chains import SequentialChain
11
+ from langchain.llms import OpenAI
12
+ from langchain.chains import LLMCheckerChain
13
+ from langchain.chains.query_constructor.base import AttributeInfo
14
+ from langchain.vectorstores import Pinecone
15
+ from langchain.embeddings.openai import OpenAIEmbeddings
16
+
17
+ from magic.prompts import PROMPT, EXAMPLE_PROMPT
18
+ from magic.vectordb import VectorIndex
19
+ from magic.self_query_retriever import SelfQueryRetriever
20
+
21
+ from kth_qa.utils import get_courses
22
+
23
+ import pinecone
24
+
25
+
26
+ class Settings(BaseSettings):
27
+ OPENAI_API_KEY: str = 'OPENAI_API_KEY'
28
+ OPENAI_CHAT_MODEL: str = 'gpt-3.5-turbo'
29
+ PINECONE_API_KEY: str = 'PINECONE_API_KEY'
30
+ PINECONE_INDEX_NAME: str = 'kth-qa'
31
+ PINECONE_ENV: str = 'us-west1-gcp-free'
32
+ class Config:
33
+ env_file = '.env'
34
+
35
+ def set_openai_key(key):
36
+ """Sets OpenAI key."""
37
+ openai.api_key = key
38
+
39
+ class State:
40
+ settings: Settings
41
+ store: Pinecone | VectorIndex
42
+ chain: RetrievalQAWithSourcesChain
43
+ courses: list
44
+
45
+ def __init__(self):
46
+ self.settings = Settings()
47
+
48
+ self.courses = get_courses()
49
+
50
+ # OPENAI
51
+ set_openai_key(self.settings.OPENAI_API_KEY)
52
+
53
+ # LOCAL VECTORSTORE
54
+ # self.store: VectorIndex = VectorIndex()
55
+
56
+ # PINECONE VECTORSTORE
57
+ embeddings = OpenAIEmbeddings()
58
+ pinecone.init(api_key=self.settings.PINECONE_API_KEY, environment=self.settings.PINECONE_ENV)
59
+ self.store: Pinecone = Pinecone.from_existing_index(self.settings.PINECONE_INDEX_NAME, embeddings, "text")
60
+ logger.info(f"Pinecone store initialized")
61
+
62
+ # CHAINS
63
+ doc_chain = self._load_doc_chain()
64
+ qa_chain = self._load_qa_chain(doc_chain, self_query=True)
65
+
66
+ # JUST QA
67
+ self.chain = qa_chain
68
+
69
+ # SEQ CHAIN with QA and CHECKER
70
+ # checker_chain = self._load_checker_chain()
71
+ # self.chain = self._load_seq_chain([qa_chain, checker_chain])
72
+
73
+ def _load_seq_chain(self, chains):
74
+ sequential_chain = SequentialChain(
75
+ chains=chains,
76
+ input_variables=["question"],
77
+ output_variables=["answer"],
78
+ verbose=True)
79
+ return sequential_chain
80
+
81
+ def _load_checker_chain(self):
82
+ llm = OpenAI(temperature=0)
83
+ checker_chain = LLMCheckerChain(llm=llm, verbose=True, input_key="answer", output_key="result")
84
+ return checker_chain
85
+
86
+ def _load_doc_chain(self):
87
+ doc_chain = load_qa_with_sources_chain(
88
+ ChatOpenAI(temperature=0, max_tokens=256, model=self.settings.OPENAI_CHAT_MODEL, request_timeout=120),
89
+ chain_type="stuff",
90
+ document_variable_name="context",
91
+ prompt=PROMPT,
92
+ document_prompt=EXAMPLE_PROMPT
93
+ )
94
+ return doc_chain
95
+
96
+ def _load_qa_chain(self, doc_chain, self_query=False):
97
+ """Load QA chain with retriever.
98
+ If self_query is True, the retriever will be a SelfQueryRetriever,
99
+ which will extract a metadata filter from question, and add to the vectorstore query.
100
+ """
101
+ if self_query:
102
+ metadata_field_info=[
103
+ AttributeInfo(
104
+ name="course",
105
+ description="A course code for a course",
106
+ type="string"
107
+ )]
108
+ document_content_description = "Brief description of a course"
109
+ llm = OpenAI(temperature=0, model_name='text-davinci-002')
110
+ retriever = SelfQueryRetriever.from_llm(llm, self.store, document_content_description,
111
+ metadata_field_info, verbose=True)
112
+ qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain,
113
+ retriever=retriever,
114
+ return_source_documents=False)
115
+ else:
116
+ qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain,
117
+ retriever=self.store.as_retriever(),
118
+ return_source_documents=False)
119
+ return qa_chain
120
+
121
+ def course_exists(self, course: str):
122
+ course = course.upper()
123
+ exists = course in self.courses
124
+ if exists:
125
+ logger.info(f'Course {course} exists')
126
+ return True
127
+ else:
128
+ logger.info(f'Course {course} does not exist')
129
+ return False
130
+
131
+ if __name__ == '__main__':
132
+ state = State()
133
+ print(state.settings)
kth_qa/courses.json ADDED
The diff for this file is too large to render. See raw diff
 
kth_qa/ingest.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from utils import get_courses
4
+ logger = logging.getLogger()
5
+
6
+ import os
7
+ from langchain.docstore.document import Document
8
+ from langchain.text_splitter import NLTKTextSplitter
9
+ from langchain.embeddings.openai import OpenAIEmbeddings
10
+ from langchain.vectorstores import Chroma
11
+ from langchain.callbacks import get_openai_callback
12
+
13
+ PERSIST_DIR = 'db'
14
+ FILE_DIR = 'files'
15
+ KURS_URL = "https://www.kth.se/student/kurser/kurs/{course_code}?l={language}"
16
+ DEFAULT_LANGUAGE = "en"
17
+
18
+ def ingest():
19
+ # make sure pwd is kth_qa
20
+ with get_openai_callback() as cb:
21
+ pwd = os.getcwd()
22
+ if pwd.split('/')[-1] != 'kth_qa':
23
+ logger.error(f"pwd is not kth_qa, but {pwd}. Please run from kth_qa directory.")
24
+ return
25
+
26
+ CHUNK_SIZE = 1000
27
+
28
+ embedding = OpenAIEmbeddings(chunk_size=CHUNK_SIZE)
29
+
30
+ text_splitter = NLTKTextSplitter.from_tiktoken_encoder(
31
+ chunk_size=CHUNK_SIZE,
32
+ chunk_overlap=100,
33
+ )
34
+
35
+ file_folder_name = f'files/{DEFAULT_LANGUAGE}'
36
+ file_folder = os.listdir(file_folder_name)
37
+ all_langdocs = []
38
+ for file in file_folder:
39
+ raw_docs = []
40
+ with open(f'{file_folder_name}/{file}', 'r') as f:
41
+ text = f.read()
42
+ filename = file.split('.')[0]
43
+ course_code, language = filename.split('?l=')
44
+ doc = Document(page_content=text, metadata={"source": course_code})
45
+ raw_docs.append(doc)
46
+ logger.debug(f"loaded file {file}")
47
+
48
+ langdocs = text_splitter.split_documents(raw_docs)
49
+ logger.debug(f"split documents into {len(langdocs)} chunks")
50
+ all_langdocs.extend(langdocs)
51
+
52
+ # add course title to page content in each document
53
+ logger.info(f"split all documents into {len(all_langdocs)} chunks")
54
+
55
+ logger.info(f"creating vector index in Chroma...")
56
+ vectordb = Chroma.from_documents(documents=all_langdocs,
57
+ embedding=embedding,
58
+ persist_directory=PERSIST_DIR)
59
+ logger.info(f"created vector index")
60
+ vectordb.persist()
61
+ logger.info(f"persisted vector index")
62
+ vectordb = None
63
+ logger.info(f"Done!")
64
+
65
+ logger.info(f"Total cost of openai api calls: {cb.total_cost}")
66
+
67
+ if __name__ == "__main__":
68
+ logging.basicConfig(level=logging.INFO)
69
+ logger.setLevel(logging.INFO)
70
+ ingest()
kth_qa/ingest_pinecone.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logger = logging.getLogger()
3
+
4
+ import os
5
+ from langchain.docstore.document import Document
6
+ from langchain.text_splitter import NLTKTextSplitter
7
+ from langchain.callbacks import get_openai_callback
8
+
9
+ from kth_qa.config import State
10
+
11
+ FILE_DIR = 'files'
12
+ KURS_URL = "https://www.kth.se/student/kurser/kurs/{course_code}?l={language}"
13
+ DEFAULT_LANGUAGE = "en"
14
+ CHUNK_SIZE = 1000
15
+
16
+ def ingest(state: State):
17
+ with get_openai_callback() as cb:
18
+ # make sure pwd is kth_qa
19
+ pwd = os.getcwd()
20
+ if pwd.split('/')[-1] != 'kth_qa':
21
+ logger.error(f"pwd is not kth_qa, but {pwd}. Please run from kth_qa directory.")
22
+ return
23
+
24
+ text_splitter = NLTKTextSplitter.from_tiktoken_encoder(
25
+ chunk_size=CHUNK_SIZE,
26
+ chunk_overlap=100,
27
+ )
28
+
29
+ file_folder_name = f'files/{DEFAULT_LANGUAGE}'
30
+ file_folder = os.listdir(file_folder_name)
31
+ all_langdocs = []
32
+ for file in file_folder:
33
+ raw_docs = []
34
+ with open(f'{file_folder_name}/{file}', 'r') as f:
35
+ text = f.read()
36
+ filename = file.split('.')[0]
37
+ course_code, language = filename.split('?l=')
38
+ doc = Document(page_content=text, metadata={"course": course_code})
39
+ raw_docs.append(doc)
40
+ logger.debug(f"loaded file {file}")
41
+
42
+ langdocs = text_splitter.split_documents(raw_docs)
43
+ logger.debug(f"split documents into {len(langdocs)} chunks")
44
+ all_langdocs.extend(langdocs)
45
+
46
+ logger.info(f"split all documents into {len(all_langdocs)} chunks")
47
+
48
+ logger.info(f"Adding documents to pinecone...")
49
+ state.store.add_documents(all_langdocs)
50
+ logger.info(f"...done!")
51
+
52
+ logger.info(f"Total cost of openai api calls: {cb.total_cost}")
53
+
54
+ if __name__ == "__main__":
55
+ logging.basicConfig(level=logging.INFO)
56
+ logger.setLevel(logging.INFO)
57
+
58
+ state = State()
59
+ ingest(state)
kth_qa/magic/conversational.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+ import logging
4
+
5
+ from kth_qa.schema import Answer, Question
6
+ logger = logging.getLogger()
7
+ import re
8
+ from ingest import KURS_URL, DEFAULT_LANGUAGE
9
+ from langchain.callbacks import get_openai_callback
10
+
11
+ from config import State
12
+
13
+ COURSE_PATTERN = r"\w{2,3}\d{3,4}\w?" # e.g. DD1315
14
+
15
+ def blocking_chain(chain, request):
16
+ return chain(request, return_only_outputs=True)
17
+
18
+ async def question_handler(question: Question, state: State) -> Answer:
19
+ question = question.question
20
+ logger.info(f"Q: {question}")
21
+
22
+ cost = 0
23
+ with get_openai_callback() as cb:
24
+ result = await asyncio.to_thread(blocking_chain, state.chain, {"question": question})
25
+ cost = cb.total_cost
26
+ logger.debug(f"result: {result}")
27
+
28
+ answer = result['answer']
29
+ logger.info(f"A: {answer}")
30
+
31
+ if answer.startswith("I cannot help"):
32
+ answer = "I'm sorry, " + answer
33
+ return Answer(**{"answer": answer, "url": ""})
34
+
35
+ sources = result.get('sources')
36
+ logger.info(f"Sources: {sources}")
37
+ if sources:
38
+ sources = re.findall(COURSE_PATTERN, sources)
39
+ else:
40
+ answer, sources = split_sources(answer)
41
+
42
+ courses = [source.upper() for source in sources if state.course_exists(source)] # filter out courses that don't exist
43
+ courses = set(courses)
44
+ logger.info(f"unique courses: {courses}")
45
+
46
+ urls = [KURS_URL.format(course_code=course, language=DEFAULT_LANGUAGE) for course in courses] # format into urls
47
+ logger.info(f"urls: {urls}")
48
+
49
+ answer = answer.strip().removesuffix("(").strip()
50
+
51
+ if (not answer or len(answer) < 3) and urls:
52
+ answer = "Something went wrong, but I found a link."
53
+
54
+ logging.info(f"Cost of query: ${'{0:.2g}'.format(cost)}")
55
+
56
+ return Answer(answer=answer, urls=urls if urls else [])
57
+
58
+ def split_sources(answer: str):
59
+ patterns = [
60
+ "Sources",
61
+ "Source",
62
+ "References",
63
+ "Reference",
64
+ "sources",
65
+ "source",
66
+ "SOURCE"
67
+ ]
68
+ for pattern in patterns:
69
+ if pattern in answer:
70
+ all_answers = answer.split(pattern)
71
+ if len(all_answers) == 2:
72
+ ans, sources = all_answers
73
+ courses = re.findall(COURSE_PATTERN, sources)
74
+ elif len(all_answers) > 2:
75
+ ans = ""
76
+ courses = []
77
+ for i, a in enumerate(all_answers):
78
+ if i % 2 == 0:
79
+ ans += a
80
+ else:
81
+ courses = re.findall(COURSE_PATTERN, a)
82
+ courses.extend(courses)
83
+ return ans, courses
84
+
85
+ return answer, []
kth_qa/magic/prompts.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain import PromptTemplate
3
+
4
+
5
+ EXAMPLE_PROMPT = PromptTemplate(
6
+ template=">Course Description\n{page_content}\n----------\nSource: {course}",
7
+ input_variables=["page_content", "course"],
8
+ )
9
+ template ="""
10
+ You are a study counselor for KTH.
11
+ Given the following extracted parts of course descriptions and a question, create a short final answer with references ("SOURCES").
12
+ If you don't know the answer, just say that you don't know. Don't try to make up an answer.
13
+ Just answer the question and don't add any extra information.
14
+ Only return sources that contain the answer to the question.
15
+ ALWAYS return a "SOURCES" part in your answer.
16
+
17
+ QUESTION: {question}
18
+ =========
19
+ {context}
20
+ =========
21
+ FINAL ANSWER:"""
22
+ PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
kth_qa/magic/self_query_retriever.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retriever that generates and executes structured queries over its own data source.
2
+
3
+ This code is adapted from the original implementation in the LangChain repo,
4
+ but has been modified to work with the KTH QA system.
5
+
6
+ """
7
+
8
+ import re
9
+ from typing import Any, Dict, List, Optional, Type, cast
10
+
11
+ from pydantic import BaseModel, Field, root_validator
12
+
13
+ from langchain import LLMChain
14
+ from langchain.base_language import BaseLanguageModel
15
+ from langchain.chains.query_constructor.base import load_query_constructor_chain
16
+ from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
17
+ from langchain.chains.query_constructor.schema import AttributeInfo
18
+ from langchain.retrievers.self_query.pinecone import PineconeTranslator
19
+ from langchain.schema import BaseRetriever, Document
20
+ from langchain.vectorstores import Pinecone, VectorStore
21
+
22
+ COURSE_PATTERN = r"\w{2,3}\d{3,4}\w?" # e.g. DD1315
23
+
24
+
25
+ def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor:
26
+ """Get the translator class corresponding to the vector store class."""
27
+ BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
28
+ Pinecone: PineconeTranslator
29
+ }
30
+ if vectorstore_cls not in BUILTIN_TRANSLATORS:
31
+ raise ValueError(
32
+ f"Self query retriever with Vector Store type {vectorstore_cls}"
33
+ f" not supported."
34
+ )
35
+ return BUILTIN_TRANSLATORS[vectorstore_cls]()
36
+
37
+
38
+ class SelfQueryRetriever(BaseRetriever, BaseModel):
39
+ """Retriever that wraps around a vector store and uses an LLM to generate
40
+ the vector store queries."""
41
+
42
+ vectorstore: VectorStore
43
+ """The underlying vector store from which documents will be retrieved."""
44
+ llm_chain: LLMChain
45
+ """The LLMChain for generating the vector store queries."""
46
+ search_type: str = "similarity"
47
+ """The search type to perform on the vector store."""
48
+ search_kwargs: dict = Field(default_factory=dict)
49
+ """Keyword arguments to pass in to the vector store search."""
50
+ structured_query_translator: Visitor
51
+ """Translator for turning internal query language into vectorstore search params."""
52
+ verbose: bool = False
53
+
54
+ class Config:
55
+ """Configuration for this pydantic object."""
56
+
57
+ arbitrary_types_allowed = True
58
+
59
+ @root_validator(pre=True)
60
+ def validate_translator(cls, values: Dict) -> Dict:
61
+ """Validate translator."""
62
+ if "structured_query_translator" not in values:
63
+ vectorstore_cls = values["vectorstore"].__class__
64
+ values["structured_query_translator"] = _get_builtin_translator(
65
+ vectorstore_cls
66
+ )
67
+ return values
68
+
69
+ def get_relevant_documents(self, query: str) -> List[Document]:
70
+ """Get documents relevant for a query.
71
+
72
+ Args:
73
+ query: string to find relevant documents for
74
+
75
+ Returns:
76
+ List of relevant documents
77
+ """
78
+ if re.findall(COURSE_PATTERN, query):
79
+ inputs = self.llm_chain.prep_inputs(query)
80
+ structured_query = cast(
81
+ StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs)
82
+ )
83
+ if self.verbose:
84
+ print("Found course pattern in query, using structured query:")
85
+ print(structured_query)
86
+ new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
87
+ structured_query
88
+ )
89
+ search_kwargs = {**self.search_kwargs, **new_kwargs}
90
+ else:
91
+ search_kwargs = self.search_kwargs
92
+ docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
93
+ return docs
94
+
95
+ async def aget_relevant_documents(self, query: str) -> List[Document]:
96
+ raise NotImplementedError
97
+
98
+ @classmethod
99
+ def from_llm(
100
+ cls,
101
+ llm: BaseLanguageModel,
102
+ vectorstore: VectorStore,
103
+ document_contents: str,
104
+ metadata_field_info: List[AttributeInfo],
105
+ structured_query_translator: Optional[Visitor] = None,
106
+ chain_kwargs: Optional[Dict] = None,
107
+ **kwargs: Any,
108
+ ) -> "SelfQueryRetriever":
109
+ if structured_query_translator is None:
110
+ structured_query_translator = _get_builtin_translator(vectorstore.__class__)
111
+ chain_kwargs = chain_kwargs or {}
112
+ if "allowed_comparators" not in chain_kwargs:
113
+ chain_kwargs[
114
+ "allowed_comparators"
115
+ ] = structured_query_translator.allowed_comparators
116
+ if "allowed_operators" not in chain_kwargs:
117
+ chain_kwargs[
118
+ "allowed_operators"
119
+ ] = structured_query_translator.allowed_operators
120
+ llm_chain = load_query_constructor_chain(
121
+ llm, document_contents, metadata_field_info, **chain_kwargs
122
+ )
123
+ return cls(
124
+ llm_chain=llm_chain,
125
+ vectorstore=vectorstore,
126
+ structured_query_translator=structured_query_translator,
127
+ **kwargs,
128
+ )
kth_qa/magic/vectordb.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logger = logging.getLogger()
3
+ import os
4
+ from langchain.embeddings.openai import OpenAIEmbeddings
5
+ from langchain.vectorstores import Chroma
6
+ from ingest import PERSIST_DIR
7
+
8
+ embedding = OpenAIEmbeddings()
9
+ class VectorIndex(Chroma):
10
+ def __init__(self):
11
+ if len(os.listdir(PERSIST_DIR)) < 2: # check if there are files in the directory
12
+ logger.error(f"VectorIndex: No files in {PERSIST_DIR}, have you run ingest.py?")
13
+ super().__init__(persist_directory=PERSIST_DIR, embedding_function=embedding)
kth_qa/main.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+
4
+ from magic.conversational import question_handler
5
+ from schema import Answer
6
+
7
+ logger = logging.getLogger()
8
+ logging.basicConfig(encoding='utf-8', level=logging.INFO)
9
+
10
+ from pathlib import Path
11
+
12
+ from fastapi import FastAPI, Request
13
+ from fastapi.responses import HTMLResponse, JSONResponse
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.templating import Jinja2Templates
16
+ from fastapi.staticfiles import StaticFiles
17
+ from starlette.routing import WebSocketRoute
18
+ import uvicorn
19
+
20
+ from schema import Question
21
+ from config import State
22
+ import arel
23
+
24
+ # --- Setup ---
25
+
26
+ # hot reload
27
+ async def reload_data():
28
+ print("Reloading server data...")
29
+
30
+ BASE_PATH = Path(__file__).resolve().parent
31
+ static_path = str(BASE_PATH / "static")
32
+ template_path = str(BASE_PATH / "templates")
33
+
34
+ hotreload = arel.HotReload(
35
+ paths=[
36
+ arel.Path(static_path),
37
+ arel.Path(template_path),
38
+ ],
39
+ )
40
+
41
+ state = State()
42
+
43
+ app = FastAPI(
44
+ routes=[WebSocketRoute("/hot-reload", hotreload, name="hot-reload")],
45
+ on_startup=[hotreload.startup],
46
+ on_shutdown=[hotreload.shutdown],
47
+ )
48
+
49
+ # templates
50
+ app.mount("/static", StaticFiles(directory="static"), name="static")
51
+ BASE_PATH = Path(__file__).resolve().parent
52
+ templates = Jinja2Templates(directory=template_path)
53
+ templates.env.globals["DEBUG"] = True
54
+ templates.env.globals["hotreload"] = hotreload
55
+
56
+ # CORS
57
+ origins = [
58
+ "http://localhost",
59
+ "http://localhost:5001",
60
+ ]
61
+
62
+ app.add_middleware(
63
+ CORSMiddleware,
64
+ allow_origins=origins,
65
+ allow_credentials=True,
66
+ allow_methods=["*"],
67
+ allow_headers=["*"],
68
+ )
69
+
70
+ # test questions
71
+ with open("test_response.json", "r") as f:
72
+ test_questions = json.load(f)
73
+
74
+ # --- Routes ---
75
+
76
+ @app.get("/", response_class=HTMLResponse)
77
+ def index(request: Request):
78
+ return templates.TemplateResponse(
79
+ "index.html",
80
+ {"request": request}
81
+ )
82
+
83
+ @app.post("/api/ask", response_class=JSONResponse)
84
+ async def ask(question: Question):
85
+ question_str = question.question
86
+ if question_str in test_questions:
87
+ return test_questions[question_str]
88
+
89
+ answer = None
90
+ try:
91
+ answer: Answer = await question_handler(question, state)
92
+ except Exception as e:
93
+ logger.exception(e)
94
+ if not answer:
95
+ return JSONResponse(status_code=404, content={"answer": "Something went wrong."})
96
+ return answer.dict(include={"answer", "urls"})
97
+
98
+
99
+ if __name__ == "__main__":
100
+ uvicorn.run("kth_qa:app", host="localhost", port=5001, reload=True, reload_excludes=['files/', 'logs/'], reload_dirs=['/templates', '/static'])
kth_qa/schema.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class Question(BaseModel):
7
+ question: str
8
+
9
+ class Answer(BaseModel):
10
+ answer: str
11
+ urls: List[str]
kth_qa/static/styles.css ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import "https://fonts.googleapis.com/css?family=Poppins:300,400,500,600,700";
2
+ body {
3
+ font-family: 'Poppins', sans-serif;
4
+ background: #fafafa;
5
+ display: flex;
6
+ flex-direction: column;
7
+ align-items: center;
8
+ margin-left: 20%;
9
+ margin-right: 20%;
10
+ }
11
+
12
+ * {
13
+ font-family: 'Poppins', sans-serif;
14
+ }
15
+
16
+ .container {
17
+ display: flex;
18
+ flex-direction: column;
19
+ align-items: center;
20
+ justify-content: center;
21
+ justify-items: center;
22
+ margin-top: 5%;
23
+ margin-bottom: 5%;
24
+ width: 100%;
25
+ }
26
+
27
+ h1 {
28
+ font-size: 3em;
29
+ text-align: center;
30
+ color: blue;
31
+ }
32
+
33
+ form {
34
+ display: flex;
35
+ flex-direction: row;
36
+ margin-top: auto;
37
+ font-family: 'Poppins', sans-serif;
38
+ }
39
+
40
+ button {
41
+ margin: 0.5em;
42
+ padding: 0.5em;
43
+ background-color: blue;
44
+ color: white;
45
+ border: none;
46
+ border-radius: 0.5em;
47
+ padding-left: 1em;
48
+ padding-right: 1em;
49
+ font-size: 1.1em;
50
+ font-weight: 300;
51
+ line-height: 1.7em;
52
+ transition: all 0.3s;
53
+ cursor: pointer;
54
+ align-self: self-end;
55
+ }
56
+
57
+ button:disabled {
58
+ background-color: #ccc;
59
+ color: #666;
60
+ cursor: not-allowed;
61
+ }
62
+
63
+ button:hover {
64
+ background-color: #ace;
65
+ color: #fff;
66
+ }
67
+
68
+ input {
69
+ margin: 0.5em;
70
+ padding: 0.5em;
71
+ width: 35em;
72
+ }
73
+
74
+ p {
75
+ font-size: 1.1em;
76
+ font-weight: 300;
77
+ line-height: 1.7em;
78
+ color: #999;
79
+ max-width: 30em;
80
+ }
81
+
82
+ a {
83
+ color: blue;
84
+ text-decoration: none;
85
+ transition: all 0.3s;
86
+ }
87
+ a:hover,
88
+ a:focus {
89
+ color: cornflowerblue;
90
+ text-decoration: none;
91
+ transition: all 0.3s;
92
+ }
93
+
94
+ #content {
95
+ width: 100%;
96
+ padding: 20px;
97
+ min-height: 100vh;
98
+ transition: all 0.3s;
99
+ position: absolute;
100
+ top: 0;
101
+ right: 0;
102
+ }
kth_qa/templates/index.html ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+
3
+ <head>
4
+ <title>KTH Q&A</title>
5
+ <link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet">
6
+ </head>
7
+
8
+ <body>
9
+ <div class="container">
10
+ <h1>KTH Q&A</h1>
11
+ <form name="form" method="post">
12
+ <input type="text" name="question" />
13
+ <button type="submit" id="ask">Ask</button>
14
+ </form>
15
+ <p id="answer"></p>
16
+ <p id="readmore"></p>
17
+ <ul id="urls"></ul>
18
+ </div>
19
+ <script>
20
+ const form = document.querySelector('form');
21
+ form.addEventListener('submit', async (event) => {
22
+ document.getElementById('ask').disabled = true;
23
+ event.preventDefault();
24
+ const formData = new FormData(form);
25
+ const question = formData.get('question');
26
+ const response = await fetch('/api/ask', {
27
+ method: 'POST',
28
+ body: JSON.stringify({ question }),
29
+ headers: {
30
+ 'content-type': 'application/json'
31
+ }
32
+ });
33
+ const data = await response.json();
34
+ console.log(data);
35
+ document.querySelector('#answer').textContent = data.answer;
36
+ if (data.urls && data.urls.length > 0) {
37
+ document.getElementById('readmore').textContent = 'You might find related info at: ';
38
+ const urls = document.getElementById('urls');
39
+ urls.innerHTML = '';
40
+ data.urls.forEach(url => {
41
+ const li = document.createElement('li');
42
+ const a = document.createElement('a');
43
+ a.href = url;
44
+ a.textContent = url;
45
+ li.appendChild(a);
46
+ urls.appendChild(li);
47
+ });
48
+ } else {
49
+ document.getElementById('readmore').textContent = '';
50
+ document.getElementById('urls').innerHTML = '';
51
+ }
52
+ document.getElementById('ask').disabled = false;
53
+ });
54
+ </script>
55
+ {% if DEBUG %}
56
+ {{ hotreload.script(url_for('hot-reload')) | safe }}
57
+ {% endif %}
58
+ </body>
59
+
60
+ </html>
kth_qa/test_response.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "test": {
3
+ "answer": "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi ultrices nibh id molestie posuere. Aliquam mollis, nibh ut ultrices pretium, sem sem lacinia mi, consectetur interdum orci sapien et justo. In tempor sed eros quis tincidunt. Curabitur nunc ante, lobortis quis nunc eu, molestie consequat lacus. Pellentesque molestie interdum pellentesque. Phasellus non cursus risus. Suspendisse dictum tempor scelerisque. Vivamus et consectetur ante. Fusce bibendum augue mauris, sed scelerisque quam pharetra at. Vivamus scelerisque tristique elit eu commodo. Nam eu ante dui. Proin vestibulum quam id nisl commodo, quis malesuada erat sollicitudin. Sed lorem nisi, pellentesque eget pulvinar vel, faucibus a mi. Nulla quis velit porttitor, consequat lorem ut, egestas mi.",
4
+ "urls": ["https://kth.se"]
5
+ },
6
+ "no url": {
7
+ "answer": "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi ultrices nibh id molestie posuere. Aliquam mollis, nibh ut ultrices pretium, sem sem lacinia mi, consectetur interdum orci sapien et justo. In tempor sed eros quis tincidunt. Curabitur nunc ante, lobortis quis nunc eu, molestie consequat lacus. Pellentesque molestie interdum pellentesque. Phasellus non cursus risus. Suspendisse dictum tempor scelerisque. Vivamus et consectetur ante. Fusce bibendum augue mauris, sed scelerisque quam pharetra at. Vivamus scelerisque tristique elit eu commodo. Nam eu ante dui. Proin vestibulum quam id nisl commodo, quis malesuada erat sollicitudin. Sed lorem nisi, pellentesque eget pulvinar vel, faucibus a mi. Nulla quis velit porttitor, consequat lorem ut, egestas mi."
8
+ },
9
+ "empty url": {
10
+ "answer": "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi ultrices nibh id molestie posuere. Aliquam mollis, nibh ut ultrices pretium, sem sem lacinia mi, consectetur interdum orci sapien et justo. In tempor sed eros quis tincidunt. Curabitur nunc ante, lobortis quis nunc eu, molestie consequat lacus. Pellentesque molestie interdum pellentesque. Phasellus non cursus risus. Suspendisse dictum tempor scelerisque. Vivamus et consectetur ante. Fusce bibendum augue mauris, sed scelerisque quam pharetra at. Vivamus scelerisque tristique elit eu commodo. Nam eu ante dui. Proin vestibulum quam id nisl commodo, quis malesuada erat sollicitudin. Sed lorem nisi, pellentesque eget pulvinar vel, faucibus a mi. Nulla quis velit porttitor, consequat lorem ut, egestas mi.",
11
+ "urls": []
12
+ },
13
+ "many urls": {
14
+ "answer": "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi ultrices nibh id molestie posuere. Aliquam mollis, nibh ut ultrices pretium, sem sem lacinia mi, consectetur interdum orci sapien et justo. In tempor sed eros quis tincidunt. Curabitur nunc ante, lobortis quis nunc eu, molestie consequat lacus. Pellentesque molestie interdum pellentesque. Phasellus non cursus risus. Suspendisse dictum tempor scelerisque. Vivamus et consectetur ante. Fusce bibendum augue mauris, sed scelerisque quam pharetra at. Vivamus scelerisque tristique elit eu commodo. Nam eu ante dui. Proin vestibulum quam id nisl commodo, quis malesuada erat sollicitudin. Sed lorem nisi, pellentesque eget pulvinar vel, faucibus a mi. Nulla quis velit porttitor, consequat lorem ut, egestas mi.",
15
+ "url": "",
16
+ "urls": ["https://kth.se", "https://kth.se", "https://kth.se"]
17
+ }
18
+ }
kth_qa/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def touch_folder(folder):
5
+ if not os.path.exists(folder):
6
+ os.makedirs(folder)
7
+
8
+ def get_courses():
9
+ try:
10
+ with open('kth_qa/courses.json', 'r') as f:
11
+ data = json.load(f)
12
+ except FileNotFoundError:
13
+ try:
14
+ with open('courses.json', 'r') as f:
15
+ data = json.load(f)
16
+ except FileNotFoundError:
17
+ raise FileNotFoundError('courses.json not found')
18
+ courses = data.get('courses')
19
+ return courses
20
+
21
+ if __name__ == '__main__':
22
+ courses = get_courses()
23
+ print(len(courses))
24
+ new_courses = {}
25
+ for c in courses.keys():
26
+ if c[:2] in ['ME', 'DA', 'DM', 'DT', 'DH', 'MF', "EI"]:
27
+ new_courses[c] = courses[c]
28
+
29
+ print(len(new_courses))