Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import torch | |
| from torch.utils.data import DataLoader | |
| import json | |
| import os | |
| from nets.model import Model | |
| from Actor.actor import Actor | |
| from dataloader import VRP_Dataset | |
| # --- تحميل الإعدادات --- | |
| with open('/data/params_saved.json', 'r') as f: | |
| params = json.load(f) | |
| # --- تعيين الجهاز --- | |
| device = params['device'] | |
| dataset_path = params['dataset_path'] | |
| input_size = None # سيتم تحديده بعد تحميل البيانات | |
| # --- تحميل نموذج مدرب --- | |
| model_path = "/data/model_state_dict.pt" | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model not found at {model_path}") | |
| # --- إعداد بيانات عشوائية للاختبار --- | |
| inference_dataset = VRP_Dataset( | |
| size=1, | |
| num_nodes=params['num_nodes'], | |
| num_depots=params['num_depots'], | |
| path=dataset_path, | |
| device=device | |
| ) | |
| input_size = inference_dataset.model_input_length() | |
| # --- تحميل النموذج --- | |
| model = Model( | |
| input_size=input_size, | |
| embedding_size=params["embedding_size"], | |
| decoder_input_size=params["decoder_input_size"] | |
| ) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| # --- تهيئة الممثل (Actor) والـ NN Actor --- | |
| actor = Actor(model=model, | |
| num_movers=params['num_movers'], | |
| num_neighbors_encoder=params['num_neighbors_encoder'], | |
| num_neighbors_action=params['num_neighbors_action'], | |
| device=device, | |
| normalize=False) | |
| actor.eval_mode() | |
| nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device) | |
| nn_actor.nearest_neighbors() | |
| # --- تنفيذ الاستدلال على دفعة واحدة --- | |
| dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate) | |
| for batch in dataloader: | |
| with torch.no_grad(): | |
| actor.greedy_search() | |
| actor_output = actor(batch) | |
| total_time = actor_output['total_time'].item() | |
| nn_output = nn_actor(batch) | |
| nn_time = nn_output['total_time'].item() | |
| print("\n===== INFERENCE RESULT =====") | |
| print(f"Actor Model Total Cost: {total_time:.4f}") | |
| print(f"Nearest Neighbor Cost : {nn_time:.4f}") | |
| print(f"Improvement over NN : {(nn_time - total_time) / nn_time * 100:.2f}%") | |