Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import pathlib | |
import shutil | |
import tempfile | |
from pathlib import Path | |
from collections import defaultdict | |
import gradio as gr | |
import pandas as pd | |
from gt4sd.properties.crystals import CRYSTALS_PROPERTY_PREDICTOR_FACTORY | |
logger = logging.getLogger(__name__) | |
logger.addHandler(logging.NullHandler()) | |
suffix_dict = {"metal_nonmetal_classifier": [".csv"]} | |
def create_temp_file(path: str) -> str: | |
temp_dir = tempfile.gettempdir() | |
temp_folder = os.path.join(temp_dir, "gt4sd_crystal") | |
os.makedirs(temp_folder, exist_ok=True) | |
# Clean up directory | |
for i in os.listdir(temp_folder): | |
print("Removing", i) | |
os.remove(os.path.join(temp_folder, i)) | |
temp_path = os.path.join(temp_folder, path.split("/")[-1]) | |
shutil.copy2(path, temp_path) | |
return temp_path | |
def main(property: str, data_file: str): | |
print(data_file, data_file.orig_name, data_file.name) | |
if data_file is None: | |
raise TypeError("You have to pass either an input file for the crystal model") | |
prop_name = property.replace(" ", "_").lower() | |
# Copy file into a UNIQUE temporary directory | |
# if data_file.name.endswith("cfsdfsdsv"): | |
# file_path = Path(create_temp_file(data_file.orig_name)) | |
# else: | |
# file_path = Path(create_temp_file(data_file.name)) | |
file_path = Path(create_temp_file(data_file.name)) | |
folder = file_path.parent | |
if file_path.suffix == ".cif": | |
if ".cif" not in suffix_dict.get(prop_name, [".cif", ".zip"]): | |
raise ValueError( | |
f"For this property, provide {suffix_dict[prop_name]}, not `.cif`." | |
) | |
input_path = folder | |
elif file_path.suffix == ".csv": | |
if ".csv" not in suffix_dict.get(prop_name, [".cif", ".zip"]): | |
raise ValueError( | |
f"For this property, provide {suffix_dict.get(prop_name, ['.cif', '.zip'])}, not `.csv`." | |
) | |
input_path = file_path | |
elif file_path.suffix == ".zip": | |
if ".zip" not in suffix_dict.get(prop_name, [".cif", ".zip"]): | |
raise ValueError( | |
f"For this property, provide {suffix_dict[prop_name]}, not `.zip`." | |
) | |
# Unzip zip | |
shutil.unpack_archive(file_path, file_path.parent) | |
if len(list(filter(lambda x: x.endswith(".cif"), os.listdir(folder)))) == 0: | |
raise ValueError("No `.cif` files were found inside the `.zip`.") | |
input_path = folder | |
else: | |
raise TypeError( | |
"You have to pass a `.csv` (for `metal_nonmetal_classifier`)," | |
" a `.cif` (for all other properties) or a `.zip` with multiple" | |
f" `.cif` files. Not {type(data_file)}." | |
) | |
algo, config = CRYSTALS_PROPERTY_PREDICTOR_FACTORY[prop_name] | |
# Pass hyperparameters if applicable | |
kwargs = {"algorithm_version": "v0"} | |
model = algo(config(**kwargs)) | |
result = model(input=input_path) | |
return pd.DataFrame(result) | |
if __name__ == "__main__": | |
# Preparation (retrieve all available algorithms) | |
properties = list(CRYSTALS_PROPERTY_PREDICTOR_FACTORY.keys())[::-1] | |
properties = list(map(lambda x: x.replace("_", " ").title(), properties)) | |
# Load metadata | |
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") | |
examples = [ | |
["Formation Energy", metadata_root.joinpath("7206075.cif")], | |
["Bulk moduli", metadata_root.joinpath("crystals.zip")], | |
["Metal Nonmetal Classifier", metadata_root.joinpath("metal.csv")], | |
["Bulk moduli", metadata_root.joinpath("9000046.cif")], | |
] | |
with open(metadata_root.joinpath("article.md"), "r") as f: | |
article = f.read() | |
with open(metadata_root.joinpath("description.md"), "r") as f: | |
description = f.read() | |
demo = gr.Interface( | |
fn=main, | |
title="Crystal properties", | |
inputs=[ | |
gr.Dropdown(properties, label="Property", value="Instability"), | |
gr.File( | |
file_types=[".cif", ".csv", ".zip"], | |
label="Input file for crystal model", | |
), | |
], | |
outputs=gr.DataFrame(label="Output"), | |
article=article, | |
description=description, | |
examples=examples, | |
) | |
demo.launch(debug=True, show_error=True) | |