Spaces:
Runtime error
Runtime error
from typing import Dict, Iterator, Literal | |
import numpy as np | |
from imgutils.metrics import lpips_difference, lpips_extract_feature | |
from .base import BaseAction | |
from ..model import ImageItem | |
class FeatureBucket: | |
def __init__(self, threshold: float = 0.45, capacity: int = 500, rtol=1.e-5, atol=1.e-8): | |
self.threshold = threshold | |
self.rtol, self.atol = rtol, atol | |
self.features = [] | |
self.ratios = np.array([], dtype=float) | |
self.capacity = capacity | |
def check_duplicate(self, feat, ratio: float): | |
for id_ in np.where(np.isclose(self.ratios, ratio, rtol=self.rtol, atol=self.atol))[0]: | |
exist_feat = self.features[id_.item()] | |
if lpips_difference(exist_feat, feat) <= self.threshold: | |
return True | |
return False | |
def add(self, feat, ratio: float): | |
self.features.append(feat) | |
self.ratios = np.append(self.ratios, ratio) | |
if len(self.features) >= self.capacity * 2: | |
self.features = self.features[-self.capacity:] | |
self.ratios = self.ratios[-self.capacity:] | |
FilterSimilarModeTyping = Literal['all', 'group'] | |
class FilterSimilarAction(BaseAction): | |
def __init__(self, mode: FilterSimilarModeTyping = 'all', threshold: float = 0.45, | |
capacity: int = 500, rtol=5.e-2, atol=2.e-2): | |
self.mode = mode | |
self.threshold, self.rtol, self.atol = threshold, rtol, atol | |
self.capacity = capacity | |
self.buckets: Dict[str, FeatureBucket] = {} | |
self.global_bucket = FeatureBucket(threshold, self.capacity, rtol, atol) | |
def _get_bin(self, group_id): | |
if self.mode == 'all': | |
return self.global_bucket | |
elif self.mode == 'group': | |
if group_id not in self.buckets: | |
self.buckets[group_id] = FeatureBucket(self.threshold, self.capacity, self.rtol, self.atol) | |
return self.buckets[group_id] | |
else: | |
raise ValueError(f'Unknown mode for filter similar action - {self.mode!r}.') | |
def iter(self, item: ImageItem) -> Iterator[ImageItem]: | |
image = item.image | |
ratio = image.height * 1.0 / image.width | |
feat = lpips_extract_feature(image) | |
bucket = self._get_bin(item.meta.get('group_id')) | |
if not bucket.check_duplicate(feat, ratio): | |
bucket.add(feat, ratio) | |
yield item | |
def reset(self): | |
self.buckets.clear() | |
self.global_bucket = FeatureBucket(self.threshold, self.capacity, self.rtol, self.atol) | |