import os from pymatgen.ext.matproj import MPRester import crystal_toolkit.components as ctc from crystal_toolkit.settings import SETTINGS import dash from dash import html, dcc from dash.dependencies import Input, Output, State from pymatgen.core import Structure HF_TOKEN = os.environ.get("HF_TOKEN") # Load only the train split of the dataset dataset = load_dataset( "LeMaterial/leDataset", token=HF_TOKEN, split="train", columns=[ "lattice_vectors", "species_at_sites", "cartesian_site_positions", "energy", "energy_corrected", "immutable_id", "elements", "functional", "stress_tensor", "magnetic_moments", "forces", "band_gap_direct", "band_gap_indirect", "dos_ef", "charges", "functional", "chemical_formula_reduced", "chemical_formula_descriptive", "total_magnetization" ], ) # Convert the train split to a pandas DataFrame train_df = dataset.to_pandas() del dataset # Initialize the Dash app app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH) server = app.server # Expose the server for deployment # Define the app layout layout = html.Div([ dcc.Markdown("## Interactive Crystal Viewer"), html.Div([ html.Div([ html.Label("Search by Chemical System (e.g., 'Ac-Cd-Ge')"), dcc.Input( id='query-input', type='text', value='Ac-Cd-Ge', placeholder='Ac-Cd-Ge', style={'width': '100%'} ), ], style={'width': '70%', 'display': 'inline-block', 'verticalAlign': 'top'}), html.Div([ html.Button('Search', id='search-button', n_clicks=0), ], style={'width': '28%', 'display': 'inline-block', 'paddingLeft': '2%', 'verticalAlign': 'top'}), ], style={'margin-bottom': '20px'}), html.Div([ html.Label("Select Material"), dcc.Dropdown( id='material-dropdown', options=[], # Empty options initially value=None ), ], style={'margin-bottom': '20px'}), html.Button('Display Material', id='display-button', n_clicks=0), html.Div([ html.Div(id='structure-container', style={'width': '48%', 'display': 'inline-block', 'verticalAlign': 'top'}), html.Div(id='properties-container', style={'width': '48%', 'display': 'inline-block', 'paddingLeft': '4%', 'verticalAlign': 'top'}), ], style={'margin-top': '20px'}), ]) # Function to search for materials def search_materials(query): element_list = [el.strip() for el in query.split("-")] isubset = lambda x: set(x).issubset(element_list) isintersection = lambda x: len(set(x).intersection(element_list)) > 0 entries_df = train_df[ [isintersection(l) and isubset(l) for l in train_df.elements.values.tolist()] ] options = [{'label': f"{res.chemical_formula_reduced} ({res.immutable_id}) Calculated with {res.functional}", 'value': n} for n,res in entries_df.iterrows()] del entries_df return options # Callback to update the material dropdown based on search @app.callback( [Output('material-dropdown', 'options'), Output('material-dropdown', 'value')], Input('search-button', 'n_clicks'), State('query-input', 'value'), ) def update_material_dropdown(n_clicks, query): if n_clicks is None or not query: return [], None options = search_materials(query) if not options: return [], None return options, options[0]['value'] # Callback to display the selected material @app.callback( [Output('structure-container', 'children'), Output('properties-container', 'children')], Input('display-button', 'n_clicks'), State('material-dropdown', 'value') ) def display_material(n_clicks, material_id): if n_clicks is None or not material_id: return '', '' row = train_df.iloc[material_id] structure = Structure([x for y in row['lattice_vectors'] for x in y], row['species_at_sites'], row['cartesian_site_positions'], coords_are_cartesian= True) # Create the StructureMoleculeComponent structure_component = ctc.StructureMoleculeComponent(structure) # Extract key properties properties = { "Material ID": row.immutable_id, "Formula": row.chemical_formula_descriptive, "Energy per atom (eV/atom)": row.energy/len(row.species_at_sites), "Band Gap (eV)": row.band_gap_direct or row.band_gap_indirect, "Total Magnetization (μB/f.u.)": row.total_magnetization, } # Format properties as an HTML table properties_html = html.Table([ html.Tbody([ html.Tr([html.Th(key), html.Td(str(value))]) for key, value in properties.items() ]) ], style={'border': '1px solid black', 'width': '100%', 'borderCollapse': 'collapse'}) return structure_component.layout(), properties_html # Register crystal toolkit with the app ctc.register_crystal_toolkit(app, layout) if __name__ == '__main__': app.run_server(debug=True, port=7860, host="0.0.0.0")