Commit
·
5a483b6
1
Parent(s):
8b4c6bd
hf token auth
Browse files- app.py +3 -0
- tsr/system.py +3 -3
app.py
CHANGED
|
@@ -12,6 +12,8 @@ from PIL import Image
|
|
| 12 |
from tsr.system import TSR
|
| 13 |
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
| 14 |
|
|
|
|
|
|
|
| 15 |
if torch.cuda.is_available():
|
| 16 |
device = "cuda:0"
|
| 17 |
else:
|
|
@@ -21,6 +23,7 @@ model = TSR.from_pretrained(
|
|
| 21 |
"stabilityai/TripoSR",
|
| 22 |
config_name="config.yaml",
|
| 23 |
weight_name="model.ckpt",
|
|
|
|
| 24 |
)
|
| 25 |
model.to(device)
|
| 26 |
|
|
|
|
| 12 |
from tsr.system import TSR
|
| 13 |
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
| 14 |
|
| 15 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 16 |
+
|
| 17 |
if torch.cuda.is_available():
|
| 18 |
device = "cuda:0"
|
| 19 |
else:
|
|
|
|
| 23 |
"stabilityai/TripoSR",
|
| 24 |
config_name="config.yaml",
|
| 25 |
weight_name="model.ckpt",
|
| 26 |
+
token=HF_TOKEN
|
| 27 |
)
|
| 28 |
model.to(device)
|
| 29 |
|
tsr/system.py
CHANGED
|
@@ -50,17 +50,17 @@ class TSR(BaseModule):
|
|
| 50 |
|
| 51 |
@classmethod
|
| 52 |
def from_pretrained(
|
| 53 |
-
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
| 54 |
):
|
| 55 |
if os.path.isdir(pretrained_model_name_or_path):
|
| 56 |
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
| 57 |
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
| 58 |
else:
|
| 59 |
config_path = hf_hub_download(
|
| 60 |
-
repo_id=pretrained_model_name_or_path, filename=config_name
|
| 61 |
)
|
| 62 |
weight_path = hf_hub_download(
|
| 63 |
-
repo_id=pretrained_model_name_or_path, filename=weight_name
|
| 64 |
)
|
| 65 |
|
| 66 |
cfg = OmegaConf.load(config_path)
|
|
|
|
| 50 |
|
| 51 |
@classmethod
|
| 52 |
def from_pretrained(
|
| 53 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str, token=None
|
| 54 |
):
|
| 55 |
if os.path.isdir(pretrained_model_name_or_path):
|
| 56 |
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
| 57 |
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
| 58 |
else:
|
| 59 |
config_path = hf_hub_download(
|
| 60 |
+
repo_id=pretrained_model_name_or_path, filename=config_name, token=token
|
| 61 |
)
|
| 62 |
weight_path = hf_hub_download(
|
| 63 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name, token=token
|
| 64 |
)
|
| 65 |
|
| 66 |
cfg = OmegaConf.load(config_path)
|