Spaces:
Runtime error
Runtime error
import json | |
import logging | |
from fastapi.encoders import jsonable_encoder | |
from langchain.chains import RetrievalQA | |
from langchain.document_loaders import DirectoryLoader, TextLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from qdrant_client.http.models import Distance, VectorParams | |
from edu_assistant.utils.common_utils import init_local_logging | |
from edu_assistant.utils.langchain_utils import load_llm, load_vectorstore | |
from edu_assistant.utils.qdrant_utils import load_qdrant_client | |
init_local_logging(logging.DEBUG) | |
def create_collection(collection_name: str): | |
client = load_qdrant_client() | |
client.recreate_collection( | |
collection_name=collection_name, | |
vectors_config=VectorParams(size=1536, distance=Distance.DOT), | |
) | |
collection_info = client.get_collection(collection_name=collection_name) | |
print(collection_info) | |
def delete_collection(collection_name: str): | |
client = load_qdrant_client() | |
client.delete_collection(collection_name=collection_name) | |
def add_docs(path: str, collection_name: str): | |
vs = load_vectorstore(collection_name=collection_name) | |
loader = DirectoryLoader(path=path, glob="*.txt", loader_cls=TextLoader) | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=50) | |
docs = text_splitter.split_documents(documents) | |
for doc in docs: | |
vs.add_documents([doc]) | |
def qa(collection_name: str, question: str): | |
chain = RetrievalQA.from_llm( | |
llm=load_llm(), | |
retriever=load_vectorstore(collection_name=collection_name).as_retriever(k=1), | |
return_source_documents=True, | |
) | |
result = chain(question) | |
print(json.dumps(jsonable_encoder(result), ensure_ascii=False, indent=4)) | |
if __name__ == "__main__": | |
name = "example" | |
path = "examples/docs" | |
question = "C++有哪些数据类型修饰符?" | |
# delete_collection(name) | |
# create_collection(name) | |
# add_docs(path, name) | |
qa(name, question) | |