behnamsa commited on
Commit
fcb30fd
1 Parent(s): 0d4e2c5

Add pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +64 -0
pipeline.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from scipy.special import softmax
2
+ import tensorflow as tf
3
+
4
+ class PreTrainedPipeline():
5
+ def __init__(self):
6
+ sequence_input = tf.keras.Input(shape=(300), name='input')
7
+ x = tf.keras.layers.Dense(2048, activation="LeakyReLU")(sequence_input)
8
+ x = tf.keras.layers.Dense(1024, activation="LeakyReLU")(x)
9
+ x = tf.keras.layers.Dense(512, activation="LeakyReLU")(x)
10
+ x = tf.keras.layers.Dense(128, activation="LeakyReLU")(x)
11
+ x = tf.keras.layers.Dense(512, activation="LeakyReLU")(x)
12
+ x = tf.keras.layers.Dense(1024, activation="LeakyReLU")(x)
13
+ x = tf.keras.layers.Dense(2048, activation="LeakyReLU")(x)
14
+ outputs = tf.keras.layers.Dense(300, activation="tanh")(x)
15
+
16
+ model = tf.keras.Model(sequence_input, outputs)
17
+
18
+ model.compile(optimizer="Adamax", loss="cosine_similarity")
19
+
20
+ def __call__(self):
21
+ return {
22
+ "text": "Hi!!!"
23
+ }
24
+
25
+ # def RevDict(sent,flag,model):
26
+ # """
27
+ # This function recieves a sentence from the user, and turns back top_10 (for flag=0) or top_100 (for flag=1) predictions.
28
+ # the input sentence will be normalized, and stop words will be removed
29
+ # """
30
+
31
+ # normalizer = Normalizer()
32
+ # X_Normalized = normalizer.normalize(sent)
33
+ # X_Tokens = word_tokenize(X_Normalized)
34
+ # stopwords = [normalizer.normalize(x.strip()) for x in codecs.open(r"stopwords.txt",'r','utf-8').readlines()]
35
+ # X_Tokens = [t for t in X_Tokens if t not in stopwords]
36
+ # preprocessed = [' '.join(X_Tokens)][0]
37
+ # sent_ids = sent2id([preprocessed])
38
+ # output=np.array((model.predict(sent_ids.reshape((1,20))).tolist()[0]))
39
+ # distances=distance.cdist(output.reshape((1,300)), comparison_matrix, "cosine")[0]
40
+ # min_index_100 = distances.argsort()[:100]
41
+ # min_index_10 = distances.argsort()[:10]
42
+
43
+ # temp=[]
44
+ # if flag == 0:
45
+ # for i in range(10):
46
+ # temp.append(id2h[str(min_index_10[i])])
47
+ # elif flag == 1:
48
+ # for i in range(100):
49
+ # temp.append(id2h[str(min_index_100[i])])
50
+
51
+ # for i in range(len(temp)):
52
+ # print(temp[i])
53
+
54
+ # def sent2id(sents):
55
+ # sents_id=np.zeros((len(sents),20))
56
+ # for j in tqdm(range(len(sents))):
57
+ # for i,word in enumerate(sents[j].split()):
58
+ # try:
59
+ # sents_id[j,i] = t2id[word]
60
+ # except:
61
+ # sents_id[j,i] = t2id['UNK']
62
+ # if i==19:
63
+ # break
64
+ # return sents_id