File size: 7,527 Bytes
44459bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""Shared utils for the predict command."""

import json
from datetime import datetime
from pathlib import Path

import typer
from folding_studio_data_models import BatchPublication, MessageStatus, Publication
from rich.markdown import Markdown

from folding_studio.console import console
from folding_studio.utils.data_model import BatchInputFile, SimpleInputFile


def validate_source_path(path: Path) -> Path:
    """Validate the prediction source path.

    Args:
        path (Path): Source path.

    Raises:
        typer.BadParameter: If the source is an empty directory.
        typer.BadParameter: If the source is a directory containing unsupported files.
        typer.BadParameter: If the source is an unsupported file.

    Returns:
        Path: The source.
    """
    supported_simple_prediction = tuple(item.value for item in SimpleInputFile)
    supported_batch_prediction = tuple(item.value for item in BatchInputFile)

    if path.is_dir():
        if not any(path.iterdir()):
            raise typer.BadParameter(f"The source directory `{path}` is empty.")

        for file in path.iterdir():
            if file.is_file():
                if file.suffix not in supported_simple_prediction:
                    raise typer.BadParameter(
                        f"The source directory '{path}' contains unsupported files. "
                        f"Only {supported_simple_prediction} files are supported."
                    )

    elif path.suffix not in supported_simple_prediction + supported_batch_prediction:
        raise typer.BadParameter(
            f"The source file '{path}' is not supported. "
            f"Only {supported_simple_prediction + supported_batch_prediction} files are supported."
        )
    return path


def validate_model_subset(model_subset: list[int]) -> list[int]:
    """Validate the model_subset argument.

    Args:
        model_subset (list[int]): List of model subset requested.

    Raises:
        typer.BadParameter: If more than 5 model ids are specified.
        typer.BadParameter: If model ids not between 1 and 5 (included).

    Returns:
        list[int]: List of model subset requested.
    """
    if len(model_subset) == 0:
        return model_subset
    elif len(model_subset) > 5:
        raise typer.BadParameter(
            f"--model_subset accept 5 model ids at most but `{len(model_subset)}` were specified."
        )
    elif min(model_subset) < 1 or max(model_subset) > 5:
        raise typer.BadParameter(
            "Model subset id out of supported range. --model_subset accepts ids between 1 and 5 (included)."
        )
    return model_subset


def print_instructions_simple(response_json: dict, metadata_file: Path | None) -> None:
    """Print pretty instructions after successful call to predict endpoint.

    Args:
        response_json (dict): Server json response
        metadata_file: (Path | None): File path where job submission metadata are written.
    """
    pub = Publication.model_validate(response_json)
    experiment_id = pub.message.experiment_id

    if pub.status == MessageStatus.NOT_PUBLISHED_DONE:
        console.print(
            f"The results of your experiment {experiment_id} were found in the cache."
        )
        console.print("Use the following command to download the prediction results:")
        md = f"""```shell
        folding experiment results {experiment_id}
        """
        console.print(Markdown(md))
    elif pub.status == MessageStatus.NOT_PUBLISHED_PENDING:
        console.print(
            f"Your experiment [bold]{experiment_id}[/bold] is [bold green]still running.[/bold green]"
        )
        console.print(
            "Use the following command to check on its status at a later time."
        )
        md = f"""```shell
        folding experiment status {experiment_id}
        """
        console.print(Markdown(md))
    elif pub.status == MessageStatus.PUBLISHED:
        console.print("[bold green]Experiment submitted successfully ![/bold green]")
        console.print(f"The experiment id is [bold]{experiment_id}[/bold]")

        if not metadata_file:
            timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
            metadata_file = f"simple_prediction_{timestamp}.json"
        with open(metadata_file, "w") as f:
            json.dump(response_json, f, indent=4)

        console.print(
            f"Prediction job metadata written to [bold]{metadata_file}[/bold]"
        )
        console.print("You can query your experiment status with the command:")
        md = f"""```shell
        folding experiment status {experiment_id}
        """
        console.print(Markdown(md))
    else:
        raise ValueError(f"Unknown publication status: {pub.status}")


def print_instructions_batch(response_json: dict, metadata_file: Path | None) -> None:
    """Print pretty instructions after successful call to batch predict endpoint.

    Args:
        response_json (dict): Server json response
        metadata_file: (Path | None): File path where job submission metadata are written.
    """
    pub = BatchPublication.model_validate(response_json)
    non_cached_exps = [
        non_cached_pub.message.experiment_id for non_cached_pub in pub.publications
    ]
    cached_exps = [
        cached_pub.message.experiment_id for cached_pub in pub.cached_publications
    ]
    done_exps = [
        cached_pub.message.experiment_id
        for cached_pub in pub.cached_publications
        if cached_pub.status == MessageStatus.NOT_PUBLISHED_DONE
    ]

    if not metadata_file:
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        metadata_file = f"batch_prediction_{timestamp}.json"
    with open(metadata_file, "w") as f:
        json.dump(response_json, f, indent=4)
    console.print(f"Batch prediction job metadata written to {metadata_file}")
    console.print("This file contains your experiments ids.")

    if pub.cached:
        console.print(
            "The results of your experiments were [bold]all found in the cache.[/bold]"
        )
        console.print("The experiment ids are:")
        console.print(f"{cached_exps}")
        console.print(
            "Use the `folding experiment status id` command to check on their status. For example:"
        )
        md = f"""```shell
        folding experiment status {cached_exps[0]}
        """
        console.print(Markdown(md))
    else:
        console.print(
            "[bold green]Batch prediction job submitted successfully ![/bold green]"
        )

        console.print(
            f"The following experiments have been [bold]submitted[/bold] (see [bold]{metadata_file}[/bold] for the full list):"
        )
        console.print(non_cached_exps)
        console.print(
            "For example, you can query an experiment status with the command:"
        )
        md = f"""```shell
            folding experiment status {non_cached_exps[0]}
        """
        console.print(Markdown(md))

        if done_exps:
            console.print(
                f"The results of the following experiments [bold]were found in the cache[/bold] (see [bold]{metadata_file}[/bold] for the full list):"
            )
            console.print(done_exps)
            console.print(
                "Use the `folding experiment results id` command to download the prediction results. For example:"
            )
            md = f"""```shell
            folding experiment results {cached_exps[0]}
            """
            console.print(Markdown(md))