sho commited on
Commit
d687de3
2 Parent(s): 76b1411 b4c2583

Merge pull request #1 from soulteary/main

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -56,21 +56,34 @@ miyazaki_model = Transformer()
56
  kon_model = Transformer()
57
 
58
  enable_gpu = torch.cuda.is_available()
59
- map_location = torch.device("cuda") if enable_gpu else "cpu"
 
 
 
 
 
 
 
60
 
61
  shinkai_model.load_state_dict(
62
- torch.load(shinkai_model_hfhub, map_location=map_location)
63
  )
64
  hosoda_model.load_state_dict(
65
- torch.load(hosoda_model_hfhub, map_location=map_location)
66
  )
67
  miyazaki_model.load_state_dict(
68
- torch.load(miyazaki_model_hfhub, map_location=map_location)
69
  )
70
  kon_model.load_state_dict(
71
- torch.load(kon_model_hfhub, map_location=map_location)
72
  )
73
 
 
 
 
 
 
 
74
  shinkai_model.eval()
75
  hosoda_model.eval()
76
  miyazaki_model.eval()
@@ -118,7 +131,8 @@ def inference(img, style):
118
 
119
  if enable_gpu:
120
  logger.info(f"CUDA found. Using GPU.")
121
- input_image = Variable(input_image).cuda()
 
122
  else:
123
  logger.info(f"CUDA not found. Using CPU.")
124
  input_image = Variable(input_image).float()
56
  kon_model = Transformer()
57
 
58
  enable_gpu = torch.cuda.is_available()
59
+
60
+ if enable_gpu:
61
+ # If you have multiple cards,
62
+ # you can assign to a specific card, eg: "cuda:0"("cuda") or "cuda:1"
63
+ # Use the first card by default: "cuda"
64
+ device = torch.device("cuda")
65
+ else:
66
+ device = "cpu"
67
 
68
  shinkai_model.load_state_dict(
69
+ torch.load(shinkai_model_hfhub, device)
70
  )
71
  hosoda_model.load_state_dict(
72
+ torch.load(hosoda_model_hfhub, device)
73
  )
74
  miyazaki_model.load_state_dict(
75
+ torch.load(miyazaki_model_hfhub, device)
76
  )
77
  kon_model.load_state_dict(
78
+ torch.load(kon_model_hfhub, device)
79
  )
80
 
81
+ if enable_gpu:
82
+ shinkai_model = shinkai_model.to(device)
83
+ hosoda_model = hosoda_model.to(device)
84
+ miyazaki_model = miyazaki_model.to(device)
85
+ kon_model = kon_model.to(device)
86
+
87
  shinkai_model.eval()
88
  hosoda_model.eval()
89
  miyazaki_model.eval()
131
 
132
  if enable_gpu:
133
  logger.info(f"CUDA found. Using GPU.")
134
+ # Allows to specify a card for calculation
135
+ input_image = Variable(input_image).to(device)
136
  else:
137
  logger.info(f"CUDA not found. Using CPU.")
138
  input_image = Variable(input_image).float()