import os import re import crystal_toolkit.components as ctc import numpy as np import periodictable from dash import dcc, html from datasets import concatenate_datasets, load_dataset from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer HF_TOKEN = os.environ.get("HF_TOKEN") top_k = 500 def get_dataset(): # Load only the train split of the dataset datasets = [] subsets = [ "compatible_pbe", "compatible_pbesol", "compatible_scan", "non_compatible", ] for subset in subsets: dataset = load_dataset( "LeMaterial/leMat-Bulk", subset, token=HF_TOKEN, columns=[ "lattice_vectors", "species_at_sites", "cartesian_site_positions", "energy", # "energy_corrected", # not yet available in LeMat-Bulk "immutable_id", "elements", "functional", "stress_tensor", "magnetic_moments", "forces", # "band_gap_direct", #future release # "band_gap_indirect", #future release "dos_ef", # "charges", #future release "functional", "chemical_formula_reduced", "chemical_formula_descriptive", "total_magnetization", "entalpic_fingerprint", ], ) datasets.append(dataset["train"]) return concatenate_datasets(datasets) display_columns = [ "chemical_formula_descriptive", "functional", "immutable_id", "energy", ] display_names = { "chemical_formula_descriptive": "Formula", "functional": "Functional", "immutable_id": "Material ID", "energy": "Energy (eV)", } # Global shared variables mapping_table_idx_dataset_idx = {} def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False): if empty_data: return np.zeros((1, 1)), {} use_dataset = dataset if index_range is not None: use_dataset = dataset.select(index_range) # Preprocessing step to create an index for the dataset if cache_path is not None: train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb")) dataset_index = pickle.load(open(f"{cache_path}/dataset_index.pkl", "rb")) else: train_df = use_dataset.select_columns( ["chemical_formula_descriptive", "immutable_id"] ).to_pandas() pattern = re.compile(r"(?P[A-Z][a-z]?)(?P\d*)") extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern) extracted["count"] = extracted["count"].replace("", "1").astype(int) wide_df = ( extracted.reset_index().pivot_table( # Move index to columns for pivoting index="level_0", # original row index columns="element", values="count", aggfunc="sum", fill_value=0, ) ) all_elements = [el.symbol for el in periodictable.elements] # full element list wide_df = wide_df.reindex(columns=all_elements, fill_value=0) dataset_index = wide_df.values dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None] dataset_index = ( dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] ) # Normalize vectors immutable_id_to_idx = train_df["immutable_id"].to_dict() immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()} return dataset_index, immutable_id_to_idx import pickle from pathlib import Path # TODO: Just load the index from a file def build_embeddings_index(empty_data=False): if empty_data: return None, {}, {} features_dict = pickle.load(open("features_dict.pkl", "rb")) from indexer import FAISSIndex index = FAISSIndex() for key in features_dict: index.index.add(features_dict[key].reshape(1, -1)) idx_to_immutable_id = {i: key for i, key in enumerate(features_dict)} # index = FAISSIndex.from_store("index.faiss") return index, features_dict, idx_to_immutable_id def search_materials( query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table ): n_elements = len(map_periodic_table) query_vector = np.zeros(n_elements) if "," in query: element_list = [el.strip() for el in query.split(",")] for el in element_list: query_vector[map_periodic_table[el]] = 1 else: # Formula import re matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query) for el, numb in matches: numb = int(numb) if numb else 1 query_vector[map_periodic_table[el]] = numb similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector)) indices = np.argsort(similarity)[::-1][:top_k] options = [dataset[int(i)] for i in indices] mapping_table_idx_dataset_idx.clear() for i, idx in enumerate(indices): mapping_table_idx_dataset_idx[int(i)] = int(idx) return options def get_properties_table( row, structure, sga, properties_container_update, container_type="query" ): properties = { "Material ID": row["immutable_id"], "Formula": row["chemical_formula_descriptive"], "Energy per atom (eV/atom)": round( row["energy"] / len(row["species_at_sites"]), 3 ), # "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release "Total Magnetization (μB)": ( round(row["total_magnetization"], 3) if row["total_magnetization"] is not None else None ), "Density (g/cm^3)": round(structure.density, 3), "Fermi energy level (eV)": ( round(row["dos_ef"], 3) if row["dos_ef"] is not None else None ), "Crystal system": sga.get_crystal_system(), "International Spacegroup": sga.get_symmetry_dataset().international, "Magnetic moments (μB)": np.round(row["magnetic_moments"], 3), "Stress tensor (kB)": np.round(row["stress_tensor"], 3), "Forces on atoms (eV/A)": np.round(row["forces"], 3), # "Bader charges (e-)": np.round(row["charges"], 3), # future release "DFT Functional": row["functional"], "Entalpic fingerprint": row["entalpic_fingerprint"], } style = { "padding": "10px", "borderBottom": "1px solid #ddd", } if container_type == "query": properties_container_update[0] = properties else: properties_container_update[1] = properties # if (type(value) in [str, float]) and ( # properties_container_update[0][key] == properties_container_update[1][key] # ): # style["backgroundColor"] = "#e6f7ff" # Format properties as an HTML table properties_html = html.Table( [ html.Tbody( [ html.Tr( [ html.Th( key, style={ "padding": "10px", "verticalAlign": "middle", }, ), html.Td( str(value), style=style, ), ], ) for key, value in properties.items() ], ) ], style={ "width": "100%", "borderCollapse": "collapse", "fontFamily": "'Arial', sans-serif", "fontSize": "14px", "color": "#333333", }, ) return properties_html def get_crystal_plot(structure): sga = SpacegroupAnalyzer(structure) # Create the StructureMoleculeComponent structure_component = ctc.StructureMoleculeComponent(structure) return structure_component.layout(), sga