File size: 2,682 Bytes
42d4082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import torch
from tqdm import tqdm
from PTI.configs import paths_config, hyperparameters, global_config
from PTI.training.coaches.base_coach import BaseCoach
from PTI.utils.log_utils import log_images_from_w


class SingleIDCoach(BaseCoach):
    def __init__(self, data_loader, use_wandb):
        super().__init__(data_loader, use_wandb)

    def train(self):
        w_path_dir = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}"
        os.makedirs(w_path_dir, exist_ok=True)
        os.makedirs(f"{w_path_dir}/{paths_config.pti_results_keyword}", exist_ok=True)

        use_ball_holder = True
        w_pivot = None
        fname, image = next(iter(self.data_loader))
        print("NANANAN", fname)
        image_name = fname[0]

        self.restart_training()

        embedding_dir = f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}"
        os.makedirs(embedding_dir, exist_ok=True)

        if hyperparameters.use_last_w_pivots:
            w_pivot = self.load_inversions(w_path_dir, image_name)

        elif not hyperparameters.use_last_w_pivots or w_pivot is None:
            w_pivot = self.calc_inversions(image, image_name)
        torch.save(w_pivot, f"{embedding_dir}/0.pt")
        # w_pivot = w_pivot.detach().clone().to(global_config.device)
        w_pivot = w_pivot.to(global_config.device)

        log_images_counter = 0
        real_images_batch = image.to(global_config.device)

        for i in tqdm(range(hyperparameters.max_pti_steps)):
            generated_images = self.forward(w_pivot)
            loss, l2_loss_val, loss_lpips = self.calc_loss(
                generated_images,
                real_images_batch,
                image_name,
                self.G,
                use_ball_holder,
                w_pivot,
            )

            self.optimizer.zero_grad()

            if loss_lpips <= hyperparameters.LPIPS_value_threshold:
                break

            loss.backward()
            self.optimizer.step()

            use_ball_holder = (
                global_config.training_step
                % hyperparameters.locality_regularization_interval
                == 0
            )

            if (
                self.use_wandb
                and log_images_counter % global_config.image_rec_result_log_snapshot
                == 0
            ):
                log_images_from_w([w_pivot], self.G, [image_name])

            global_config.training_step += 1
            log_images_counter += 1

        torch.save(
            self.G,
            f"{paths_config.checkpoints_dir}/model_{global_config.run_name}_{image_name}.pt",
        )
        return self.G, w_pivot