File size: 2,864 Bytes
a1b524b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
diff --git a/mapper/latent_mappers.py b/mapper/latent_mappers.py
index 56b9c55..f0dd005 100644
--- a/mapper/latent_mappers.py
+++ b/mapper/latent_mappers.py
@@ -19,7 +19,7 @@ class ModulationModule(Module):
 
     def forward(self, x, embedding, cut_flag):
         x = self.fc(x)
-        x = self.norm(x) 	
+        x = self.norm(x)
         if cut_flag == 1:
             return x
         gamma = self.gamma_function(embedding.float())
@@ -39,20 +39,20 @@ class SubHairMapper(Module):
     def forward(self, x, embedding, cut_flag=0):
         x = self.pixelnorm(x)
         for modulation_module in self.modulation_module_list:
-        	x = modulation_module(x, embedding, cut_flag)        
+        	x = modulation_module(x, embedding, cut_flag)
         return x
 
-class HairMapper(Module): 
+class HairMapper(Module):
     def __init__(self, opts):
         super(HairMapper, self).__init__()
         self.opts = opts
-        self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda")
+        self.clip_model, self.preprocess = clip.load("ViT-B/32", device=opts.device)
         self.transform = transforms.Compose([transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
         self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
         self.hairstyle_cut_flag = 0
         self.color_cut_flag = 0
 
-        if not opts.no_coarse_mapper: 
+        if not opts.no_coarse_mapper:
             self.course_mapping = SubHairMapper(opts, 4)
         if not opts.no_medium_mapper:
             self.medium_mapping = SubHairMapper(opts, 4)
@@ -70,13 +70,13 @@ class HairMapper(Module):
         elif hairstyle_tensor.shape[1] != 1:
             hairstyle_embedding = self.gen_image_embedding(hairstyle_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach()
         else:
-            hairstyle_embedding = torch.ones(x.shape[0], 18, 512).cuda()
+            hairstyle_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device)
         if color_text_inputs.shape[1] != 1:
             color_embedding = self.clip_model.encode_text(color_text_inputs).unsqueeze(1).repeat(1, 18, 1).detach()
         elif color_tensor.shape[1] != 1:
             color_embedding = self.gen_image_embedding(color_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach()
         else:
-            color_embedding = torch.ones(x.shape[0], 18, 512).cuda()
+            color_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device)
 
 
         if (hairstyle_text_inputs.shape[1] == 1) and (hairstyle_tensor.shape[1] == 1):
@@ -106,4 +106,4 @@ class HairMapper(Module):
             x_fine = torch.zeros_like(x_fine)
 
         out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
-        return out
\ No newline at end of file
+        return out