File size: 2,154 Bytes
844aef2
 
 
 
 
 
d7cdc67
 
844aef2
d7cdc67
 
 
 
 
844aef2
090de1f
 
 
 
 
 
 
844aef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Define uclas."""
# pylint: disable=invalid-name

from typing import List, Tuple, Union

import logzero
import numpy as np
from joblib import Memory
from logzero import logger
# set PYTHONPATH=..\align-model-pool  # in win10
from model_pool.fetch_check_aux import fetch_check_aux
from model_pool.load_model import load_model
from model_pool.model_s import load_model_s
from sklearn.metrics.pairwise import cosine_similarity

# logzero.loglevel(20)

# fetch_check_aux("/home/user")
try:
    fetch_check_aux()
except Exception as _:
    logger.error(_)

model_s = load_model_s()
clas = load_model("clas-l-user")

location = "./cachedir"
memory = Memory(location, verbose=0)


@memory.cache
def cached_clas(*args, **kw):
    """Cache clas-l-user."""
    return clas(*args, **kw)


# cached_clas = memory.cache(cached_clas)


@memory.cache
def encode(*args, **kw):
    """Cache model_s.encode."""
    return model_s.encode(*args, **kw)


def uclas(
    seq: str,
    labels: Union[List[str], np.ndarray, Tuple[str, ...]],
    thresh: float = 0.5,
    multi_label: bool = False,
) -> Tuple[str, Union[float, str]]:
    """Classify seq with a filter.

    if clas > thresh, return
    if clas * csim > thresh return
    if csim > thresh return
    return ""
    """
    # _ = clas(seq, labels, multi_label=multi_label)
    _ = cached_clas(seq, labels, multi_label=multi_label)

    logger.debug("1 %s, %s", _.get("labels")[0], round(_.get("scores")[0], 2))

    if _.get("scores")[0] > thresh:
        return _.get("labels")[0], round(_.get("scores")[0], 2)

    _ = dict(zip(_.get("labels"), _.get("scores")))

    corr = np.array([_.get(elm) for elm in labels])

    csim = cosine_similarity(encode([seq]), encode(labels))

    corr = corr * csim

    logger.debug("2 %s, %s", corr.argmax(), round(corr.max(), 2))

    if corr.max() > thresh:
        return labels[corr.argmax()], round(corr.max(), 2)

    logger.debug("3 %s, %s, %s", csim.argmax(), round(csim.max(), 2), thresh / 2)

    logger.debug("T or F: %s", csim.max() > (thresh / 2))
    if csim.max() > (thresh / 2):
        return labels[csim.argmax()], round(csim.max(), 2)

    return "", ""