glenn-jocher commited on
Commit
bc1fd13
1 Parent(s): 44cdcc7

gsutil cp hyp evolution bug fix (#876)

Browse files
Files changed (2) hide show
  1. utils/general.py +4 -1
  2. utils/google_utils.py +7 -0
utils/general.py CHANGED
@@ -22,6 +22,7 @@ from scipy.cluster.vq import kmeans
22
  from scipy.signal import butter, filtfilt
23
  from tqdm import tqdm
24
 
 
25
  from utils.torch_utils import init_seeds as init_torch_seeds
26
  from utils.torch_utils import is_parallel
27
 
@@ -854,7 +855,9 @@ def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
854
  print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
855
 
856
  if bucket:
857
- os.system('gsutil cp gs://%s/evolve.txt .' % bucket) # download evolve.txt
 
 
858
 
859
  with open('evolve.txt', 'a') as f: # append result
860
  f.write(c + b + '\n')
 
22
  from scipy.signal import butter, filtfilt
23
  from tqdm import tqdm
24
 
25
+ from utils.google_utils import gsutil_getsize
26
  from utils.torch_utils import init_seeds as init_torch_seeds
27
  from utils.torch_utils import is_parallel
28
 
 
855
  print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
856
 
857
  if bucket:
858
+ url = 'gs://%s/evolve.txt' % bucket
859
+ if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
860
+ os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
861
 
862
  with open('evolve.txt', 'a') as f: # append result
863
  f.write(c + b + '\n')
utils/google_utils.py CHANGED
@@ -4,12 +4,19 @@
4
 
5
  import os
6
  import platform
 
7
  import time
8
  from pathlib import Path
9
 
10
  import torch
11
 
12
 
 
 
 
 
 
 
13
  def attempt_download(weights):
14
  # Attempt to download pretrained weights if not found locally
15
  weights = weights.strip().replace("'", '')
 
4
 
5
  import os
6
  import platform
7
+ import subprocess
8
  import time
9
  from pathlib import Path
10
 
11
  import torch
12
 
13
 
14
+ def gsutil_getsize(url=''):
15
+ # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du
16
+ s = subprocess.check_output('gsutil du %s' % url, shell=True).decode('utf-8')
17
+ return eval(s.split(' ')[0]) if len(s) else 0 # bytes
18
+
19
+
20
  def attempt_download(weights):
21
  # Attempt to download pretrained weights if not found locally
22
  weights = weights.strip().replace("'", '')