LittleApple-fp16's picture
Upload 88 files
4f8ad24
raw
history blame
5.72 kB
import logging
from enum import IntEnum
from typing import Iterator, Optional, List, Tuple
import numpy as np
from hbutils.string import plural_word
from hbutils.testing import disable_output
from imgutils.metrics import ccip_extract_feature, ccip_default_threshold, ccip_clustering, ccip_batch_differences
from .base import BaseAction
from ..model import ImageItem
class CCIPStatus(IntEnum):
INIT = 0x1
APPROACH = 0x2
EVAL = 0x3
INIT_WITH_SOURCE = 0x4
class CCIPAction(BaseAction):
def __init__(self, init_source=None, *, min_val_count: int = 15, step: int = 5,
ratio_threshold: float = 0.6, min_clu_dump_ratio: float = 0.3, cmp_threshold: float = 0.5,
eps: Optional[float] = None, min_samples: Optional[int] = None,
model='ccip-caformer-24-randaug-pruned', threshold: Optional[float] = None):
self.init_source = init_source
self.min_val_count = min_val_count
self.step = step
self.ratio_threshold = ratio_threshold
self.min_clu_dump_ratio = min_clu_dump_ratio
self.cmp_threshold = cmp_threshold
self.eps, self.min_samples = eps, min_samples
self.model = model
self.threshold = threshold or ccip_default_threshold(self.model)
self.items = []
self.item_released = []
self.feats = []
if self.init_source is not None:
self.status = CCIPStatus.INIT_WITH_SOURCE
else:
self.status = CCIPStatus.INIT
def _extract_feature(self, item: ImageItem):
if 'ccip_feature' in item.meta:
return item.meta['ccip_feature']
else:
return ccip_extract_feature(item.image, model=self.model)
def _try_cluster(self) -> bool:
with disable_output():
clu_ids = ccip_clustering(self.feats, method='optics', model=self.model,
eps=self.eps, min_samples=self.min_samples)
clu_counts = {}
for id_ in clu_ids:
if id_ != -1:
clu_counts[id_] = clu_counts.get(id_, 0) + 1
clu_total = sum(clu_counts.values()) if clu_counts else 0
chosen_id = None
for id_, count in clu_counts.items():
if count >= clu_total * self.ratio_threshold:
chosen_id = id_
break
if chosen_id is not None:
feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id]
clu_dump_ratio = np.array([
self._compare_to_exists(feat, base_set=feats)
for feat in feats
]).astype(float).mean()
if clu_dump_ratio >= self.min_clu_dump_ratio:
self.items = [item for i, item in enumerate(self.items) if clu_ids[i] == chosen_id]
self.item_released = [False] * len(self.items)
self.feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id]
return True
else:
return False
else:
return False
def _compare_to_exists(self, feat, base_set=None) -> Tuple[bool, List[int]]:
diffs = ccip_batch_differences([feat, *(base_set or self.feats)], model=self.model)[0, 1:]
matches = diffs <= self.threshold
return matches.astype(float).mean() >= self.cmp_threshold
def _dump_items(self) -> Iterator[ImageItem]:
for i in range(len(self.items)):
if not self.item_released[i]:
if self._compare_to_exists(self.feats[i]):
self.item_released[i] = True
yield self.items[i]
def _eval_iter(self, item: ImageItem) -> Iterator[ImageItem]:
feat = self._extract_feature(item)
if self._compare_to_exists(feat):
self.feats.append(feat)
yield item
if (len(self.feats) - len(self.items)) % self.step == 0:
yield from self._dump_items()
def iter(self, item: ImageItem) -> Iterator[ImageItem]:
if self.status == CCIPStatus.INIT_WITH_SOURCE:
cnt = 0
logging.info('Existing anchor detected.')
for item_ in self.init_source:
self.feats.append(self._extract_feature(item_))
yield item_
cnt += 1
logging.info(f'{plural_word(cnt, "items")} loaded from anchor.')
self.status = CCIPStatus.EVAL
yield from self._eval_iter(item)
elif self.status == CCIPStatus.INIT:
self.items.append(item)
self.feats.append(self._extract_feature(item))
if len(self.items) >= self.min_val_count:
if self._try_cluster():
self.status = CCIPStatus.EVAL
yield from self._dump_items()
else:
self.status = CCIPStatus.APPROACH
elif self.status == CCIPStatus.APPROACH:
self.items.append(item)
self.feats.append(self._extract_feature(item))
if (len(self.items) - self.min_val_count) % self.step == 0:
if self._try_cluster():
self.status = CCIPStatus.EVAL
yield from self._dump_items()
elif self.status == CCIPStatus.EVAL:
yield from self._eval_iter(item)
else:
raise ValueError(f'Unknown status for {self.__class__.__name__} - {self.status!r}.')
def reset(self):
self.items.clear()
self.item_released.clear()
self.feats.clear()
if self.init_source:
self.status = CCIPStatus.INIT_WITH_SOURCE
else:
self.status = CCIPStatus.INIT