narugo's picture
dev(narugo): better metrics
dfcc607
raw
history blame contribute delete
No virus
2.84 kB
import json
from functools import lru_cache
import numpy as np
import pandas as pd
from PIL import Image
from autofaiss import build_index
from hfutils.operate import get_hf_fs
from huggingface_hub import hf_hub_download
from imgutils.data import load_image
from imgutils.metrics import ccip_batch_extract_features, ccip_batch_differences, ccip_default_threshold
SRC_REPO = 'deepghs/character_index'
hf_fs = get_hf_fs()
@lru_cache()
def _make_index():
tag_infos = np.array(json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/index/tag_infos.json')))
embeddings = np.load(hf_hub_download(
repo_id=SRC_REPO,
repo_type='dataset',
filename='index/embeddings.npy',
))
index, index_infos = build_index(embeddings, save_on_disk=False)
return (index, index_infos), tag_infos
def gender_predict(p):
if p['boy'] - p['girl'] >= 0.1:
return 'male'
elif p['girl'] - p['boy'] >= 0.1:
return 'female'
else:
return 'not_sure'
def query_character(image: Image.Image, count: int = 5, order_by: str = 'same_ratio', threshold: float = 0.7):
(index, index_infos), tag_infos = _make_index()
query = ccip_batch_extract_features([image])
assert query.shape == (1, 768)
query = query / np.linalg.norm(query)
all_dists, all_indices = index.search(query, k=count)
dists, indices = all_dists[0], all_indices[0]
images, records = {}, []
for dist, idx in zip(dists, indices):
info = tag_infos[idx]
current_image = load_image(hf_hub_download(
repo_id=SRC_REPO,
repo_type='dataset',
filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp'
))
feats = np.load(hf_hub_download(
repo_id=SRC_REPO,
repo_type='dataset',
filename=f'{info["hprefix"]}/{info["short_tag"]}/feat.npy'
))
diffs = ccip_batch_differences([query[0], *feats])[0, 1:]
images[info['tag']] = current_image
records.append({
'id': info['id'],
'tag': info['tag'],
'gender': gender_predict(info['gender']),
'copyright': info['copyright'],
'index_score': dist,
'mean_diff': diffs.mean(),
'same_ratio': (diffs < ccip_default_threshold()).mean(),
})
df_records = pd.DataFrame(records)
df_records = df_records.sort_values(
by=[order_by, 'index_score'] if order_by != 'index_score' else ['index_score'],
ascending=[False, False] if order_by != 'index_score' else [False],
)
df_records = df_records[df_records[order_by] >= threshold]
ret_images = []
for row_item in df_records.to_dict('records'):
ret_images.append((images[row_item['tag']], f'{row_item["tag"]} ({row_item[order_by]:.3f})'))
return ret_images, df_records