|
|
|
import pandas as pd |
|
import numpy as np |
|
import joblib |
|
import os |
|
import logging |
|
from pymatgen.core import Composition |
|
import re |
|
|
|
from .constants import KNOWN_ELEMENT_SYMBOLS, ATMOSPHERE_CONFIG, MIXING_METHOD_CONFIG, MAGPIE_FEATURIZER, MAGPIE_LABELS, matminer_available |
|
from .feature_engineering_utils import standardize_chemical_formula, generate_compositional_features |
|
from .process_feature_utils import generate_process_features_for_input, generate_stoichiometry_features_for_input |
|
|
|
MODEL_DIR = "../models" |
|
PREPROCESSOR_DIR = "../models" |
|
ELEMENTAL_DATA_PATH = os.path.join(MODEL_DIR, "df_elements_processed.pkl") |
|
|
|
ESSENTIAL_OBJECTS = {} |
|
DF_ELEMENTS_PROCESSED_GLOBAL = None |
|
|
|
def load_all_artifacts_once(): |
|
global DF_ELEMENTS_PROCESSED_GLOBAL, ESSENTIAL_OBJECTS, matminer_available, MAGPIE_FEATURIZER, MAGPIE_LABELS |
|
if ESSENTIAL_OBJECTS.get("loaded_successfully"): |
|
logging.info("Artifacts already loaded.") |
|
return True |
|
|
|
logging.info("--- Loading Essential Artifacts for Prediction ---") |
|
script_dir = os.path.dirname(__file__) |
|
|
|
try: |
|
elemental_data_full_path = os.path.join(script_dir, ELEMENTAL_DATA_PATH) |
|
DF_ELEMENTS_PROCESSED_GLOBAL = pd.read_pickle(elemental_data_full_path) |
|
ESSENTIAL_OBJECTS["elemental_data"] = DF_ELEMENTS_PROCESSED_GLOBAL |
|
logging.info(f"Loaded processed elemental data from {elemental_data_full_path}") |
|
except Exception as e: |
|
logging.critical(f"CRITICAL: Error loading elemental data from {elemental_data_full_path}: {e}") |
|
return False |
|
|
|
if not matminer_available: |
|
try: |
|
from matminer.featurizers.composition import ElementProperty |
|
MAGPIE_FEATURIZER = ElementProperty.from_preset("magpie", impute_nan=True) |
|
MAGPIE_LABELS = [f'magpie_{label.replace(" ", "_")}' for label in MAGPIE_FEATURIZER.feature_labels()] |
|
matminer_available = True |
|
logging.info("Matminer re-initialized in inference script.") |
|
except: |
|
logging.warning("Matminer could not be re-initialized in inference script.") |
|
|
|
|
|
ESSENTIAL_OBJECTS["models"] = {} |
|
ESSENTIAL_OBJECTS["encoders"] = {} |
|
ESSENTIAL_OBJECTS["imputers"] = {} |
|
ESSENTIAL_OBJECTS["scalers"] = {} |
|
ESSENTIAL_OBJECTS["feature_columns"] = {} |
|
|
|
all_loaded_successfully = True |
|
for model_type_key in ["temperature_bin", "atmosphere_category"]: |
|
model_artifact_name = f"{model_type_key}_tuned" |
|
try: |
|
ESSENTIAL_OBJECTS["models"][model_type_key] = joblib.load(os.path.join(script_dir, MODEL_DIR, f"{model_artifact_name}_lgbm_model.joblib")) |
|
ESSENTIAL_OBJECTS["encoders"][model_type_key] = joblib.load(os.path.join(script_dir, MODEL_DIR, f"{model_artifact_name}_label_encoder.joblib")) |
|
ESSENTIAL_OBJECTS["imputers"][model_type_key] = joblib.load(os.path.join(script_dir, PREPROCESSOR_DIR, f"{model_artifact_name}_imputer.joblib")) |
|
ESSENTIAL_OBJECTS["scalers"][model_type_key] = joblib.load(os.path.join(script_dir, PREPROCESSOR_DIR, f"{model_artifact_name}_scaler.joblib")) |
|
ESSENTIAL_OBJECTS["feature_columns"][model_type_key] = joblib.load(os.path.join(script_dir, PREPROCESSOR_DIR, f"{model_artifact_name}_feature_columns.joblib")) |
|
logging.info(f"Loaded artifacts for {model_artifact_name} model.") |
|
except Exception as e: |
|
logging.error(f"Error loading one or more artifacts for '{model_artifact_name}': {e}. Predictions for it may fail.") |
|
ESSENTIAL_OBJECTS["models"][model_type_key] = None |
|
all_loaded_successfully = False |
|
|
|
ESSENTIAL_OBJECTS["loaded_successfully"] = all_loaded_successfully |
|
return all_loaded_successfully |
|
|
|
def create_feature_vector_for_prediction(raw_synthesis_input, model_target_name): |
|
global DF_ELEMENTS_PROCESSED_GLOBAL, ESSENTIAL_OBJECTS |
|
|
|
if DF_ELEMENTS_PROCESSED_GLOBAL is None: |
|
logging.error("Elemental data not loaded. Call load_all_artifacts_once() first.") |
|
return None |
|
|
|
expected_feature_cols = ESSENTIAL_OBJECTS["feature_columns"].get(model_target_name) |
|
if not expected_feature_cols: |
|
logging.error(f"Feature column list for '{model_target_name}' not found in loaded artifacts.") |
|
return None |
|
|
|
feature_dict = {col: (0 if col.startswith(("ops_", "proc_has_", "elem_block_")) or "is_stoichiometric" in col or "is_elements_only" in col else np.nan) for col in expected_feature_cols} |
|
|
|
|
|
std_target_output = standardize_chemical_formula(raw_synthesis_input.get('target_formula_raw'), "predict_target") |
|
target_comp_feats = generate_compositional_features(std_target_output, DF_ELEMENTS_PROCESSED_GLOBAL, "predict_target_comp") |
|
for k, v in target_comp_feats.items(): |
|
feature_key = f'target_{k}' |
|
if feature_key in feature_dict: feature_dict[feature_key] = v |
|
|
|
|
|
precursor_formulas_raw = raw_synthesis_input.get('precursor_formulas_raw', []) |
|
std_precursors_outputs = [standardize_chemical_formula(p, f"predict_prec_{i}") for i, p in enumerate(precursor_formulas_raw)] |
|
num_valid_precursors, num_stoich_precursors, num_elements_only_precursors = 0,0,0 |
|
precursor_comp_feats_list = [] |
|
for std_p_output in std_precursors_outputs: |
|
if std_p_output is not None: |
|
num_valid_precursors += 1 |
|
if isinstance(std_p_output, str): num_stoich_precursors += 1 |
|
elif isinstance(std_p_output, dict) and std_p_output.get('type') == 'elements_only': num_elements_only_precursors +=1 |
|
precursor_comp_feats_list.append(generate_compositional_features(std_p_output, DF_ELEMENTS_PROCESSED_GLOBAL, "predict_prec_comp")) |
|
|
|
feature_dict['num_valid_precursors'] = num_valid_precursors |
|
feature_dict['all_prec_are_stoichiometric'] = (num_stoich_precursors == num_valid_precursors) if num_valid_precursors > 0 else False |
|
feature_dict['any_prec_is_elements_only'] = (num_elements_only_precursors > 0) if num_valid_precursors > 0 else False |
|
|
|
if precursor_comp_feats_list: |
|
df_prec_feats = pd.DataFrame(precursor_comp_feats_list) |
|
numeric_cols_df_prec = df_prec_feats.select_dtypes(include=np.number) |
|
if not numeric_cols_df_prec.empty: |
|
temp_sample_df = pd.DataFrame([generate_compositional_features("H2O", DF_ELEMENTS_PROCESSED_GLOBAL)]) |
|
numeric_sample_comp_keys = [k for k in temp_sample_df.columns if pd.api.types.is_numeric_dtype(temp_sample_df[k]) and k not in ['is_stoichiometric_formula']] |
|
for agg_func_name in ['mean', 'std', 'min', 'max', 'sum']: |
|
aggregated_vals = getattr(numeric_cols_df_prec, agg_func_name)() |
|
for feat_name_suffix in numeric_sample_comp_keys: |
|
agg_feat_key = f"{agg_func_name}_prec_{feat_name_suffix}" |
|
if agg_feat_key in feature_dict and feat_name_suffix in aggregated_vals: |
|
feature_dict[agg_feat_key] = aggregated_vals[feat_name_suffix] |
|
|
|
|
|
process_input_ops_list = raw_synthesis_input.get('operations_simplified_list', []) |
|
all_atm_cats = list(set([col.split('ops_atm_cat_')[-1] for col in expected_feature_cols if col.startswith('ops_atm_cat_')])) |
|
all_mix_meths = list(set([col.split('ops_mix_meth_')[-1] for col in expected_feature_cols if col.startswith('ops_mix_meth_')])) |
|
proc_feats_generated = generate_process_features_for_input(process_input_ops_list, all_atm_cats, all_mix_meths) |
|
for k, v in proc_feats_generated.items(): |
|
if k in feature_dict: feature_dict[k] = v |
|
|
|
|
|
reactants_simplified = raw_synthesis_input.get('reactants_simplified', []) |
|
products_simplified = raw_synthesis_input.get('products_simplified', []) |
|
stoich_feats_generated = generate_stoichiometry_features_for_input(reactants_simplified, products_simplified, standardize_chemical_formula) |
|
for k, v in stoich_feats_generated.items(): |
|
if k in feature_dict: feature_dict[k] = v |
|
|
|
feature_vector_df = pd.DataFrame([feature_dict], columns=expected_feature_cols) |
|
|
|
|
|
imputer = ESSENTIAL_OBJECTS["imputers"].get(model_target_name) |
|
scaler = ESSENTIAL_OBJECTS["scalers"].get(model_target_name) |
|
|
|
numerical_features_for_transform = [col for col in expected_feature_cols if col in feature_vector_df.columns and pd.api.types.is_numeric_dtype(feature_vector_df[col].dtype) and not col.startswith('ops_') and not col.startswith('proc_has_') and not col.startswith('elem_block_') and col not in ['is_stoichiometric_formula', 'all_prec_are_stoichiometric', 'any_prec_is_elements_only', 'num_valid_precursors']] |
|
|
|
if imputer and scaler and numerical_features_for_transform: |
|
try: |
|
feature_vector_df[numerical_features_for_transform] = feature_vector_df[numerical_features_for_transform].astype(np.float64) |
|
feature_vector_df[numerical_features_for_transform] = imputer.transform(feature_vector_df[numerical_features_for_transform]) |
|
feature_vector_df[numerical_features_for_transform] = scaler.transform(feature_vector_df[numerical_features_for_transform]) |
|
logging.info("Feature vector imputed and scaled for prediction.") |
|
except Exception as e_transform: |
|
logging.error(f"Error during imputation/scaling for prediction: {e_transform}", exc_info=True) |
|
return None |
|
else: |
|
logging.warning("Imputer, Scaler or numerical features missing for prediction. Proceeding with caution.") |
|
return feature_vector_df |
|
|
|
|
|
def predict_synthesis_outcome(raw_synthesis_input): |
|
global ESSENTIAL_OBJECTS |
|
if not ESSENTIAL_OBJECTS.get("loaded_successfully"): |
|
success = load_all_artifacts_once() |
|
if not success: |
|
logging.error("Essential artifacts could not be loaded. Cannot make predictions.") |
|
return {} |
|
|
|
predictions = {} |
|
model_types_to_predict = ["temperature_bin", "atmosphere_category"] |
|
|
|
for model_type in model_types_to_predict: |
|
if ESSENTIAL_OBJECTS["models"].get(model_type): |
|
logging.info(f"\n--- Predicting {model_type} ---") |
|
feature_vector = create_feature_vector_for_prediction(raw_synthesis_input, model_type) |
|
|
|
if feature_vector is not None: |
|
model = ESSENTIAL_OBJECTS["models"][model_type] |
|
encoder = ESSENTIAL_OBJECTS["encoders"][model_type] |
|
try: |
|
pred_encoded = model.predict(feature_vector) |
|
pred_proba = model.predict_proba(feature_vector) |
|
pred_label = encoder.inverse_transform(pred_encoded)[0] |
|
|
|
predictions[model_type] = { |
|
'predicted_label': pred_label, |
|
'probabilities': {str(cls): prob for cls, prob in zip(encoder.classes_, pred_proba[0])} |
|
} |
|
logging.info(f"Predicted {model_type}: {pred_label}") |
|
logging.info(f"Probabilities: {predictions[model_type]['probabilities']}") |
|
except Exception as e: |
|
logging.error(f"Error during {model_type} prediction: {e}", exc_info=True) |
|
predictions[model_type] = f"Prediction Error: {e}" |
|
else: |
|
logging.error(f"Could not create feature vector for {model_type} model.") |
|
predictions[model_type] = "Feature vector creation error" |
|
else: |
|
logging.warning(f"{model_type} model not available for prediction.") |
|
|
|
return predictions |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
if not load_all_artifacts_once(): |
|
print("Exiting due to failure in loading essential artifacts.") |
|
else: |
|
print("\n--- Example Interactive Prediction ---") |
|
example_input_with_ops_list = { |
|
'target_formula_raw': "YBa2Cu3O7", |
|
'precursor_formulas_raw': ["Y2O3", "BaCO3", "CuO"], |
|
'operations_simplified_list': [ |
|
{'type': 'MixingOperation', 'string': 'Mix precursors by ball milling for 4h', 'conditions': {'duration': [{'value':4, 'unit':'h'}]}}, |
|
{'type': 'HeatingOperation', 'string': 'Calcined at 900C for 12h in air', 'conditions': {'heating_temperature': [{'value':900, 'unit':'C'}], 'heating_time': [{'value':12, 'unit':'h'}], 'atmosphere': 'Air'}}, |
|
{'type': 'HeatingOperation', 'string': 'Sintered at 950C for 24h in O2', 'conditions': {'heating_temperature': [{'value':950, 'unit':'C'}], 'heating_time': [{'value':20, 'unit':'h'}], 'atmosphere': 'Oxygen'}} |
|
], |
|
'reactants_simplified': [{'material': 'Y2O3', 'amount': 0.5}, {'material':'BaCO3', 'amount': 2.0}, {'material':'CuO', 'amount': 3.0}], |
|
'products_simplified': [{'material':'YBa2Cu3O7', 'amount': 1.0}] |
|
} |
|
|
|
predictions = predict_synthesis_outcome(example_input_with_ops_list) |
|
print(f"\nFinal Predictions for example input: {predictions}") |
|
|
|
|