crystals / app.py
jannisborn's picture
update
ee3d7ae unverified
raw
history blame
4.3 kB
import logging
import os
import pathlib
import shutil
import tempfile
from pathlib import Path
from collections import defaultdict
import gradio as gr
import pandas as pd
from gt4sd.properties.crystals import CRYSTALS_PROPERTY_PREDICTOR_FACTORY
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
suffix_dict = {"metal_nonmetal_classifier": [".csv"]}
def create_temp_file(path: str) -> str:
temp_dir = tempfile.gettempdir()
temp_folder = os.path.join(temp_dir, "gt4sd_crystal")
os.makedirs(temp_folder, exist_ok=True)
# Clean up directory
for i in os.listdir(temp_folder):
print("Removing", i)
os.remove(os.path.join(temp_folder, i))
temp_path = os.path.join(temp_folder, path.split("/")[-1])
shutil.copy2(path, temp_path)
return temp_path
def main(property: str, data_file: str):
print(data_file, data_file.orig_name, data_file.name)
if data_file is None:
raise TypeError("You have to pass either an input file for the crystal model")
prop_name = property.replace(" ", "_").lower()
# Copy file into a UNIQUE temporary directory
# if data_file.name.endswith("cfsdfsdsv"):
# file_path = Path(create_temp_file(data_file.orig_name))
# else:
# file_path = Path(create_temp_file(data_file.name))
file_path = Path(create_temp_file(data_file.name))
folder = file_path.parent
if file_path.suffix == ".cif":
if ".cif" not in suffix_dict.get(prop_name, [".cif", ".zip"]):
raise ValueError(
f"For this property, provide {suffix_dict[prop_name]}, not `.cif`."
)
input_path = folder
elif file_path.suffix == ".csv":
if ".csv" not in suffix_dict.get(prop_name, [".cif", ".zip"]):
raise ValueError(
f"For this property, provide {suffix_dict.get(prop_name, ['.cif', '.zip'])}, not `.csv`."
)
input_path = file_path
elif file_path.suffix == ".zip":
if ".zip" not in suffix_dict.get(prop_name, [".cif", ".zip"]):
raise ValueError(
f"For this property, provide {suffix_dict[prop_name]}, not `.zip`."
)
# Unzip zip
shutil.unpack_archive(file_path, file_path.parent)
if len(list(filter(lambda x: x.endswith(".cif"), os.listdir(folder)))) == 0:
raise ValueError("No `.cif` files were found inside the `.zip`.")
input_path = folder
else:
raise TypeError(
"You have to pass a `.csv` (for `metal_nonmetal_classifier`),"
" a `.cif` (for all other properties) or a `.zip` with multiple"
f" `.cif` files. Not {type(data_file)}."
)
algo, config = CRYSTALS_PROPERTY_PREDICTOR_FACTORY[prop_name]
# Pass hyperparameters if applicable
kwargs = {"algorithm_version": "v0"}
model = algo(config(**kwargs))
result = model(input=input_path)
return pd.DataFrame(result)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
properties = list(CRYSTALS_PROPERTY_PREDICTOR_FACTORY.keys())[::-1]
properties = list(map(lambda x: x.replace("_", " ").title(), properties))
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = [
["Formation Energy", metadata_root.joinpath("7206075.cif")],
["Bulk moduli", metadata_root.joinpath("crystals.zip")],
["Metal Nonmetal Classifier", metadata_root.joinpath("metal.csv")],
["Bulk moduli", metadata_root.joinpath("9000046.cif")],
]
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=main,
title="Crystal properties",
inputs=[
gr.Dropdown(properties, label="Property", value="Instability"),
gr.File(
file_types=[".cif", ".csv", ".zip"],
label="Input file for crystal model",
),
],
outputs=gr.DataFrame(label="Output"),
article=article,
description=description,
examples=examples,
)
demo.launch(debug=True, show_error=True)