glenn-jocher commited on
Commit
96a8446
·
unverified ·
1 Parent(s): 97a5227

Update labels_to_image_weights() (#1545)

Browse files
Files changed (2) hide show
  1. utils/general.py +5 -10
  2. utils/plots.py +1 -0
utils/general.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  import glob
4
  import logging
5
- import math
6
  import os
7
  import platform
8
  import random
@@ -12,7 +11,7 @@ import time
12
  from pathlib import Path
13
 
14
  import cv2
15
- import matplotlib
16
  import numpy as np
17
  import torch
18
  import torchvision
@@ -22,13 +21,10 @@ from utils.google_utils import gsutil_getsize
22
  from utils.metrics import fitness
23
  from utils.torch_utils import init_torch_seeds
24
 
25
- # Set printoptions
26
  torch.set_printoptions(linewidth=320, precision=5, profile='long')
27
  np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
28
- matplotlib.rc('font', **{'size': 11})
29
-
30
- # Prevent OpenCV from multithreading (to use PyTorch DataLoader)
31
- cv2.setNumThreads(0)
32
 
33
 
34
  def set_logging(rank=-1):
@@ -121,9 +117,8 @@ def labels_to_class_weights(labels, nc=80):
121
 
122
 
123
  def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
124
- # Produces image weights based on class mAPs
125
- n = len(labels)
126
- class_counts = np.array([np.bincount(labels[i][:, 0].astype(np.int), minlength=nc) for i in range(n)])
127
  image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
128
  # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
129
  return image_weights
 
2
 
3
  import glob
4
  import logging
 
5
  import os
6
  import platform
7
  import random
 
11
  from pathlib import Path
12
 
13
  import cv2
14
+ import math
15
  import numpy as np
16
  import torch
17
  import torchvision
 
21
  from utils.metrics import fitness
22
  from utils.torch_utils import init_torch_seeds
23
 
24
+ # Settings
25
  torch.set_printoptions(linewidth=320, precision=5, profile='long')
26
  np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
27
+ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
 
 
 
28
 
29
 
30
  def set_logging(rank=-1):
 
117
 
118
 
119
  def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
120
+ # Produces image weights based on class_weights and image contents
121
+ class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
 
122
  image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
123
  # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
124
  return image_weights
utils/plots.py CHANGED
@@ -20,6 +20,7 @@ from utils.general import xywh2xyxy, xyxy2xywh
20
  from utils.metrics import fitness
21
 
22
  # Settings
 
23
  matplotlib.use('Agg') # for writing to files only
24
 
25
 
 
20
  from utils.metrics import fitness
21
 
22
  # Settings
23
+ matplotlib.rc('font', **{'size': 11})
24
  matplotlib.use('Agg') # for writing to files only
25
 
26