Spaces:
Build error
Build error
| import datetime | |
| import os | |
| import random | |
| import time | |
| import plotly.figure_factory as ff | |
| import json | |
| import pandas as pd | |
| from compiled_jss.CPEnv import CompiledJssEnvCP | |
| from stable_baselines3.common.vec_env import VecEnvWrapper | |
| from torch.distributions import Categorical | |
| import torch | |
| import numpy as np | |
| from MyDummyVecEnv import MyDummyVecEnv | |
| import gradio as gr | |
| class VecPyTorch(VecEnvWrapper): | |
| def __init__(self, venv, device): | |
| super(VecPyTorch, self).__init__(venv) | |
| self.device = device | |
| def reset(self): | |
| return self.venv.reset() | |
| def step_async(self, actions): | |
| self.venv.step_async(actions) | |
| def step_wait(self): | |
| return self.venv.step_wait() | |
| def make_env(seed, instance): | |
| def thunk(): | |
| _env = CompiledJssEnvCP(instance) | |
| return _env | |
| return thunk | |
| def solve(file, num_workers, seed): | |
| seed = int(abs(seed)) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| with torch.inference_mode(): | |
| device = torch.device('cpu') | |
| actor = torch.jit.load('actor.pt', map_location=device) | |
| actor.eval() | |
| start_time = time.time() | |
| fn_env = [make_env(0, file.name) | |
| for _ in range(num_workers)] | |
| async_envs = MyDummyVecEnv(fn_env, device) | |
| envs = VecPyTorch(async_envs, device) | |
| current_solution_cost = float('inf') | |
| current_solution = '' | |
| obs = envs.reset() | |
| total_episode = 0 | |
| while total_episode < envs.num_envs: | |
| logits = actor(obs['interval_rep'], obs['attention_interval_mask'], obs['job_resource_mask'], | |
| obs['action_mask'], obs['index_interval'], obs['start_end_tokens']) | |
| # temperature vector | |
| if num_workers >= 4: | |
| temperature = torch.arange(0.5, 2.0, step=(1.5 / num_workers), device=device) | |
| else: | |
| temperature = torch.ones(num_workers, device=device) | |
| logits = logits / temperature[:, None] | |
| probs = Categorical(logits=logits).probs | |
| # random sample based on logits | |
| actions = torch.multinomial(probs, probs.shape[1]).cpu().numpy() | |
| obs, reward, done, infos = envs.step(actions) | |
| total_episode += done.sum() | |
| # total_actions += 1 | |
| # print(f'Episode {total_episode} / {envs.num_envs} - Actions {total_actions}', end='\r') | |
| for env_idx, info in enumerate(infos): | |
| if 'makespan' in info and int(info['makespan']) < current_solution_cost: | |
| current_solution_cost = int(info['makespan']) | |
| current_solution = json.loads(info['solution']) | |
| total_time = time.time() - start_time | |
| pretty_output = "" | |
| for job_id in range(len(current_solution)): | |
| pretty_output += f"Job {job_id}: {current_solution[job_id]}\n" | |
| jobs_data = [] | |
| file.seek(0) | |
| line_str: str = file.readline() | |
| line_cnt: int = 1 | |
| jobs_count: int = 0 | |
| machines_count: int = 0 | |
| while line_str: | |
| data = [] | |
| split_data = line_str.split() | |
| if line_cnt == 1: | |
| jobs_count, machines_count = int(split_data[0]), int( | |
| split_data[1] | |
| ) | |
| else: | |
| i = 0 | |
| this_job_op_count = 0 | |
| while i < len(split_data): | |
| machine, op_time = int(split_data[i]), int(split_data[i + 1]) | |
| data.append((machine, op_time)) | |
| i += 2 | |
| this_job_op_count += 1 | |
| jobs_data.append(data) | |
| line_str = file.readline() | |
| line_cnt += 1 | |
| # convert to integer the current_solution | |
| current_solution = [[int(x) for x in y] for y in current_solution] | |
| df = [] | |
| for job_id in range(jobs_count): | |
| for task_id in range(len(current_solution[job_id])): | |
| dict_op = dict() | |
| dict_op["Task"] = "Job {}".format(job_id) | |
| start_sec = current_solution[job_id][task_id] | |
| finish_sec = start_sec + jobs_data[job_id][task_id][1] | |
| dict_op["Start"] = datetime.datetime.fromtimestamp(start_sec) | |
| dict_op["Finish"] = datetime.datetime.fromtimestamp(finish_sec) | |
| dict_op["Resource"] = "Machine {}".format( | |
| jobs_data[job_id][task_id][0] | |
| ) | |
| df.append(dict_op) | |
| i += 1 | |
| fig = None | |
| colors = [ | |
| tuple([random.random() for _ in range(3)]) for _ in range(machines_count) | |
| ] | |
| if len(df) > 0: | |
| df = pd.DataFrame(df) | |
| fig = ff.create_gantt( | |
| df, | |
| index_col="Resource", | |
| colors=colors, | |
| show_colorbar=True, | |
| group_tasks=True, | |
| ) | |
| fig.update_yaxes( | |
| autorange=True | |
| ) | |
| return current_solution_cost, str(total_time) + " seconds", pretty_output, fig | |
| title = "Job-Shop Scheduling CP environment with RL dispatching" | |
| description = """A Job-Shop Scheduling Reinforcement Learning based solver using an underlying CP model as an | |
| environment. <br> | |
| For fast inference, | |
| check out the cached examples below.<br> Any Job-Shop Scheduling instance following the standard specification is | |
| compatible. <a href='http://jobshop.jjvh.nl/index.php'>Check out this website for more instances</a>.<br> | |
| Increasing the number of workers will provide better solutions, but will slow down the solving time. | |
| This behavior is different than the one from the paper repository as here agents are run sequentially, | |
| whereas we run agents in parallel (technical limitation due to the platform here). <br> | |
| <br> | |
| For large instance, we recommend running the approach locally outside the interface, as it causes a lot | |
| of overhead and the resource available on this platform are low (1 vCPU and no GPU).<br> """ | |
| article = "<p style='text-align: center'>Article Under Review</p>" | |
| # list all non-hidden files in the 'instances' directory | |
| examples = [['instances/' + f, 16, 0] for f in os.listdir('instances') if not f.startswith('.')] | |
| iface = gr.Interface(fn=solve, | |
| inputs=[gr.File(label="Instance File"), | |
| gr.Slider(8, 32, value=16, label="Number of Workers", step=1), | |
| gr.Number(0, label="Random Seed", precision=0)], | |
| outputs=[gr.Text(label="Makespan"), gr.Text(label="Elapsed Time"), gr.Text(label="Solution"), | |
| gr.Plot(label="Solution's Gantt Chart")], | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| allow_flagging="never") | |
| iface.launch(enable_queue=True) | |