Spaces:
Runtime error
Runtime error
File size: 9,398 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
from typing import Any, Dict, List, Optional, Tuple, Union
class NeptuneQueryException(Exception):
"""A class to handle queries that fail to execute"""
def __init__(self, exception: Union[str, Dict]):
if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown"
else:
self.message = exception
self.details = "unknown"
def get_message(self) -> str:
return self.message
def get_details(self) -> Any:
return self.details
class NeptuneGraph:
"""Neptune wrapper for graph operations.
Args:
host: endpoint for the database instance
port: port number for the database instance, default is 8182
use_https: whether to use secure connection, default is True
client: optional boto3 Neptune client
credentials_profile_name: optional AWS profile name
region_name: optional AWS region, e.g., us-west-2
service: optional service name, default is neptunedata
sign: optional, whether to sign the request payload, default is True
Example:
.. code-block:: python
graph = NeptuneGraph(
host='<my-cluster>',
port=8182
)
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
host: str,
port: int = 8182,
use_https: bool = True,
client: Any = None,
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,
service: str = "neptunedata",
sign: bool = True,
) -> None:
"""Create a new Neptune graph wrapper instance."""
try:
if client is not None:
self.client = client
else:
import boto3
if credentials_profile_name is not None:
session = boto3.Session(profile_name=credentials_profile_name)
else:
# use default credentials
session = boto3.Session()
client_params = {}
if region_name:
client_params["region_name"] = region_name
protocol = "https" if use_https else "http"
client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
if sign:
self.client = session.client(service, **client_params)
else:
from botocore import UNSIGNED
from botocore.config import Config
self.client = session.client(
service,
**client_params,
config=Config(signature_version=UNSIGNED),
)
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
if type(e).__name__ == "UnknownServiceError":
raise ModuleNotFoundError(
"NeptuneGraph requires a boto3 version 1.28.38 or greater."
"Please install it with `pip install -U boto3`."
) from e
else:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
try:
self._refresh_schema()
except Exception as e:
raise NeptuneQueryException(
{
"message": "Could not get schema for Neptune database",
"detail": str(e),
}
)
@property
def get_schema(self) -> str:
"""Returns the schema of the Neptune database"""
return self.schema
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
"""Query Neptune database."""
try:
return self.client.execute_open_cypher_query(openCypherQuery=query)
except Exception as e:
raise NeptuneQueryException(
{
"message": "An error occurred while executing the query.",
"details": str(e),
}
)
def _get_summary(self) -> Dict:
try:
response = self.client.get_propertygraph_summary()
except Exception as e:
raise NeptuneQueryException(
{
"message": (
"Summary API is not available for this instance of Neptune,"
"ensure the engine version is >=1.2.1.0"
),
"details": str(e),
}
)
try:
summary = response["payload"]["graphSummary"]
except Exception:
raise NeptuneQueryException(
{
"message": "Summary API did not return a valid response.",
"details": response.content.decode(),
}
)
else:
return summary
def _get_labels(self) -> Tuple[List[str], List[str]]:
"""Get node and edge labels from the Neptune statistics summary"""
summary = self._get_summary()
n_labels = summary["nodeLabels"]
e_labels = summary["edgeLabels"]
return n_labels, e_labels
def _get_triples(self, e_labels: List[str]) -> List[str]:
triple_query = """
MATCH (a)-[e:`{e_label}`]->(b)
WITH a,e,b LIMIT 3000
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
LIMIT 10
"""
triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)"
triple_schema = []
for label in e_labels:
q = triple_query.format(e_label=label)
data = self.query(q)
for d in data["results"]:
triple = triple_template.format(
a=d["from"][0], e=d["edge"], b=d["to"][0]
)
triple_schema.append(triple)
return triple_schema
def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
node_properties_query = """
MATCH (a:`{n_label}`)
RETURN properties(a) AS props
LIMIT 100
"""
node_properties = []
for label in n_labels:
q = node_properties_query.format(n_label=label)
data = {"label": label, "properties": self.query(q)["results"]}
s = set({})
for p in data["properties"]:
for k, v in p["props"].items():
s.add((k, types[type(v).__name__]))
np = {
"properties": [{"property": k, "type": v} for k, v in s],
"labels": label,
}
node_properties.append(np)
return node_properties
def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
edge_properties_query = """
MATCH ()-[e:`{e_label}`]->()
RETURN properties(e) AS props
LIMIT 100
"""
edge_properties = []
for label in e_labels:
q = edge_properties_query.format(e_label=label)
data = {"label": label, "properties": self.query(q)["results"]}
s = set({})
for p in data["properties"]:
for k, v in p["props"].items():
s.add((k, types[type(v).__name__]))
ep = {
"type": label,
"properties": [{"property": k, "type": v} for k, v in s],
}
edge_properties.append(ep)
return edge_properties
def _refresh_schema(self) -> None:
"""
Refreshes the Neptune graph schema information.
"""
types = {
"str": "STRING",
"float": "DOUBLE",
"int": "INTEGER",
"list": "LIST",
"dict": "MAP",
"bool": "BOOLEAN",
}
n_labels, e_labels = self._get_labels()
triple_schema = self._get_triples(e_labels)
node_properties = self._get_node_properties(n_labels, types)
edge_properties = self._get_edge_properties(e_labels, types)
self.schema = f"""
Node properties are the following:
{node_properties}
Relationship properties are the following:
{edge_properties}
The relationships are the following:
{triple_schema}
"""
|