File size: 4,294 Bytes
c703bc8
19e399c
c703bc8
19e399c
 
 
63d9b78
c703bc8
 
 
19e399c
c703bc8
 
 
 
63d9b78
c703bc8
19e399c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a6f0bc
19e399c
8a6f0bc
19e399c
 
 
 
63d9b78
 
19e399c
33d620a
 
 
 
 
19e399c
 
63d9b78
 
 
 
19e399c
 
63d9b78
 
 
 
19e399c
 
63d9b78
 
 
 
19e399c
 
 
 
 
c703bc8
19e399c
 
 
 
c703bc8
19e399c
 
 
 
 
 
 
 
c703bc8
 
 
 
 
19e399c
 
c703bc8
 
 
 
19e399c
33dfe7b
 
 
 
19e399c
c703bc8
19e399c
c703bc8
19e399c
c703bc8
 
 
19e399c
 
c703bc8
19e399c
 
 
 
c703bc8
 
19e399c
c703bc8
 
19e399c
c703bc8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import logging
import os
import pathlib
import shutil
import tempfile
from pathlib import Path
from collections import defaultdict

import gradio as gr
import pandas as pd
from gt4sd.properties.crystals import CRYSTALS_PROPERTY_PREDICTOR_FACTORY

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

suffix_dict = {"metal_nonmetal_classifier": [".csv"]}


def create_temp_file(path: str) -> str:
    temp_dir = tempfile.gettempdir()
    temp_folder = os.path.join(temp_dir, "gt4sd_crystal")
    os.makedirs(temp_folder, exist_ok=True)
    # Clean up directory
    for i in os.listdir(temp_folder):
        print("Removing", i)
        os.remove(os.path.join(temp_folder, i))

    temp_path = os.path.join(temp_folder, path.split("/")[-1])
    shutil.copy2(path, temp_path)
    return temp_path


def main(property: str, data_file: str):

    print(data_file, data_file.name)

    if data_file is None:
        raise TypeError("You have to pass either an input file for the crystal model")

    prop_name = property.replace(" ", "_").lower()

    # Copy file into a UNIQUE temporary directory
    # if data_file.name.endswith("cfsdfsdsv"):
    #     file_path = Path(create_temp_file(data_file.orig_name))
    # else:
    #     file_path = Path(create_temp_file(data_file.name))
    file_path = Path(create_temp_file(data_file.name))
    folder = file_path.parent
    if file_path.suffix == ".cif":
        if ".cif" not in suffix_dict.get(prop_name, [".cif", ".zip"]):
            raise ValueError(
                f"For this property, provide {suffix_dict[prop_name]}, not `.cif`."
            )
        input_path = folder
    elif file_path.suffix == ".csv":
        if ".csv" not in suffix_dict.get(prop_name, [".cif", ".zip"]):
            raise ValueError(
                f"For this property, provide {suffix_dict.get(prop_name, ['.cif', '.zip'])}, not `.csv`."
            )
        input_path = file_path
    elif file_path.suffix == ".zip":
        if ".zip" not in suffix_dict.get(prop_name, [".cif", ".zip"]):
            raise ValueError(
                f"For this property, provide {suffix_dict[prop_name]}, not `.zip`."
            )
        # Unzip zip
        shutil.unpack_archive(file_path, file_path.parent)
        if len(list(filter(lambda x: x.endswith(".cif"), os.listdir(folder)))) == 0:
            raise ValueError("No `.cif` files were found inside the `.zip`.")
        input_path = folder
    else:
        raise TypeError(
            "You have to pass a `.csv` (for `metal_nonmetal_classifier`),"
            " a `.cif` (for all other properties) or a `.zip` with multiple"
            f" `.cif` files. Not {type(data_file)}."
        )

    algo, config = CRYSTALS_PROPERTY_PREDICTOR_FACTORY[prop_name]
    # Pass hyperparameters if applicable
    kwargs = {"algorithm_version": "v0"}
    model = algo(config(**kwargs))

    result = model(input=input_path)
    return pd.DataFrame(result)


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    properties = list(CRYSTALS_PROPERTY_PREDICTOR_FACTORY.keys())[::-1]
    properties = list(map(lambda x: x.replace("_", " ").title(), properties))

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = [
        ["Formation Energy", str(metadata_root.joinpath("7206075.cif"))],
        ["Bulk moduli", str(metadata_root.joinpath("crystals.zip"))],
        ["Metal Nonmetal Classifier", str(metadata_root.joinpath("metal.csv"))],
        ["Bulk moduli", str(metadata_root.joinpath("9000046.cif"))],
    ]

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=main,
        title="Crystal properties",
        inputs=[
            gr.Dropdown(properties, label="Property", value="Instability"),
            gr.File(
                file_types=[".cif", ".csv", ".zip"],
                label="Input file for crystal model",
            ),
        ],
        outputs=gr.DataFrame(label="Output"),
        article=article,
        description=description,
        examples=examples,
    )
    demo.launch(debug=True, show_error=True)