ZhengPeng7 commited on
Commit
de0b7d0
·
1 Parent(s): 4edabb2

Upgrade the weights loading method to avoid duplicated loading.

Browse files
Files changed (1) hide show
  1. app.py +2 -10
app.py CHANGED
@@ -45,28 +45,20 @@ usage_to_weights_file = {
45
  }
46
 
47
  from transformers import AutoModelForImageSegmentation
48
- weights_path = 'General'
49
- birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file[weights_path])), trust_remote_code=True)
50
  birefnet.to(device)
51
  birefnet.eval()
52
- weights_path = weights_path
53
 
54
 
55
  @spaces.GPU
56
  def predict(image, resolution, weights_file):
57
- global birefnet
58
- global weights_path
59
- if weights_path != weights_file:
60
- print('*' * 10)
61
- print('\t1: ', weights_file, weights_path)
62
  # Load BiRefNet with chosen weights
63
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
64
  print('Change weights to:', _weights_file)
65
  birefnet = birefnet.from_pretrained(_weights_file)
66
  birefnet.to(device)
67
  birefnet.eval()
68
- weights_path = weights_file
69
- print('\t2: ', weights_file, weights_path)
70
 
71
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
72
  # Image is a RGB numpy array.
 
45
  }
46
 
47
  from transformers import AutoModelForImageSegmentation
48
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
 
49
  birefnet.to(device)
50
  birefnet.eval()
 
51
 
52
 
53
  @spaces.GPU
54
  def predict(image, resolution, weights_file):
55
+ if weights_file != 'General':
 
 
 
 
56
  # Load BiRefNet with chosen weights
57
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
58
  print('Change weights to:', _weights_file)
59
  birefnet = birefnet.from_pretrained(_weights_file)
60
  birefnet.to(device)
61
  birefnet.eval()
 
 
62
 
63
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
64
  # Image is a RGB numpy array.