import sys
sys.path.append('../XGBoost_Prediction_Model/')

import warnings
warnings.filterwarnings("ignore")
import Predict
import torch
import numpy as np
import os
from os.path import isfile, isdir, join
from Magazine_Optimization import *
import time

mypath = '../XGBoost_Prediction_Model/Magazine_Optimization_Demo/Magazines'
results = {}

for f in os.listdir(mypath):
    if isdir(join(mypath, f)):
        print('Currently processing Magazine '+f+'......')
        path_temp = join(mypath, f)
        dir_list = []
        for sub_f in os.listdir(path_temp):
            if isdir(join(path_temp, sub_f)):
                sub_path_temp = join(path_temp, sub_f)
                if (sub_f.split('_')[0]) == 'Jpg':
                    dir_list = os.listdir(sub_path_temp)
                    dir_list.sort()
                    for i in range(len(dir_list)):
                        dir_list[i] = join(sub_path_temp,dir_list[i])
                else:
                    Slots = torch.load(join(sub_path_temp,'Slots')).astype('int32')
                    Sizes = torch.load(join(sub_path_temp,'surfaces'))
                    Product_Groups = torch.load(join(sub_path_temp,'Prod_Cat'))
                    Textboxes = torch.load(join(sub_path_temp,'Textboxes'))
                    Obj_and_Topics = torch.load(join(sub_path_temp,'Obj_and_Topics'))
                    Ad_embeddings = torch.load(join(sub_path_temp,'Ad_Emb'))
                    Ctpg_embeddings = torch.load(join(sub_path_temp,'Ctpg_Emb'))

        start = time.time()
        Costs = np.ones(len(dir_list)) #np.arange(len(dir_list),0,-1)
        Ad_Gaze, Brand_Gaze, Double_Page_Ad_Attention, Double_Page_Brand_Attention, Assign_ids = Preference_Matrix(dir_list, Slots, Product_Groups, Sizes, 
                                                                                                                   Ad_embeddings=Ad_embeddings, Ctpg_embeddings=Ctpg_embeddings,
                                                                                                                   Textboxes=Textboxes, Obj_and_Topics=Obj_and_Topics,
                                                                                                                   Costs=Costs, Method='XGBoost')
        
        #Assignement Problem
        workers = []
        jobs = []
        N = (Ad_Gaze.shape)[0]
        for i in range(N):
            workers.append(i+1)
            jobs.append(i+1)

        max_ad_attention = np.max(Ad_Gaze)
        max_brand_attention = np.max(Brand_Gaze)
        Ad_Gaze_cost = max_ad_attention - Ad_Gaze
        Brand_Gaze_cost = max_brand_attention - Brand_Gaze

        Prob_solved_Ad = Assignment_Problem(Ad_Gaze_cost, workers, jobs)
        Prob_solved_Brand = Assignment_Problem(Brand_Gaze_cost, workers, jobs)
        end = time.time()
        assigning_time = end-start

        # Print the variables optimized value
        print('If based on maximizing Overall Ad Attention: ')
        strategy_AG = ''
        BG_under_AG_assignment = 0
        for v in Prob_solved_Ad.variables():
            if v.varValue == 1:
                curr = (v.name).split('_')
                temp = curr[0]+' Ad '+str(Assign_ids[int(curr[1])-1])+' to Counterpage '+str(Assign_ids[int(curr[2])-1])
                BG_under_AG_assignment += Brand_Gaze_cost[int(curr[1])-1,int(curr[2])-1]
                strategy_AG += temp+'; '
                print(temp)
            
        # The optimised objective function value is printed to the screen
        m_ad = N*max_ad_attention - value(Prob_solved_Ad.objective) + sum(Double_Page_Ad_Attention)
        print("Maximized Ad Attention = ", m_ad, " sec.")
        print("Maximized Average Ad attention on each Ad = ", (N*max_ad_attention - value(Prob_solved_Ad.objective) + sum(Double_Page_Ad_Attention))/(N + len(Double_Page_Ad_Attention)), " sec.")
        print("Original Ad Attention = ", np.trace(Ad_Gaze), " sec.")
        print()

        # Print the variables optimized value
        print('If based on maximizing Overall Brand Attention: ')
        strategy_BG = ''
        for v in Prob_solved_Brand.variables():
            if v.varValue == 1:
                curr = (v.name).split('_')
                temp = curr[0]+' Ad '+str(Assign_ids[int(curr[1])-1])+' to Counterpage '+str(Assign_ids[int(curr[2])-1])
                strategy_BG += temp+'; '
                print(temp)
            
        # The optimised objective function value is printed to the screen
        m_brand = N*max_brand_attention - value(Prob_solved_Brand.objective) + sum(Double_Page_Brand_Attention)
        BG_under_AG_assignment = N*max_brand_attention - BG_under_AG_assignment + sum(Double_Page_Brand_Attention)
        print("Maximized Brand Attention = ", m_brand, " sec.")
        print("New Brand Gaze under AG assignment = ", BG_under_AG_assignment, " sec.")
        print("Maximized Average Brand attention on each Ad = ", (N*max_brand_attention - value(Prob_solved_Brand.objective) + sum(Double_Page_Brand_Attention))/(N + len(Double_Page_Brand_Attention)), " sec.")
        print("Original Brand Attention = ", np.trace(Brand_Gaze), " sec.")
        print('End of Magazine '+f+'......')

        results[f] = {'AG':[strategy_AG,m_ad,np.trace(Ad_Gaze)], 'BG':[strategy_BG,m_brand,np.trace(Brand_Gaze),BG_under_AG_assignment], 'Time':assigning_time}
        print()
        print()

print()
torch.save(results, '../XGBoost_Prediction_Model/Magazine_Optimization_Demo/results')
# torch.save(results, '../XGBoost_Prediction_Model/Magazine_Optimization_Demo/results_CNN')
print('Summary: ')
for f in list(results.keys()):
    print('Magazine '+f+': ')
    dict_curr = results[f]
    print('Total Time used: ',dict_curr['Time'])
    print('Ad Gaze: ')
    print('Strategy: '+dict_curr['AG'][0])
    print('max Attention: ',dict_curr['AG'][1])
    print('original Attention: ',dict_curr['AG'][2])
    print('Improvement: ', (dict_curr['AG'][1]-dict_curr['AG'][2])/dict_curr['AG'][2]*100)
    print('------------------------')
    print('Brand Gaze: ')
    print('Strategy: '+dict_curr['BG'][0])
    print('max Attention: ',dict_curr['BG'][1])
    print('original Attention: ',dict_curr['BG'][2])
    print('Attention under AG Assignment: ',dict_curr['BG'][3])
    print('Improvement: ', (dict_curr['BG'][3]-dict_curr['BG'][2])/dict_curr['BG'][2]*100)
    print('------------------------')
    print()