Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
16bb8a1
1
Parent(s):
6c4b492
app.py
CHANGED
|
@@ -11,6 +11,8 @@ import os
|
|
| 11 |
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
| 12 |
# os.environ["OPENAI_LOGDIR"] = "./logs"
|
| 13 |
# os.environ["MPI_DISABLED"] = "1"
|
|
|
|
|
|
|
| 14 |
import torch
|
| 15 |
import torch.distributed as dist
|
| 16 |
import torchvision.transforms as transforms
|
|
@@ -516,7 +518,7 @@ setattr(diffusion, "settings", settings)
|
|
| 516 |
|
| 517 |
|
| 518 |
pretrained_dewarp_model = GeoTr_Seg_Inf()
|
| 519 |
-
settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth")
|
| 520 |
reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
|
| 521 |
# reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
|
| 522 |
pretrained_dewarp_model.to(dist_util.dev())
|
|
@@ -525,19 +527,19 @@ pretrained_dewarp_model.eval()
|
|
| 525 |
if settings.env.use_line_mask:
|
| 526 |
pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
|
| 527 |
pretrained_seg_model = Seg()
|
| 528 |
-
settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth")
|
| 529 |
line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
|
| 530 |
pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
|
| 531 |
pretrained_line_seg_model.to(dist_util.dev())
|
| 532 |
pretrained_line_seg_model.eval()
|
| 533 |
|
| 534 |
-
settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth")
|
| 535 |
seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
|
| 536 |
pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
|
| 537 |
pretrained_seg_model.to(dist_util.dev())
|
| 538 |
pretrained_seg_model.eval()
|
| 539 |
|
| 540 |
-
settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt")
|
| 541 |
model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
|
| 542 |
logger.log(f"Model loaded with {settings.env.model_path}")
|
| 543 |
|
|
@@ -564,4 +566,4 @@ if __name__ == '__main__':
|
|
| 564 |
|
| 565 |
|
| 566 |
|
| 567 |
-
|
|
|
|
| 11 |
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
| 12 |
# os.environ["OPENAI_LOGDIR"] = "./logs"
|
| 13 |
# os.environ["MPI_DISABLED"] = "1"
|
| 14 |
+
# os.environ.getattribute("HF_TOKEN")
|
| 15 |
+
token = os.getenv("HF_TOKEN", None)
|
| 16 |
import torch
|
| 17 |
import torch.distributed as dist
|
| 18 |
import torchvision.transforms as transforms
|
|
|
|
| 518 |
|
| 519 |
|
| 520 |
pretrained_dewarp_model = GeoTr_Seg_Inf()
|
| 521 |
+
settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth", token=token)
|
| 522 |
reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
|
| 523 |
# reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
|
| 524 |
pretrained_dewarp_model.to(dist_util.dev())
|
|
|
|
| 527 |
if settings.env.use_line_mask:
|
| 528 |
pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
|
| 529 |
pretrained_seg_model = Seg()
|
| 530 |
+
settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth", token=token)
|
| 531 |
line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
|
| 532 |
pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
|
| 533 |
pretrained_line_seg_model.to(dist_util.dev())
|
| 534 |
pretrained_line_seg_model.eval()
|
| 535 |
|
| 536 |
+
settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth", token=token)
|
| 537 |
seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
|
| 538 |
pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
|
| 539 |
pretrained_seg_model.to(dist_util.dev())
|
| 540 |
pretrained_seg_model.eval()
|
| 541 |
|
| 542 |
+
settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt", token=token)
|
| 543 |
model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
|
| 544 |
logger.log(f"Model loaded with {settings.env.model_path}")
|
| 545 |
|
|
|
|
| 566 |
|
| 567 |
|
| 568 |
|
| 569 |
+
|