File size: 2,495 Bytes
71e47a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import traceback
from importlib import import_module
from pathlib import Path
from typing import Tuple

from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from modules import shared
from modules.logging_colors import logger


def _get_available_pipeline_modules():
    pipeline_path = Path(__file__).parent / 'pipelines'
    modules = [p for p in pipeline_path.iterdir() if p.is_dir()]
    return [m.name for m in modules if (m / 'pipelines.py').exists()]


def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
    pipeline_modules = {}
    available_pipeline_modules = _get_available_pipeline_modules()
    for name in available_pipeline_modules:
        try:
            pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
        except:
            logger.warning(f'Failed to get multimodal pipelines from {name}')
            logger.warning(traceback.format_exc())

    if shared.args.multimodal_pipeline is not None:
        for k in pipeline_modules:
            if hasattr(pipeline_modules[k], 'get_pipeline'):
                pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
                if pipeline is not None:
                    return (pipeline, k)
    else:
        model_name = shared.args.model.lower()
        for k in pipeline_modules:
            if hasattr(pipeline_modules[k], 'get_pipeline_from_model_name'):
                pipeline = getattr(pipeline_modules[k], 'get_pipeline_from_model_name')(model_name, params)
                if pipeline is not None:
                    return (pipeline, k)

    available = []
    for k in pipeline_modules:
        if hasattr(pipeline_modules[k], 'available_pipelines'):
            pipelines = getattr(pipeline_modules[k], 'available_pipelines')
            available += pipelines

    if shared.args.multimodal_pipeline is not None:
        log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
    else:
        log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
    logger.critical(f'{log} Please specify a correct pipeline, or disable the extension')
    raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')