Spaces:
Build error
Build error
bugfix: fix cuda: + pytorch latest version, input type and weight type should be the same
Browse files
app.py
CHANGED
@@ -56,21 +56,34 @@ miyazaki_model = Transformer()
|
|
56 |
kon_model = Transformer()
|
57 |
|
58 |
enable_gpu = torch.cuda.is_available()
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
shinkai_model.load_state_dict(
|
62 |
-
torch.load(shinkai_model_hfhub,
|
63 |
)
|
64 |
hosoda_model.load_state_dict(
|
65 |
-
torch.load(hosoda_model_hfhub,
|
66 |
)
|
67 |
miyazaki_model.load_state_dict(
|
68 |
-
torch.load(miyazaki_model_hfhub,
|
69 |
)
|
70 |
kon_model.load_state_dict(
|
71 |
-
torch.load(kon_model_hfhub,
|
72 |
)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
shinkai_model.eval()
|
75 |
hosoda_model.eval()
|
76 |
miyazaki_model.eval()
|
@@ -118,7 +131,8 @@ def inference(img, style):
|
|
118 |
|
119 |
if enable_gpu:
|
120 |
logger.info(f"CUDA found. Using GPU.")
|
121 |
-
|
|
|
122 |
else:
|
123 |
logger.info(f"CUDA not found. Using CPU.")
|
124 |
input_image = Variable(input_image).float()
|
|
|
56 |
kon_model = Transformer()
|
57 |
|
58 |
enable_gpu = torch.cuda.is_available()
|
59 |
+
|
60 |
+
if enable_gpu:
|
61 |
+
# If you have multiple cards,
|
62 |
+
# you can assign to a specific card, eg: "cuda:0"("cuda") or "cuda:1"
|
63 |
+
# Use the first card by default: "cuda"
|
64 |
+
device = torch.device("cuda")
|
65 |
+
else:
|
66 |
+
device = "cpu"
|
67 |
|
68 |
shinkai_model.load_state_dict(
|
69 |
+
torch.load(shinkai_model_hfhub, device)
|
70 |
)
|
71 |
hosoda_model.load_state_dict(
|
72 |
+
torch.load(hosoda_model_hfhub, device)
|
73 |
)
|
74 |
miyazaki_model.load_state_dict(
|
75 |
+
torch.load(miyazaki_model_hfhub, device)
|
76 |
)
|
77 |
kon_model.load_state_dict(
|
78 |
+
torch.load(kon_model_hfhub, device)
|
79 |
)
|
80 |
|
81 |
+
if enable_gpu:
|
82 |
+
shinkai_model = shinkai_model.to(device)
|
83 |
+
hosoda_model = hosoda_model.to(device)
|
84 |
+
miyazaki_model = miyazaki_model.to(device)
|
85 |
+
kon_model = kon_model.to(device)
|
86 |
+
|
87 |
shinkai_model.eval()
|
88 |
hosoda_model.eval()
|
89 |
miyazaki_model.eval()
|
|
|
131 |
|
132 |
if enable_gpu:
|
133 |
logger.info(f"CUDA found. Using GPU.")
|
134 |
+
# Allows to specify a card for calculation
|
135 |
+
input_image = Variable(input_image).to(device)
|
136 |
else:
|
137 |
logger.info(f"CUDA not found. Using CPU.")
|
138 |
input_image = Variable(input_image).float()
|