Spaces:
Sleeping
Sleeping
"""Gen cmat for de/en text.""" | |
# pylint: disable=invalid-name, too-many-branches | |
from typing import List, Optional | |
import more_itertools as mit | |
import numpy as np | |
# from logzero import logger | |
from loguru import logger | |
from tqdm import tqdm | |
# from model_pool import load_model_s | |
# from hf_model_s_cpu import model_s # load_model_s directly | |
from st_mlbee.load_model_s import load_model_s | |
# from st_mlbee.cos_matrix2 import cos_matrix2 | |
from .cos_matrix2 import cos_matrix2 | |
_ = """ | |
try: | |
model_s = load_model_s() | |
except Exception as exc: | |
logger.erorr(exc) | |
raise | |
""" | |
try: | |
# model = model_s() | |
# model = model_s(alive_bar_on=True) | |
# default model-s mikeee/model_s_512 | |
model_s = load_model_s() | |
# model_s_v2 = load_model_s("model_s_512v2") # model-s mikeee/model-s-512v2 | |
except Exception as _: | |
logger.error(_) | |
raise | |
def gen_cmat( | |
text1: List[str], | |
text2: List[str], | |
bsize: int = 32, # default batch_size of model.encode | |
model=None, | |
) -> np.ndarray: | |
"""Gen corr matrix for texts. | |
Args: | |
---- | |
text1: typically '''...''' splitlines() | |
text2: typically '''...''' splitlines() | |
bsize: batch size, default 50 | |
model: for encoding list of strings, default model-s of mikeee/model_s_512 | |
text1 = 'this is a test' | |
text2 = 'another test' | |
Returns: | |
------- | |
numpy array of cmat | |
""" | |
if model is None: | |
model = model_s | |
bsize = int(bsize) | |
if bsize <= 0: | |
bsize = 32 | |
if isinstance(text1, str): | |
text1 = [text1] | |
if isinstance(text2, str): | |
text1 = [text2] | |
vec1 = [] | |
vec2 = [] | |
len1 = len(text1) | |
len2 = len(text2) | |
tot = len1 // bsize + bool(len1 % bsize) | |
tot += len2 // bsize + bool(len2 % bsize) | |
with tqdm(total=tot) as pbar: | |
for chunk in mit.chunked(text1, bsize): | |
try: | |
vec = model.encode(chunk) | |
except Exception as exc: | |
logger.error(exc) | |
raise | |
vec1.extend(vec) | |
pbar.update() | |
for chunk in mit.chunked(text2, bsize): | |
try: | |
vec = model.encode(chunk) | |
except Exception as exc: | |
logger.error(exc) | |
raise | |
vec2.extend(vec) | |
pbar.update() | |
try: | |
# note the order vec2, vec1 | |
_ = cos_matrix2(np.array(vec2), np.array(vec1)) | |
except Exception as exc: | |
logger.exception(exc) | |
raise | |
return _ | |