dmitriitochilkin
commited on
Commit
•
5a483b6
1
Parent(s):
8b4c6bd
hf token auth
Browse files- app.py +3 -0
- tsr/system.py +3 -3
app.py
CHANGED
@@ -12,6 +12,8 @@ from PIL import Image
|
|
12 |
from tsr.system import TSR
|
13 |
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
14 |
|
|
|
|
|
15 |
if torch.cuda.is_available():
|
16 |
device = "cuda:0"
|
17 |
else:
|
@@ -21,6 +23,7 @@ model = TSR.from_pretrained(
|
|
21 |
"stabilityai/TripoSR",
|
22 |
config_name="config.yaml",
|
23 |
weight_name="model.ckpt",
|
|
|
24 |
)
|
25 |
model.to(device)
|
26 |
|
|
|
12 |
from tsr.system import TSR
|
13 |
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
14 |
|
15 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
16 |
+
|
17 |
if torch.cuda.is_available():
|
18 |
device = "cuda:0"
|
19 |
else:
|
|
|
23 |
"stabilityai/TripoSR",
|
24 |
config_name="config.yaml",
|
25 |
weight_name="model.ckpt",
|
26 |
+
token=HF_TOKEN
|
27 |
)
|
28 |
model.to(device)
|
29 |
|
tsr/system.py
CHANGED
@@ -50,17 +50,17 @@ class TSR(BaseModule):
|
|
50 |
|
51 |
@classmethod
|
52 |
def from_pretrained(
|
53 |
-
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
54 |
):
|
55 |
if os.path.isdir(pretrained_model_name_or_path):
|
56 |
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
57 |
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
58 |
else:
|
59 |
config_path = hf_hub_download(
|
60 |
-
repo_id=pretrained_model_name_or_path, filename=config_name
|
61 |
)
|
62 |
weight_path = hf_hub_download(
|
63 |
-
repo_id=pretrained_model_name_or_path, filename=weight_name
|
64 |
)
|
65 |
|
66 |
cfg = OmegaConf.load(config_path)
|
|
|
50 |
|
51 |
@classmethod
|
52 |
def from_pretrained(
|
53 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str, token=None
|
54 |
):
|
55 |
if os.path.isdir(pretrained_model_name_or_path):
|
56 |
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
57 |
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
58 |
else:
|
59 |
config_path = hf_hub_download(
|
60 |
+
repo_id=pretrained_model_name_or_path, filename=config_name, token=token
|
61 |
)
|
62 |
weight_path = hf_hub_download(
|
63 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name, token=token
|
64 |
)
|
65 |
|
66 |
cfg = OmegaConf.load(config_path)
|