import torch from torch import nn from sentence_transformers import SentenceTransformer from regressor import * import numpy as np import os ENCODER = os.getenv("ENCODER") class NextUsRegressor(nn.Module): def __init__(self): super(NextUsRegressor, self).__init__() self.embedder = SentenceTransformer(ENCODER) self.regressor = WRegressor() return def forward(self, txts): if type(txts) == str: txts = [txts] embedded = self.embedder.encode(np.array(txts)) # embedded_tensor = self.embedder(np.array(txts)) embedded_tensor = torch.tensor(embedded, dtype=torch.float32) regressed = self.regressor(embedded_tensor) # return regressed.tolist() # TODO: actually handle list of strings vals = regressed.flatten().tolist() # must return the whole thing, not just the 0-th element # strs = list() # for t, v in list(zip(txts, vals)): # strs.append(str(round(v, 4)) + "\t" + t[:100]) return "\n".join([str(round(s, 4)) for s in vals]) # return torch.tensor(val).unsqueeze(1)