Spaces:
Runtime error
Runtime error
"""Weaviate reader.""" | |
from typing import Any, List, Optional | |
from gpt_index.readers.base import BaseReader | |
from gpt_index.readers.schema.base import Document | |
class WeaviateReader(BaseReader): | |
"""Weaviate reader. | |
Retrieves documents from Weaviate through vector lookup. Allows option | |
to concatenate retrieved documents into one Document, or to return | |
separate Document objects per document. | |
Args: | |
host (str): host. | |
auth_client_secret (Optional[weaviate.auth.AuthCredentials]): | |
auth_client_secret. | |
""" | |
def __init__( | |
self, | |
host: str, | |
auth_client_secret: Optional[Any] = None, | |
) -> None: | |
"""Initialize with parameters.""" | |
try: | |
import weaviate # noqa: F401 | |
from weaviate import Client # noqa: F401 | |
from weaviate.auth import AuthCredentials # noqa: F401 | |
except ImportError: | |
raise ImportError( | |
"`weaviate` package not found, please run `pip install weaviate-client`" | |
) | |
self.client: Client = Client(host, auth_client_secret=auth_client_secret) | |
def load_data( | |
self, | |
class_name: Optional[str] = None, | |
properties: Optional[List[str]] = None, | |
graphql_query: Optional[str] = None, | |
separate_documents: Optional[bool] = True, | |
) -> List[Document]: | |
"""Load data from Weaviate. | |
If `graphql_query` is not found in load_kwargs, we assume that | |
`class_name` and `properties` are provided. | |
Args: | |
class_name (Optional[str]): class_name to retrieve documents from. | |
properties (Optional[List[str]]): properties to retrieve from documents. | |
graphql_query (Optional[str]): Raw GraphQL Query. | |
We assume that the query is a Get query. | |
separate_documents (Optional[bool]): Whether to return separate | |
documents. Defaults to True. | |
Returns: | |
List[Document]: A list of documents. | |
""" | |
if class_name is not None and properties is not None: | |
props_txt = "\n".join(properties) | |
graphql_query = f""" | |
{{ | |
Get {{ | |
{class_name} {{ | |
{props_txt} | |
}} | |
}} | |
}} | |
""" | |
elif graphql_query is not None: | |
pass | |
else: | |
raise ValueError( | |
"Either `class_name` and `properties` must be specified, " | |
"or `graphql_query` must be specified." | |
) | |
response = self.client.query.raw(graphql_query) | |
if "errors" in response: | |
raise ValueError("Invalid query, got errors: {}".format(response["errors"])) | |
data_response = response["data"] | |
if "Get" not in data_response: | |
raise ValueError("Invalid query response, must be a Get query.") | |
if class_name is None: | |
# infer class_name if only graphql_query was provided | |
class_name = list(data_response["Get"].keys())[0] | |
entries = data_response["Get"][class_name] | |
documents = [] | |
for entry in entries: | |
embedding = None | |
# for each entry, join properties into <property>:<value> | |
# separated by newlines | |
text_list = [] | |
for k, v in entry.items(): | |
if k == "_additional": | |
if "vector" in v: | |
embedding = v["vector"] | |
continue | |
text_list.append(f"{k}: {v}") | |
text = "\n".join(text_list) | |
documents.append(Document(text=text, embedding=embedding)) | |
if not separate_documents: | |
# join all documents into one | |
text_list = [doc.get_text() for doc in documents] | |
text = "\n\n".join(text_list) | |
documents = [Document(text=text)] | |
return documents | |