| from __future__ import annotations |
|
|
| from typing import TypedDict |
| import os |
| import folder_paths |
| import glob |
| from aiohttp import web |
| import hashlib |
|
|
|
|
| class Source: |
| custom_node = "custom_node" |
|
|
| class SubgraphEntry(TypedDict): |
| source: str |
| """ |
| Source of subgraph - custom_nodes vs templates. |
| """ |
| path: str |
| """ |
| Relative path of the subgraph file. |
| For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json |
| """ |
| name: str |
| """ |
| Name of subgraph file. |
| """ |
| info: CustomNodeSubgraphEntryInfo |
| """ |
| Additional info about subgraph; in the case of custom_nodes, will contain nodepack name |
| """ |
| data: str |
|
|
| class CustomNodeSubgraphEntryInfo(TypedDict): |
| node_pack: str |
| """Node pack name.""" |
|
|
| class SubgraphManager: |
| def __init__(self): |
| self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None |
|
|
| async def load_entry_data(self, entry: SubgraphEntry): |
| with open(entry['path'], 'r') as f: |
| entry['data'] = f.read() |
| return entry |
|
|
| async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None: |
| if entry is None: |
| return None |
| entry = entry.copy() |
| entry.pop('path', None) |
| if remove_data: |
| entry.pop('data', None) |
| return entry |
|
|
| async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]: |
| entries = entries.copy() |
| for key in list(entries.keys()): |
| entries[key] = await self.sanitize_entry(entries[key], remove_data) |
| return entries |
|
|
| async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): |
| |
| if not force_reload and self.cached_custom_node_subgraphs is not None: |
| return self.cached_custom_node_subgraphs |
| |
| subfolder = "subgraphs" |
| subgraphs_dict: dict[SubgraphEntry] = {} |
|
|
| for folder in folder_paths.get_folder_paths("custom_nodes"): |
| pattern = os.path.join(folder, f"*/{subfolder}/*.json") |
| matched_files = glob.glob(pattern) |
| for file in matched_files: |
| |
| file = file.replace('\\', '/') |
| info: CustomNodeSubgraphEntryInfo = { |
| "node_pack": "custom_nodes." + file.split('/')[-3] |
| } |
| source = Source.custom_node |
| |
| |
| id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() |
| entry: SubgraphEntry = { |
| "source": Source.custom_node, |
| "name": os.path.splitext(os.path.basename(file))[0], |
| "path": file, |
| "info": info, |
| } |
| subgraphs_dict[id] = entry |
| self.cached_custom_node_subgraphs = subgraphs_dict |
| return subgraphs_dict |
|
|
| async def get_custom_node_subgraph(self, id: str, loadedModules): |
| subgraphs = await self.get_custom_node_subgraphs(loadedModules) |
| entry: SubgraphEntry = subgraphs.get(id, None) |
| if entry is not None and entry.get('data', None) is None: |
| await self.load_entry_data(entry) |
| return entry |
|
|
| def add_routes(self, routes, loadedModules): |
| @routes.get("/global_subgraphs") |
| async def get_global_subgraphs(request): |
| subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules) |
| |
| |
| return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) |
|
|
| @routes.get("/global_subgraphs/{id}") |
| async def get_global_subgraph(request): |
| id = request.match_info.get("id", None) |
| subgraph = await self.get_custom_node_subgraph(id, loadedModules) |
| return web.json_response(await self.sanitize_entry(subgraph)) |
|
|