Spaces:
Runtime error
Runtime error
import logging | |
from enum import IntEnum | |
from typing import Iterator, Optional, List, Tuple | |
import numpy as np | |
from hbutils.string import plural_word | |
from hbutils.testing import disable_output | |
from imgutils.metrics import ccip_extract_feature, ccip_default_threshold, ccip_clustering, ccip_batch_differences | |
from .base import BaseAction | |
from ..model import ImageItem | |
class CCIPStatus(IntEnum): | |
INIT = 0x1 | |
APPROACH = 0x2 | |
EVAL = 0x3 | |
INIT_WITH_SOURCE = 0x4 | |
class CCIPAction(BaseAction): | |
def __init__(self, init_source=None, *, min_val_count: int = 15, step: int = 5, | |
ratio_threshold: float = 0.6, min_clu_dump_ratio: float = 0.3, cmp_threshold: float = 0.5, | |
eps: Optional[float] = None, min_samples: Optional[int] = None, | |
model='ccip-caformer-24-randaug-pruned', threshold: Optional[float] = None): | |
self.init_source = init_source | |
self.min_val_count = min_val_count | |
self.step = step | |
self.ratio_threshold = ratio_threshold | |
self.min_clu_dump_ratio = min_clu_dump_ratio | |
self.cmp_threshold = cmp_threshold | |
self.eps, self.min_samples = eps, min_samples | |
self.model = model | |
self.threshold = threshold or ccip_default_threshold(self.model) | |
self.items = [] | |
self.item_released = [] | |
self.feats = [] | |
if self.init_source is not None: | |
self.status = CCIPStatus.INIT_WITH_SOURCE | |
else: | |
self.status = CCIPStatus.INIT | |
def _extract_feature(self, item: ImageItem): | |
if 'ccip_feature' in item.meta: | |
return item.meta['ccip_feature'] | |
else: | |
return ccip_extract_feature(item.image, model=self.model) | |
def _try_cluster(self) -> bool: | |
with disable_output(): | |
clu_ids = ccip_clustering(self.feats, method='optics', model=self.model, | |
eps=self.eps, min_samples=self.min_samples) | |
clu_counts = {} | |
for id_ in clu_ids: | |
if id_ != -1: | |
clu_counts[id_] = clu_counts.get(id_, 0) + 1 | |
clu_total = sum(clu_counts.values()) if clu_counts else 0 | |
chosen_id = None | |
for id_, count in clu_counts.items(): | |
if count >= clu_total * self.ratio_threshold: | |
chosen_id = id_ | |
break | |
if chosen_id is not None: | |
feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id] | |
clu_dump_ratio = np.array([ | |
self._compare_to_exists(feat, base_set=feats) | |
for feat in feats | |
]).astype(float).mean() | |
if clu_dump_ratio >= self.min_clu_dump_ratio: | |
self.items = [item for i, item in enumerate(self.items) if clu_ids[i] == chosen_id] | |
self.item_released = [False] * len(self.items) | |
self.feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id] | |
return True | |
else: | |
return False | |
else: | |
return False | |
def _compare_to_exists(self, feat, base_set=None) -> Tuple[bool, List[int]]: | |
diffs = ccip_batch_differences([feat, *(base_set or self.feats)], model=self.model)[0, 1:] | |
matches = diffs <= self.threshold | |
return matches.astype(float).mean() >= self.cmp_threshold | |
def _dump_items(self) -> Iterator[ImageItem]: | |
for i in range(len(self.items)): | |
if not self.item_released[i]: | |
if self._compare_to_exists(self.feats[i]): | |
self.item_released[i] = True | |
yield self.items[i] | |
def _eval_iter(self, item: ImageItem) -> Iterator[ImageItem]: | |
feat = self._extract_feature(item) | |
if self._compare_to_exists(feat): | |
self.feats.append(feat) | |
yield item | |
if (len(self.feats) - len(self.items)) % self.step == 0: | |
yield from self._dump_items() | |
def iter(self, item: ImageItem) -> Iterator[ImageItem]: | |
if self.status == CCIPStatus.INIT_WITH_SOURCE: | |
cnt = 0 | |
logging.info('Existing anchor detected.') | |
for item_ in self.init_source: | |
self.feats.append(self._extract_feature(item_)) | |
yield item_ | |
cnt += 1 | |
logging.info(f'{plural_word(cnt, "items")} loaded from anchor.') | |
self.status = CCIPStatus.EVAL | |
yield from self._eval_iter(item) | |
elif self.status == CCIPStatus.INIT: | |
self.items.append(item) | |
self.feats.append(self._extract_feature(item)) | |
if len(self.items) >= self.min_val_count: | |
if self._try_cluster(): | |
self.status = CCIPStatus.EVAL | |
yield from self._dump_items() | |
else: | |
self.status = CCIPStatus.APPROACH | |
elif self.status == CCIPStatus.APPROACH: | |
self.items.append(item) | |
self.feats.append(self._extract_feature(item)) | |
if (len(self.items) - self.min_val_count) % self.step == 0: | |
if self._try_cluster(): | |
self.status = CCIPStatus.EVAL | |
yield from self._dump_items() | |
elif self.status == CCIPStatus.EVAL: | |
yield from self._eval_iter(item) | |
else: | |
raise ValueError(f'Unknown status for {self.__class__.__name__} - {self.status!r}.') | |
def reset(self): | |
self.items.clear() | |
self.item_released.clear() | |
self.feats.clear() | |
if self.init_source: | |
self.status = CCIPStatus.INIT_WITH_SOURCE | |
else: | |
self.status = CCIPStatus.INIT | |