Igor Santana
rnn model sent from github to huggingface
9c58361
import re
import os
import yaml
import pickle
import argparse
import pandas as pd
import numpy as np
import multiprocessing as mp
import project.evaluation.run as r
from os.path import exists
from datetime import datetime
from project.data.preprocess import preprocess, remove_sessions
from project.models.embeddings import embeddings
from project.evaluation.run import cross_validation
if __name__ == '__main__':
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
parser = argparse.ArgumentParser(description='RNN Embeddings')
parser.add_argument('--config', help='Configuration file', type=str)
args = parser.parse_args()
conf = yaml.safe_load(open(args.config))
print('The configuration file "%s" was read.' % args.config)
print('Pre-process started for dataset "%s"' % conf['evaluation']['dataset'])
preprocess(conf)
ds = conf['evaluation']['dataset']
df = pd.read_csv('dataset/{}/session_listening_history.csv'.format(ds), sep = ',')
emb_path = 'tmp/{}/models/ids.npy'.format(ds)
if not exists(emb_path):
embeddings(df, conf)
ids = np.load(emb_path)
cross_validation(df, conf, ids)