Ramlaoui's picture
Faster search and Table
2dd66b7
raw
history blame
8.02 kB
import os
import crystal_toolkit.components as ctc
import dash
import dash_mp_components as dmp
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output, State
from datasets import load_dataset
from pymatgen.core import Structure
from pymatgen.ext.matproj import MPRester
HF_TOKEN = os.environ.get("HF_TOKEN")
top_k = 100
# Load only the train split of the dataset
dataset = load_dataset(
"LeMaterial/leDataset",
token=HF_TOKEN,
split="train",
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
"stress_tensor",
"magnetic_moments",
"forces",
"band_gap_direct",
"band_gap_indirect",
"dos_ef",
"charges",
"functional",
"chemical_formula_reduced",
"chemical_formula_descriptive",
"total_magnetization",
],
)
display_columns = [
"chemical_formula_descriptive",
"functional",
"immutable_id",
"energy",
]
display_names = {
"chemical_formula_descriptive": "Formula",
"functional": "Functional",
"immutable_id": "Material ID",
"energy": "Energy (eV)",
}
mapping_table_idx_dataset_idx = {}
import numpy as np
import periodictable
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
# import re
#
# dataset_index = np.zeros((len(dataset), 118))
# import tqdm
#
# for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
# for el in row["chemical_formula_descriptive"].split(" "):
# matches = re.findall(r"([a-zA-Z]+)([0-9]*)", el)
# el = matches[0][0]
# numb = int(matches[0][1]) if matches[0][1] else 1
# dataset_index[i][map_periodic_table[el]] = numb
dataset_index = np.load("dataset_index.npy")
# Initialize the Dash app
app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
server = app.server # Expose the server for deployment
# Define the app layout
layout = html.Div(
[
html.H1("Interactive Crystal Viewer"),
html.Div(
[
html.Div(
[
html.H3("Search for materials by elements (eg. 'Ac,Cd,Ge')"),
dmp.MaterialsInput(
allowedInputTypes=["elements", "formula"],
hidePeriodicTable=False,
periodicTableMode="toggle",
showSubmitButton=True,
submitButtonText="Search",
type="elements",
id="materials-input",
),
],
style={
"width": "100%",
"display": "inline-block",
"verticalAlign": "top",
},
),
],
style={"margin-bottom": "20px"},
),
html.Div(
[
html.Label("Select Material"),
# dcc.Dropdown(
# id="material-dropdown",
# options=[], # Empty options initially
# value=None,
# ),
dash.dash_table.DataTable(
id="table",
columns=[
{"name": display_names[col], "id": col}
for col in display_columns
],
data=[{}],
style_table={
"overflowX": "auto",
"height": "400px",
"overflowY": "auto",
},
style_cell={"textAlign": "left"},
),
],
style={"margin-bottom": "20px"},
),
html.Button("Display Material", id="display-button", n_clicks=0),
html.Div(
[
html.Div(
id="structure-container",
style={
"width": "48%",
"display": "inline-block",
"verticalAlign": "top",
},
),
html.Div(
id="properties-container",
style={
"width": "48%",
"display": "inline-block",
"paddingLeft": "4%",
"verticalAlign": "top",
},
),
],
style={"margin-top": "20px"},
),
],
style={
"margin-left": "10px",
"margin-right": "10px",
},
)
def search_materials(query):
query_vector = np.zeros(118)
if "," in query:
element_list = [el.strip() for el in query.split(",")]
for el in element_list:
query_vector[map_periodic_table[el]] = 1
else:
# Formula
import re
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
for el, numb in matches:
numb = int(numb) if numb else 1
query_vector[map_periodic_table[el]] = numb
similarity = np.dot(dataset_index, query_vector) / (
np.linalg.norm(dataset_index) * np.linalg.norm(query_vector)
)
print(similarity[::-1][:top_k])
indices = np.argsort(similarity)[::-1][:top_k]
options = [dataset[int(i)] for i in indices]
mapping_table_idx_dataset_idx.clear()
for i, idx in enumerate(indices):
mapping_table_idx_dataset_idx[int(i)] = int(idx)
return options
# Callback to update the table based on search
@app.callback(
Output("table", "data"),
Input("materials-input", "submitButtonClicks"),
Input("materials-input", "value"),
)
def on_submit_materials_input(n_clicks, query):
if n_clicks is None or not query:
return []
entries = search_materials(query)
print(len(entries))
return [{col: entry[col] for col in display_columns} for entry in entries]
# Callback to display the selected material
@app.callback(
[
Output("structure-container", "children"),
Output("properties-container", "children"),
],
Input("display-button", "n_clicks"),
Input("table", "active_cell"),
)
def display_material(n_clicks, active_cell):
if n_clicks is None or not active_cell:
return "", ""
idx_active = active_cell["row"]
row = dataset[mapping_table_idx_dataset_idx[idx_active]]
structure = Structure(
[x for y in row["lattice_vectors"] for x in y],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
)
# Create the StructureMoleculeComponent
structure_component = ctc.StructureMoleculeComponent(structure)
# Extract key properties
properties = {
"Material ID": row["immutable_id"],
"Formula": row["chemical_formula_descriptive"],
"Energy per atom (eV/atom)": row["energy"] / len(row["species_at_sites"]),
"Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"],
"Total Magnetization (μB/f.u.)": row["total_magnetization"],
}
# Format properties as an HTML table
properties_html = html.Table(
[
html.Tbody(
[
html.Tr([html.Th(key), html.Td(str(value))])
for key, value in properties.items()
]
)
],
style={
"border": "1px solid black",
"width": "100%",
"borderCollapse": "collapse",
},
)
return structure_component.layout(), properties_html
# Register crystal toolkit with the app
ctc.register_crystal_toolkit(app, layout)
if __name__ == "__main__":
app.run_server(debug=True, port=7860, host="0.0.0.0")