Spaces:
Build error
Build error
import os | |
import numpy as np | |
import tensorflow as tf | |
import pandas as pd | |
import streamlit as st | |
import time | |
from recommenders.models.sasrec.model import SASREC | |
class SASREC_Vessl(SASREC): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def predict_next(self, input): | |
# seq generation | |
training = False | |
seq = np.zeros([self.seq_max_len], dtype=np.int32) | |
idx = self.seq_max_len - 1 | |
for i in input[::-1]: | |
seq[idx] = i | |
idx -= 1 | |
if idx == -1: | |
break | |
input_seq = np.array([seq]) | |
candidate = np.expand_dims(np.arange(1, self.item_num + 1, 1), axis=0) | |
mask = tf.expand_dims(tf.cast(tf.not_equal(input_seq, 0), tf.float32), | |
-1) | |
seq_embeddings, positional_embeddings = self.embedding(input_seq) | |
seq_embeddings += positional_embeddings | |
seq_embeddings *= mask | |
seq_attention = seq_embeddings | |
seq_attention = self.encoder(seq_attention, training, mask) | |
seq_attention = self.layer_normalization(seq_attention) # (b, s, d) | |
seq_emb = tf.reshape( | |
seq_attention, | |
[tf.shape(input_seq)[0] * self.seq_max_len, self.embedding_dim], | |
) # (b*s, d) | |
candidate_emb = self.item_embedding_layer(candidate) # (b, s, d) | |
candidate_emb = tf.transpose(candidate_emb, perm=[0, 2, 1]) # (b, d, s) | |
test_logits = tf.matmul(seq_emb, candidate_emb) | |
test_logits = tf.reshape( | |
test_logits, | |
[tf.shape(input_seq)[0], self.seq_max_len, self.item_num], | |
) | |
test_logits = test_logits[:, -1, :] # (1, 101) | |
predictions = np.array(test_logits)[0] | |
return predictions | |
def elapsed_time(fn, *args): | |
start = time.time() | |
output = fn(*args) | |
end = time.time() | |
elapsed = f'{end - start:.2f}' | |
return elapsed, output | |
def load_model(): | |
model_config = { | |
"MAXLEN": 50, | |
"NUM_BLOCKS": 2, # NUMBER OF TRANSFORMER BLOCKS | |
"HIDDEN_UNITS": 100, # NUMBER OF UNITS IN THE ATTENTION CALCULATION | |
"NUM_HEADS": 1, # NUMBER OF ATTENTION HEADS | |
"DROPOUT_RATE": 0.2, # DROPOUT RATE | |
"L2_EMB": 0.0, # L2 REGULARIZATION COEFFICIENT | |
"NUM_NEG_TEST": 100, | |
# NUMBER OF NEGATIVE EXAMPLES PER POSITIVE EXAMPLE | |
} | |
model = SASREC_Vessl( | |
item_num=12101, # should be changed according to dataset | |
seq_max_len=model_config.get("MAXLEN"), | |
num_blocks=model_config.get("NUM_BLOCKS"), | |
embedding_dim=model_config.get("HIDDEN_UNITS"), | |
attention_dim=model_config.get("HIDDEN_UNITS"), | |
attention_num_heads=model_config.get("NUM_HEADS"), | |
dropout_rate=model_config.get("DROPOUT_RATE"), | |
conv_dims=[100, 100], | |
l2_reg=model_config.get("L2_EMB"), | |
num_neg_test=model_config.get("NUM_NEG_TEST"), | |
) | |
if os.path.isfile('best.index') and os.path.isfile( | |
'best.data-00000-of-00001'): | |
model.load_weights('best').expect_partial() | |
return model | |
def postprocess_data(data): | |
predictions = -1 * data | |
rec_items = predictions.argsort()[:5] | |
dic_result = { | |
"Rank": [i for i in range(1, 6)], | |
"ItemID": list(rec_items + 1), | |
"Similarity Score": -1 * predictions[rec_items] | |
} | |
result = pd.DataFrame(dic_result) | |
time.sleep(0.5) | |
best_item = rec_items[0] + 1 | |
return result, best_item | |
def main(): | |
st.title("Self-Attentive Sequential Recommendation(SASRec)") | |
elapsed, model = elapsed_time(load_model) | |
st.write(f"Model is loaded in {elapsed} seconds!") | |
numbers = st.text_input( | |
label="Please write input items separated by comma. (e.g. 80, 70, 100, 1)") | |
if numbers: | |
integer_numbers = np.array(list(map(int, numbers.split(",")))) | |
result = model.predict_next(integer_numbers) | |
table, best_item = postprocess_data(result) | |
st.table(table) | |
st.write(f"Best item is {best_item}") | |
if __name__ == '__main__': | |
main() | |