Spaces:
Build error
Build error
File size: 2,154 Bytes
844aef2 d7cdc67 844aef2 d7cdc67 844aef2 090de1f 844aef2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
"""Define uclas."""
# pylint: disable=invalid-name
from typing import List, Tuple, Union
import logzero
import numpy as np
from joblib import Memory
from logzero import logger
# set PYTHONPATH=..\align-model-pool # in win10
from model_pool.fetch_check_aux import fetch_check_aux
from model_pool.load_model import load_model
from model_pool.model_s import load_model_s
from sklearn.metrics.pairwise import cosine_similarity
# logzero.loglevel(20)
# fetch_check_aux("/home/user")
try:
fetch_check_aux()
except Exception as _:
logger.error(_)
model_s = load_model_s()
clas = load_model("clas-l-user")
location = "./cachedir"
memory = Memory(location, verbose=0)
@memory.cache
def cached_clas(*args, **kw):
"""Cache clas-l-user."""
return clas(*args, **kw)
# cached_clas = memory.cache(cached_clas)
@memory.cache
def encode(*args, **kw):
"""Cache model_s.encode."""
return model_s.encode(*args, **kw)
def uclas(
seq: str,
labels: Union[List[str], np.ndarray, Tuple[str, ...]],
thresh: float = 0.5,
multi_label: bool = False,
) -> Tuple[str, Union[float, str]]:
"""Classify seq with a filter.
if clas > thresh, return
if clas * csim > thresh return
if csim > thresh return
return ""
"""
# _ = clas(seq, labels, multi_label=multi_label)
_ = cached_clas(seq, labels, multi_label=multi_label)
logger.debug("1 %s, %s", _.get("labels")[0], round(_.get("scores")[0], 2))
if _.get("scores")[0] > thresh:
return _.get("labels")[0], round(_.get("scores")[0], 2)
_ = dict(zip(_.get("labels"), _.get("scores")))
corr = np.array([_.get(elm) for elm in labels])
csim = cosine_similarity(encode([seq]), encode(labels))
corr = corr * csim
logger.debug("2 %s, %s", corr.argmax(), round(corr.max(), 2))
if corr.max() > thresh:
return labels[corr.argmax()], round(corr.max(), 2)
logger.debug("3 %s, %s, %s", csim.argmax(), round(csim.max(), 2), thresh / 2)
logger.debug("T or F: %s", csim.max() > (thresh / 2))
if csim.max() > (thresh / 2):
return labels[csim.argmax()], round(csim.max(), 2)
return "", ""
|