persian-reverse-dict / pipeline.py
behnamsa's picture
Update pipeline.py
e022928
raw
history blame
No virus
4.1 kB
# from scipy.special import softmax
import tensorflow as tf
from transformers import Pipeline
import tensorflow as tf
import numpy as np
import json
from hazm import *
from scipy.spatial import distance
class PreTrainedPipeline():
def __init__(self, path):
self.model_dir = path + "/saved_model"
self.t2id_path = path + "/t2id.json"
self.id2h_path = path + "/id2h.json"
self.stopwords_path = path + "/stopwords.txt"
self.comparison_matrix_path = path + "/comparison_matrix.npz"
self.t2id = json.load(open(self.t2id_path,encoding="utf8"))
self.id2h = json.load(open(self.id2h_path,encoding="utf8"))
self.stopwords = set(line.strip() for line in open(self.stopwords_path,encoding="utf8"))
self.comparisons = np.load(self.comparison_matrix_path)['arr_0']
self.model = tf.saved_model.load(self.model_dir)
def __call__(self, inputs: str):
# Preprocess the input sentence
sentence = Normalizer().normalize(inputs)
tokens = word_tokenize(sentence)
tokens = [t for t in tokens if t not in self.stopwords]
input_ids = np.zeros((1, 20))
for i, token in enumerate(tokens):
if i >= 20:
break
input_ids[0, i] = self.t2id.get(token, self.t2id['UNK'])
# Call the model on the input ids
embeddings = self.model(tf.constant(input_ids, dtype=tf.int32)).numpy()
# Postprocess the embeddings to get the most similar words
similarities = distance.cdist(embeddings.reshape((1,300)), self.comparisons, "cosine")[0]
top_indices = similarities.argsort()[:10]
top_words = [[self.id2h[str(top_indices[i])]] for i in range(10)]
return [
[
{'label': top_words[0], 'score': 0},
{'label': top_words[1], 'score': 0},
{'label': top_words[2], 'score': 0},
{'label': top_words[3], 'score': 0},
]
]
# return [
# [ # Sample output, call the model here TODO
# {'label': 'POSITIVE', 'score': 0.05},
# {'label': 'NEGATIVE', 'score': 0.03},
# {'label': 'معنی', 'score': 0.92},
# {'label': f'{inputs}', 'score': 0},
# ]
# ]
# def RevDict(sent,flag,model):
# """
# This function recieves a sentence from the user, and turns back top_10 (for flag=0) or top_100 (for flag=1) predictions.
# the input sentence will be normalized, and stop words will be removed
# """
# normalizer = Normalizer()
# X_Normalized = normalizer.normalize(sent)
# X_Tokens = word_tokenize(X_Normalized)
# stopwords = [normalizer.normalize(x.strip()) for x in codecs.open(r"stopwords.txt",'r','utf-8').readlines()]
# X_Tokens = [t for t in X_Tokens if t not in stopwords]
# preprocessed = [' '.join(X_Tokens)][0]
# sent_ids = sent2id([preprocessed])
# output=np.array((model.predict(sent_ids.reshape((1,20))).tolist()[0]))
# distances=distance.cdist(output.reshape((1,300)), comparison_matrix, "cosine")[0]
# min_index_100 = distances.argsort()[:100]
# min_index_10 = distances.argsort()[:10]
# temp=[]
# if flag == 0:
# for i in range(10):
# temp.append(id2h[str(min_index_10[i])])
# elif flag == 1:
# for i in range(100):
# temp.append(id2h[str(min_index_100[i])])
# for i in range(len(temp)):
# print(temp[i])
# def sent2id(sents):
# sents_id=np.zeros((len(sents),20))
# for j in tqdm(range(len(sents))):
# for i,word in enumerate(sents[j].split()):
# try:
# sents_id[j,i] = t2id[word]
# except:
# sents_id[j,i] = t2id['UNK']
# if i==19:
# break
# return sents_id