File size: 6,907 Bytes
d051ea8
33f121d
d051ea8
 
f146e60
33f121d
e1720ec
33f121d
 
 
 
d051ea8
68eff7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1720ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68eff7b
e1720ec
68eff7b
e1720ec
 
 
 
 
 
 
 
 
68eff7b
 
e1720ec
 
 
 
 
 
 
 
 
 
 
 
 
 
68eff7b
 
 
 
 
 
e1720ec
68eff7b
e1720ec
 
 
 
 
 
 
 
 
 
68eff7b
 
 
e1720ec
 
 
 
68eff7b
e1720ec
 
d051ea8
 
 
 
e1720ec
 
 
 
f146e60
 
33f121d
 
 
 
 
 
 
 
f146e60
e1720ec
 
 
 
 
 
33f121d
 
 
e1720ec
33f121d
e1720ec
 
33f121d
e1720ec
33f121d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f146e60
33f121d
 
 
 
 
 
f146e60
d051ea8
 
33f121d
 
e1720ec
33f121d
 
f146e60
 
 
 
 
 
 
 
 
 
 
 
33f121d
f146e60
33f121d
 
d051ea8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""
Command-line interface for SlideDeck AI.
"""
import argparse
import sys
import shutil
from typing import Any

from slidedeckai.core import SlideDeckAI
from slidedeckai.global_config import GlobalConfig


def group_models_by_provider(models: list[str]) -> dict[str, list[str]]:
    """
    Group model names by their provider.

    Args:
        models (list[str]): List of model names.

    Returns:
        dict[str, list[str]]: Dictionary mapping provider codes to lists of model names.
    """
    provider_models = {}
    for model in sorted(models):
        if match := GlobalConfig.PROVIDER_REGEX.match(model):
            provider = match.group(1)
            if provider not in provider_models:
                provider_models[provider] = []
            provider_models[provider].append(model.strip())

    return provider_models


def format_models_as_bullets(models: list[str]) -> str:
    """
    Format models as a bulleted list, grouped by provider.

    Args:
        models (list[str]): List of model names.

    Returns:
        str: Formatted string of models.
    """
    provider_models = group_models_by_provider(models)
    lines = []
    for provider in sorted(provider_models.keys()):
        lines.append(f'\n{provider}:')
        for model in sorted(provider_models[provider]):
            lines.append(f'  • {model}')

    return '\n'.join(lines)


class CustomHelpFormatter(argparse.HelpFormatter):
    """
    Custom formatter for argparse that improves the display of choices.
    """
    def _format_action_invocation(self, action: Any) -> str:
        if not action.option_strings or action.nargs == 0:
            return super()._format_action_invocation(action)

        default = self._get_default_metavar_for_optional(action)
        args_string = self._format_args(action, default)

        # If there are choices, and it's the model argument, handle it specially
        if action.choices and '--model' in action.option_strings:
            return ', '.join(action.option_strings) + ' MODEL'

        return f"{', '.join(action.option_strings)} {args_string}"

    def _split_lines(self, text: str, width: int) -> list[str]:
        if text.startswith('Model choices:') or text.startswith('choose from'):
            # Special handling for model choices and error messages
            lines = []
            header = 'Available models:'
            separator = '------------------------'  # Fixed-length separator
            lines.append(header)
            lines.append(separator)

            # Extract models from text
            if text.startswith('choose from'):
                models = [
                    m.strip("' ") for m in text.replace('choose from', '').split(',')
                ]
            else:
                models = text.split('\n')[1:]

            # Use the centralized formatting
            lines.extend(format_models_as_bullets(models).split('\n'))
            return lines

        return super()._split_lines(text, width)


class CustomArgumentParser(argparse.ArgumentParser):
    """
    Custom argument parser that formats error messages better.
    """
    def error(self, message: str) -> None:
        """Custom error handler that formats model choices better"""
        if 'invalid choice' in message and '--model' in message:
            # Extract models from the error message
            choices_str = message[message.find('(choose from'):]
            models = [
                m.strip("' ") for m in choices_str.replace(
                    '(choose from', ''
                ).rstrip(')').split(',')
            ]

            error_lines = ['Error: Invalid model choice. Available models:']
            error_lines.extend(format_models_as_bullets(models).split('\n'))

            self.print_help()
            print('\n' + '\n'.join(error_lines), file=sys.stderr)
            sys.exit(2)

        super().error(message)


def format_models_list() -> str:
    """Format the models list in a nice grouped format with descriptions."""
    header = 'Supported SlideDeck AI models:\n'
    models = list(GlobalConfig.VALID_MODELS.keys())
    return header + format_models_as_bullets(models)


def format_model_help() -> str:
    """Format model choices as a grouped bulleted list for help text."""
    return format_models_as_bullets(list(GlobalConfig.VALID_MODELS.keys()))


def main():
    """
    The main function for the CLI.
    """
    parser = CustomArgumentParser(
        description='Generate slide decks with SlideDeck AI.',
        formatter_class=CustomHelpFormatter
    )
    subparsers = parser.add_subparsers(dest='command')

    # Top-level flag to list supported models
    parser.add_argument(
        '-l',
        '--list-models',
        action='store_true',
        help='List supported model keys and exit.',
    )

    # 'generate' command
    parser_generate = subparsers.add_parser(
        'generate',
        help='Generate a new slide deck.',
        formatter_class=CustomHelpFormatter
    )

    parser_generate.add_argument(
        '--model',
        required=True,
        choices=GlobalConfig.VALID_MODELS.keys(),
        help=(
            'Model name to use. Must be one of the supported models in the'
            ' `[provider-code]model_name` format.' + format_model_help()
        ),
        metavar='MODEL'
    )
    parser_generate.add_argument(
        '--topic',
        required=True,
        help='The topic of the slide deck.',
    )
    parser_generate.add_argument(
        '--api-key',
        help=(
            'The API key for the LLM provider. Alternatively, set the appropriate API key'
            ' in the environment variable.'
        ),
    )
    parser_generate.add_argument(
        '--template-id',
        type=int,
        default=0,
        help='The index of the PowerPoint template to use.',
    )
    parser_generate.add_argument(
        '--output-path',
        help='The path to save the generated .pptx file.',
    )

    # Note: the 'launch' command has been intentionally disabled.

    # If no arguments are provided, show help and exit
    if len(sys.argv) == 1:
        parser.print_help()
        return

    args = parser.parse_args()

    # If --list-models flag was provided, print models and exit
    if getattr(args, 'list_models', False):
        print(format_models_list())
        return

    if args.command == 'generate':
        slide_generator = SlideDeckAI(
            model=args.model,
            topic=args.topic,
            api_key=args.api_key,
            template_idx=args.template_id,
        )

        pptx_path = slide_generator.generate()

        if args.output_path:
            shutil.move(str(pptx_path), args.output_path)
            print(f'\n🤖 Slide deck saved to: {args.output_path}')
        else:
            print(f'\n🤖 Slide deck saved to: {pptx_path}')


if __name__ == '__main__':
    main()