hanquansanren commited on
Commit
16bb8a1
·
1 Parent(s): 6c4b492
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -11,6 +11,8 @@ import os
11
  os.environ["CUDA_VISIBLE_DEVICES"] = "7"
12
  # os.environ["OPENAI_LOGDIR"] = "./logs"
13
  # os.environ["MPI_DISABLED"] = "1"
 
 
14
  import torch
15
  import torch.distributed as dist
16
  import torchvision.transforms as transforms
@@ -516,7 +518,7 @@ setattr(diffusion, "settings", settings)
516
 
517
 
518
  pretrained_dewarp_model = GeoTr_Seg_Inf()
519
- settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth")
520
  reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
521
  # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
522
  pretrained_dewarp_model.to(dist_util.dev())
@@ -525,19 +527,19 @@ pretrained_dewarp_model.eval()
525
  if settings.env.use_line_mask:
526
  pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
527
  pretrained_seg_model = Seg()
528
- settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth")
529
  line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
530
  pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
531
  pretrained_line_seg_model.to(dist_util.dev())
532
  pretrained_line_seg_model.eval()
533
 
534
- settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth")
535
  seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
536
  pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
537
  pretrained_seg_model.to(dist_util.dev())
538
  pretrained_seg_model.eval()
539
 
540
- settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt")
541
  model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
542
  logger.log(f"Model loaded with {settings.env.model_path}")
543
 
@@ -564,4 +566,4 @@ if __name__ == '__main__':
564
 
565
 
566
 
567
-
 
11
  os.environ["CUDA_VISIBLE_DEVICES"] = "7"
12
  # os.environ["OPENAI_LOGDIR"] = "./logs"
13
  # os.environ["MPI_DISABLED"] = "1"
14
+ # os.environ.getattribute("HF_TOKEN")
15
+ token = os.getenv("HF_TOKEN", None)
16
  import torch
17
  import torch.distributed as dist
18
  import torchvision.transforms as transforms
 
518
 
519
 
520
  pretrained_dewarp_model = GeoTr_Seg_Inf()
521
+ settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth", token=token)
522
  reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
523
  # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
524
  pretrained_dewarp_model.to(dist_util.dev())
 
527
  if settings.env.use_line_mask:
528
  pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
529
  pretrained_seg_model = Seg()
530
+ settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth", token=token)
531
  line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
532
  pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
533
  pretrained_line_seg_model.to(dist_util.dev())
534
  pretrained_line_seg_model.eval()
535
 
536
+ settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth", token=token)
537
  seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
538
  pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
539
  pretrained_seg_model.to(dist_util.dev())
540
  pretrained_seg_model.eval()
541
 
542
+ settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt", token=token)
543
  model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
544
  logger.log(f"Model loaded with {settings.env.model_path}")
545
 
 
566
 
567
 
568
 
569
+