|
import torch |
|
from torch import nn |
|
from sentence_transformers import SentenceTransformer |
|
from regressor import * |
|
import numpy as np |
|
import os |
|
|
|
ENCODER = os.getenv("ENCODER") |
|
|
|
class NextUsRegressor(nn.Module): |
|
def __init__(self): |
|
super(NextUsRegressor, self).__init__() |
|
self.embedder = SentenceTransformer(ENCODER) |
|
|
|
|
|
self.regressor = WRegressor() |
|
|
|
return |
|
|
|
def forward(self, txts): |
|
if type(txts) == str: |
|
txts = [txts] |
|
embedded = self.embedder.encode(np.array(txts)) |
|
|
|
|
|
embedded_tensor = torch.tensor(embedded, dtype=torch.float32) |
|
regressed = self.regressor(embedded_tensor) |
|
|
|
|
|
|
|
vals = regressed.flatten().tolist() |
|
|
|
strs = list() |
|
for t, v in list(zip(txts, vals)): |
|
strs.append(str(round(v, 4)) + "\t" + t[:20]) |
|
return "\n".join(strs) |
|
|
|
|