multimodalart HF staff commited on
Commit
b5965be
1 Parent(s): 24f1d8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -172,10 +172,20 @@ def train(*inputs):
172
  with zipfile.ZipFile('diffusers_model.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
173
  zipdir('output_model/', zipf)
174
  print("Training completed!")
175
- if os.path.exists("intraining.lock"): os.remove("intraining.lock")
176
- trained_file = open("hastrained.success", "w")
177
- trained_file.close()
178
- if(remove_attribution_after):
 
 
 
 
 
 
 
 
 
 
179
  hf_token = inputs[-5]
180
  model_name = inputs[-7]
181
  where_to_upload = inputs[-8]
@@ -184,14 +194,6 @@ def train(*inputs):
184
  headers = { "authorization" : f"Bearer {hf_token}"}
185
  body = {'flavor': 'cpu-basic'}
186
  requests.post(hardware_url, json = body, headers=headers)
187
- return [
188
- gr.update(visible=True, value=["diffusers_model.zip"]), #result
189
- gr.update(visible=True), #try_your_model
190
- gr.update(visible=True), #push_to_hub
191
- gr.update(visible=True), #convert_button
192
- gr.update(visible=False), #training_ongoing
193
- gr.update(visible=True) #completed_training
194
- ]
195
 
196
  def generate(prompt):
197
  torch.cuda.empty_cache()
 
172
  with zipfile.ZipFile('diffusers_model.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
173
  zipdir('output_model/', zipf)
174
  print("Training completed!")
175
+
176
+ if(not remove_attribution_after):
177
+ if os.path.exists("intraining.lock"): os.remove("intraining.lock")
178
+ trained_file = open("hastrained.success", "w")
179
+ trained_file.close()
180
+ return [
181
+ gr.update(visible=True, value=["diffusers_model.zip"]), #result
182
+ gr.update(visible=True), #try_your_model
183
+ gr.update(visible=True), #push_to_hub
184
+ gr.update(visible=True), #convert_button
185
+ gr.update(visible=False), #training_ongoing
186
+ gr.update(visible=True) #completed_training
187
+ ]
188
+ else:
189
  hf_token = inputs[-5]
190
  model_name = inputs[-7]
191
  where_to_upload = inputs[-8]
 
194
  headers = { "authorization" : f"Bearer {hf_token}"}
195
  body = {'flavor': 'cpu-basic'}
196
  requests.post(hardware_url, json = body, headers=headers)
 
 
 
 
 
 
 
 
197
 
198
  def generate(prompt):
199
  torch.cuda.empty_cache()