Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.nn.utils import clip_grad_norm_ | |
| from just_time_windows.Actor.actor import Actor | |
| def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, comparison_model=None, compute_cost_ratio=True): | |
| device = actor.device | |
| actor.train_mode() | |
| actor.train() | |
| actor_output = actor(batch) | |
| actor_cost, log_probs = actor_output['total_time'], actor_output['log_probs'] | |
| with torch.no_grad(): | |
| baseline.greedy_search() | |
| baseline_output = baseline(batch) | |
| baseline_cost = baseline_output['total_time'] | |
| loss = ((actor_cost - baseline_cost).detach() * log_probs).mean() | |
| optimizer.zero_grad() | |
| loss.backward() | |
| if gradient_clipping: | |
| for group in optimizer.param_groups: | |
| params = [p for p in group['params'] if p.grad is not None] | |
| if params: | |
| clip_grad_norm_(params, max_norm=1, norm_type=2) | |
| optimizer.step() | |
| if compute_cost_ratio: | |
| if comparison_model is None: | |
| normalize = actor.apply_normalization | |
| comparison_model = Actor(model=None, num_neighbors_action=1, normalize=normalize, device=device) | |
| with torch.no_grad(): | |
| comp_output = comparison_model(batch) | |
| comp_cost = comp_output['total_time'] | |
| a = comp_cost.sum().item() | |
| b = actor_cost.sum().item() | |
| return b / a | |
| return None | |