mlbee / st_mlbee /gen_cmat.py
ffreemt
Add ruff.toml
b9d6157
raw
history blame
2.49 kB
"""Gen cmat for de/en text."""
# pylint: disable=invalid-name, too-many-branches
from typing import List, Optional
import more_itertools as mit
import numpy as np
# from logzero import logger
from loguru import logger
from tqdm import tqdm
# from model_pool import load_model_s
# from hf_model_s_cpu import model_s # load_model_s directly
from st_mlbee.load_model_s import load_model_s
# from st_mlbee.cos_matrix2 import cos_matrix2
from .cos_matrix2 import cos_matrix2
_ = """
try:
model_s = load_model_s()
except Exception as exc:
logger.erorr(exc)
raise
"""
try:
# model = model_s()
# model = model_s(alive_bar_on=True)
# default model-s mikeee/model_s_512
model_s = load_model_s()
# model_s_v2 = load_model_s("model_s_512v2") # model-s mikeee/model-s-512v2
except Exception as _:
logger.error(_)
raise
def gen_cmat(
text1: List[str],
text2: List[str],
bsize: int = 32, # default batch_size of model.encode
model=None,
) -> np.ndarray:
"""Gen corr matrix for texts.
Args:
----
text1: typically '''...''' splitlines()
text2: typically '''...''' splitlines()
bsize: batch size, default 50
model: for encoding list of strings, default model-s of mikeee/model_s_512
text1 = 'this is a test'
text2 = 'another test'
Returns:
-------
numpy array of cmat
"""
if model is None:
model = model_s
bsize = int(bsize)
if bsize <= 0:
bsize = 32
if isinstance(text1, str):
text1 = [text1]
if isinstance(text2, str):
text1 = [text2]
vec1 = []
vec2 = []
len1 = len(text1)
len2 = len(text2)
tot = len1 // bsize + bool(len1 % bsize)
tot += len2 // bsize + bool(len2 % bsize)
with tqdm(total=tot) as pbar:
for chunk in mit.chunked(text1, bsize):
try:
vec = model.encode(chunk)
except Exception as exc:
logger.error(exc)
raise
vec1.extend(vec)
pbar.update()
for chunk in mit.chunked(text2, bsize):
try:
vec = model.encode(chunk)
except Exception as exc:
logger.error(exc)
raise
vec2.extend(vec)
pbar.update()
try:
# note the order vec2, vec1
_ = cos_matrix2(np.array(vec2), np.array(vec1))
except Exception as exc:
logger.exception(exc)
raise
return _