| import hdbscan |
| import numpy as np |
|
|
| |
| def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): |
| """ Function used to select the HDBSCAN-like model for generating |
| predictions and probabilities. |
| |
| Arguments: |
| model: The cluster model. |
| func: The function to use. Options: |
| - "approximate_predict" |
| - "all_points_membership_vectors" |
| - "membership_vector" |
| embeddings: Input embeddings for "approximate_predict" |
| and "membership_vector" |
| """ |
|
|
| |
| if func == "approximate_predict": |
| if isinstance(model, hdbscan.HDBSCAN): |
| predictions, probabilities = hdbscan.approximate_predict(model, embeddings) |
| return predictions, probabilities |
|
|
| str_type_model = str(type(model)).lower() |
| if "cuml" in str_type_model and "hdbscan" in str_type_model: |
| from cuml.cluster import hdbscan as cuml_hdbscan |
| predictions, probabilities = cuml_hdbscan.approximate_predict(model, embeddings) |
| return predictions, probabilities |
|
|
| predictions = model.predict(embeddings) |
| return predictions, None |
|
|
| |
| if func == "all_points_membership_vectors": |
| if isinstance(model, hdbscan.HDBSCAN): |
| return hdbscan.all_points_membership_vectors(model) |
|
|
| str_type_model = str(type(model)).lower() |
| if "cuml" in str_type_model and "hdbscan" in str_type_model: |
| from cuml.cluster import hdbscan as cuml_hdbscan |
| return cuml_hdbscan.all_points_membership_vectors(model) |
|
|
| return None |
| |
| |
| if func == "membership_vector": |
| if isinstance(model, hdbscan.HDBSCAN): |
| probabilities = hdbscan.membership_vector(model, embeddings) |
| return probabilities |
|
|
| str_type_model = str(type(model)).lower() |
| if "cuml" in str_type_model and "hdbscan" in str_type_model: |
| from cuml.cluster.hdbscan.prediction import approximate_predict |
| probabilities = approximate_predict(model, embeddings) |
| return probabilities |
|
|
| return None |
|
|
|
|
| def is_supported_hdbscan(model): |
| """ Check whether the input model is a supported HDBSCAN-like model """ |
| if isinstance(model, hdbscan.HDBSCAN): |
| return True |
|
|
| str_type_model = str(type(model)).lower() |
| if "cuml" in str_type_model and "hdbscan" in str_type_model: |
| return True |
|
|
| return False |
|
|