Spaces:
Runtime error
Runtime error
Commit
•
cfadc6d
1
Parent(s):
c9feef4
Update trainer.py
Browse files- trainer.py +12 -7
trainer.py
CHANGED
@@ -20,12 +20,12 @@ from utils import save_model_card
|
|
20 |
sys.path.append('Tune-A-Video')
|
21 |
|
22 |
URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
|
23 |
-
|
|
|
24 |
|
25 |
class Trainer:
|
26 |
def __init__(self, hf_token: str | None = None):
|
27 |
self.hf_token = hf_token
|
28 |
-
self.api = HfApi(token=hf_token)
|
29 |
self.model_uploader = ModelUploader(hf_token)
|
30 |
|
31 |
self.checkpoint_dir = pathlib.Path('checkpoints')
|
@@ -42,10 +42,10 @@ class Trainer:
|
|
42 |
cwd=org_dir)
|
43 |
return model_dir.as_posix()
|
44 |
|
45 |
-
def join_model_library_org(
|
46 |
subprocess.run(
|
47 |
shlex.split(
|
48 |
-
f'curl -X POST -H "Authorization: Bearer {
|
49 |
))
|
50 |
|
51 |
def run(
|
@@ -70,7 +70,10 @@ class Trainer:
|
|
70 |
delete_existing_repo: bool,
|
71 |
upload_to: str,
|
72 |
remove_gpu_after_training: bool,
|
|
|
73 |
) -> str:
|
|
|
|
|
74 |
if not torch.cuda.is_available():
|
75 |
raise gr.Error('CUDA is not available.')
|
76 |
if training_video is None:
|
@@ -143,14 +146,16 @@ class Trainer:
|
|
143 |
repo_name=output_model_name,
|
144 |
upload_to=upload_to,
|
145 |
private=use_private_repo,
|
146 |
-
delete_existing_repo=delete_existing_repo
|
|
|
147 |
print(upload_message)
|
148 |
message = message + '\n' + upload_message
|
149 |
|
150 |
if remove_gpu_after_training:
|
151 |
space_id = os.getenv('SPACE_ID')
|
152 |
if space_id:
|
153 |
-
|
154 |
-
|
|
|
155 |
|
156 |
return message
|
|
|
20 |
sys.path.append('Tune-A-Video')
|
21 |
|
22 |
URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
|
23 |
+
ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
|
24 |
+
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
|
25 |
|
26 |
class Trainer:
|
27 |
def __init__(self, hf_token: str | None = None):
|
28 |
self.hf_token = hf_token
|
|
|
29 |
self.model_uploader = ModelUploader(hf_token)
|
30 |
|
31 |
self.checkpoint_dir = pathlib.Path('checkpoints')
|
|
|
42 |
cwd=org_dir)
|
43 |
return model_dir.as_posix()
|
44 |
|
45 |
+
def join_model_library_org(token) -> None:
|
46 |
subprocess.run(
|
47 |
shlex.split(
|
48 |
+
f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
|
49 |
))
|
50 |
|
51 |
def run(
|
|
|
70 |
delete_existing_repo: bool,
|
71 |
upload_to: str,
|
72 |
remove_gpu_after_training: bool,
|
73 |
+
input_token: str,
|
74 |
) -> str:
|
75 |
+
if SPACE_ID != ORIGINAL_SPACE_ID:
|
76 |
+
raise gr.Error('This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU')
|
77 |
if not torch.cuda.is_available():
|
78 |
raise gr.Error('CUDA is not available.')
|
79 |
if training_video is None:
|
|
|
146 |
repo_name=output_model_name,
|
147 |
upload_to=upload_to,
|
148 |
private=use_private_repo,
|
149 |
+
delete_existing_repo=delete_existing_repo,
|
150 |
+
input_token=input_token)
|
151 |
print(upload_message)
|
152 |
message = message + '\n' + upload_message
|
153 |
|
154 |
if remove_gpu_after_training:
|
155 |
space_id = os.getenv('SPACE_ID')
|
156 |
if space_id:
|
157 |
+
api = HfApi(token=hf_token if self.hf_token else input_token)
|
158 |
+
api.request_space_hardware(repo_id=space_id,
|
159 |
+
hardware='cpu-basic')
|
160 |
|
161 |
return message
|