Commit
•
b5965be
1
Parent(s):
24f1d8c
Update app.py
Browse files
app.py
CHANGED
@@ -172,10 +172,20 @@ def train(*inputs):
|
|
172 |
with zipfile.ZipFile('diffusers_model.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
|
173 |
zipdir('output_model/', zipf)
|
174 |
print("Training completed!")
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
hf_token = inputs[-5]
|
180 |
model_name = inputs[-7]
|
181 |
where_to_upload = inputs[-8]
|
@@ -184,14 +194,6 @@ def train(*inputs):
|
|
184 |
headers = { "authorization" : f"Bearer {hf_token}"}
|
185 |
body = {'flavor': 'cpu-basic'}
|
186 |
requests.post(hardware_url, json = body, headers=headers)
|
187 |
-
return [
|
188 |
-
gr.update(visible=True, value=["diffusers_model.zip"]), #result
|
189 |
-
gr.update(visible=True), #try_your_model
|
190 |
-
gr.update(visible=True), #push_to_hub
|
191 |
-
gr.update(visible=True), #convert_button
|
192 |
-
gr.update(visible=False), #training_ongoing
|
193 |
-
gr.update(visible=True) #completed_training
|
194 |
-
]
|
195 |
|
196 |
def generate(prompt):
|
197 |
torch.cuda.empty_cache()
|
|
|
172 |
with zipfile.ZipFile('diffusers_model.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
|
173 |
zipdir('output_model/', zipf)
|
174 |
print("Training completed!")
|
175 |
+
|
176 |
+
if(not remove_attribution_after):
|
177 |
+
if os.path.exists("intraining.lock"): os.remove("intraining.lock")
|
178 |
+
trained_file = open("hastrained.success", "w")
|
179 |
+
trained_file.close()
|
180 |
+
return [
|
181 |
+
gr.update(visible=True, value=["diffusers_model.zip"]), #result
|
182 |
+
gr.update(visible=True), #try_your_model
|
183 |
+
gr.update(visible=True), #push_to_hub
|
184 |
+
gr.update(visible=True), #convert_button
|
185 |
+
gr.update(visible=False), #training_ongoing
|
186 |
+
gr.update(visible=True) #completed_training
|
187 |
+
]
|
188 |
+
else:
|
189 |
hf_token = inputs[-5]
|
190 |
model_name = inputs[-7]
|
191 |
where_to_upload = inputs[-8]
|
|
|
194 |
headers = { "authorization" : f"Bearer {hf_token}"}
|
195 |
body = {'flavor': 'cpu-basic'}
|
196 |
requests.post(hardware_url, json = body, headers=headers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
def generate(prompt):
|
199 |
torch.cuda.empty_cache()
|