vrp-shanghai-transformer / inference.py
a-ragab-h-m's picture
Create inference.py
17b50e6 verified
raw
history blame
2.3 kB
import torch
from torch.utils.data import DataLoader
import json
import os
from nets.model import Model
from Actor.actor import Actor
from dataloader import VRP_Dataset
# --- تحميل الإعدادات ---
with open('/data/params_saved.json', 'r') as f:
params = json.load(f)
# --- تعيين الجهاز ---
device = params['device']
dataset_path = params['dataset_path']
input_size = None # سيتم تحديده بعد تحميل البيانات
# --- تحميل نموذج مدرب ---
model_path = "/data/model_state_dict.pt"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found at {model_path}")
# --- إعداد بيانات عشوائية للاختبار ---
inference_dataset = VRP_Dataset(
size=1,
num_nodes=params['num_nodes'],
num_depots=params['num_depots'],
path=dataset_path,
device=device
)
input_size = inference_dataset.model_input_length()
# --- تحميل النموذج ---
model = Model(
input_size=input_size,
embedding_size=params["embedding_size"],
decoder_input_size=params["decoder_input_size"]
)
model.load_state_dict(torch.load(model_path, map_location=device))
# --- تهيئة الممثل (Actor) والـ NN Actor ---
actor = Actor(model=model,
num_movers=params['num_movers'],
num_neighbors_encoder=params['num_neighbors_encoder'],
num_neighbors_action=params['num_neighbors_action'],
device=device,
normalize=False)
actor.eval_mode()
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
nn_actor.nearest_neighbors()
# --- تنفيذ الاستدلال على دفعة واحدة ---
dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
for batch in dataloader:
with torch.no_grad():
actor.greedy_search()
actor_output = actor(batch)
total_time = actor_output['total_time'].item()
nn_output = nn_actor(batch)
nn_time = nn_output['total_time'].item()
print("\n===== INFERENCE RESULT =====")
print(f"Actor Model Total Cost: {total_time:.4f}")
print(f"Nearest Neighbor Cost : {nn_time:.4f}")
print(f"Improvement over NN : {(nn_time - total_time) / nn_time * 100:.2f}%")