File size: 3,056 Bytes
4f8ad24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from functools import partial
from typing import Iterator, Union, List, Mapping, Literal

from PIL import Image
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags

from .base import ProcessAction, BaseAction
from ..model import ImageItem


def _deepdanbooru_tagging(image: Image.Image, use_real_name: bool = False,
                          general_threshold: float = 0.5, character_threshold: float = 0.5, **kwargs):
    _ = kwargs
    _, features, characters = get_deepdanbooru_tags(image, use_real_name, general_threshold, character_threshold)
    return {**features, **characters}


def _wd14_tagging(image: Image.Image, model_name: str,
                  general_threshold: float = 0.35, character_threshold: float = 0.85, **kwargs):
    _ = kwargs
    _, features, characters = get_wd14_tags(image, model_name, general_threshold, character_threshold)
    return {**features, **characters}


def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general_threshold: float = 0.7, **kwargs):
    _ = kwargs
    features = get_mldanbooru_tags(image, use_real_name, general_threshold)
    return features


_TAGGING_METHODS = {
    'deepdanbooru': _deepdanbooru_tagging,
    'wd14_vit': partial(_wd14_tagging, model_name='ViT'),
    'wd14_convnext': partial(_wd14_tagging, model_name='ConvNext'),
    'wd14_convnextv2': partial(_wd14_tagging, model_name='ConvNextV2'),
    'wd14_swinv2': partial(_wd14_tagging, model_name='SwinV2'),
    'mldanbooru': _mldanbooru_tagging,
}

TaggingMethodTyping = Literal[
    'deepdanbooru', 'wd14_vit', 'wd14_convnext', 'wd14_convnextv2', 'wd14_swinv2', 'mldanbooru']


class TaggingAction(ProcessAction):
    def __init__(self, method: TaggingMethodTyping = 'wd14_convnextv2', force: bool = False, **kwargs):
        self.method = _TAGGING_METHODS[method]
        self.force = force
        self.kwargs = kwargs

    def process(self, item: ImageItem) -> ImageItem:
        if 'tags' in item.meta and not self.force:
            return item
        else:
            tags = self.method(image=item.image, **self.kwargs)
            return ImageItem(item.image, {**item.meta, 'tags': tags})


class TagFilterAction(BaseAction):
    def __init__(self, tags: Union[List[str], Mapping[str, float]],
                 method: TaggingMethodTyping = 'wd14_convnextv2', **kwargs):
        if isinstance(tags, (list, tuple)):
            self.tags = {tag: 1e-6 for tag in tags}
        elif isinstance(tags, dict):
            self.tags = dict(tags)
        else:
            raise TypeError(f'Unknown type of tags - {tags!r}.')
        self.tagger = TaggingAction(method, force=False, **kwargs)

    def iter(self, item: ImageItem) -> Iterator[ImageItem]:
        item = self.tagger(item)
        tags = item.meta['tags']

        valid = True
        for tag, min_score in self.tags.items():
            if tags[tag] < min_score:
                valid = False
                break

        if valid:
            yield item

    def reset(self):
        self.tagger.reset()