Spaces:
Sleeping
Sleeping
File size: 4,294 Bytes
c703bc8 19e399c c703bc8 19e399c 63d9b78 c703bc8 19e399c c703bc8 63d9b78 c703bc8 19e399c 8a6f0bc 19e399c 8a6f0bc 19e399c 63d9b78 19e399c 33d620a 19e399c 63d9b78 19e399c 63d9b78 19e399c 63d9b78 19e399c c703bc8 19e399c c703bc8 19e399c c703bc8 19e399c c703bc8 19e399c 33dfe7b 19e399c c703bc8 19e399c c703bc8 19e399c c703bc8 19e399c c703bc8 19e399c c703bc8 19e399c c703bc8 19e399c c703bc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import logging
import os
import pathlib
import shutil
import tempfile
from pathlib import Path
from collections import defaultdict
import gradio as gr
import pandas as pd
from gt4sd.properties.crystals import CRYSTALS_PROPERTY_PREDICTOR_FACTORY
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
suffix_dict = {"metal_nonmetal_classifier": [".csv"]}
def create_temp_file(path: str) -> str:
temp_dir = tempfile.gettempdir()
temp_folder = os.path.join(temp_dir, "gt4sd_crystal")
os.makedirs(temp_folder, exist_ok=True)
# Clean up directory
for i in os.listdir(temp_folder):
print("Removing", i)
os.remove(os.path.join(temp_folder, i))
temp_path = os.path.join(temp_folder, path.split("/")[-1])
shutil.copy2(path, temp_path)
return temp_path
def main(property: str, data_file: str):
print(data_file, data_file.name)
if data_file is None:
raise TypeError("You have to pass either an input file for the crystal model")
prop_name = property.replace(" ", "_").lower()
# Copy file into a UNIQUE temporary directory
# if data_file.name.endswith("cfsdfsdsv"):
# file_path = Path(create_temp_file(data_file.orig_name))
# else:
# file_path = Path(create_temp_file(data_file.name))
file_path = Path(create_temp_file(data_file.name))
folder = file_path.parent
if file_path.suffix == ".cif":
if ".cif" not in suffix_dict.get(prop_name, [".cif", ".zip"]):
raise ValueError(
f"For this property, provide {suffix_dict[prop_name]}, not `.cif`."
)
input_path = folder
elif file_path.suffix == ".csv":
if ".csv" not in suffix_dict.get(prop_name, [".cif", ".zip"]):
raise ValueError(
f"For this property, provide {suffix_dict.get(prop_name, ['.cif', '.zip'])}, not `.csv`."
)
input_path = file_path
elif file_path.suffix == ".zip":
if ".zip" not in suffix_dict.get(prop_name, [".cif", ".zip"]):
raise ValueError(
f"For this property, provide {suffix_dict[prop_name]}, not `.zip`."
)
# Unzip zip
shutil.unpack_archive(file_path, file_path.parent)
if len(list(filter(lambda x: x.endswith(".cif"), os.listdir(folder)))) == 0:
raise ValueError("No `.cif` files were found inside the `.zip`.")
input_path = folder
else:
raise TypeError(
"You have to pass a `.csv` (for `metal_nonmetal_classifier`),"
" a `.cif` (for all other properties) or a `.zip` with multiple"
f" `.cif` files. Not {type(data_file)}."
)
algo, config = CRYSTALS_PROPERTY_PREDICTOR_FACTORY[prop_name]
# Pass hyperparameters if applicable
kwargs = {"algorithm_version": "v0"}
model = algo(config(**kwargs))
result = model(input=input_path)
return pd.DataFrame(result)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
properties = list(CRYSTALS_PROPERTY_PREDICTOR_FACTORY.keys())[::-1]
properties = list(map(lambda x: x.replace("_", " ").title(), properties))
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = [
["Formation Energy", str(metadata_root.joinpath("7206075.cif"))],
["Bulk moduli", str(metadata_root.joinpath("crystals.zip"))],
["Metal Nonmetal Classifier", str(metadata_root.joinpath("metal.csv"))],
["Bulk moduli", str(metadata_root.joinpath("9000046.cif"))],
]
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=main,
title="Crystal properties",
inputs=[
gr.Dropdown(properties, label="Property", value="Instability"),
gr.File(
file_types=[".cif", ".csv", ".zip"],
label="Input file for crystal model",
),
],
outputs=gr.DataFrame(label="Output"),
article=article,
description=description,
examples=examples,
)
demo.launch(debug=True, show_error=True)
|