Spaces:
Runtime error
Runtime error
File size: 607 Bytes
4a285f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import os
import torch
from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB
def model(model_name='u2net'):
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
if(model_name=='u2net'):
print("...load U2NET---173.6 MB")
net = U2NET(3,1)
elif(model_name=='u2netp'):
print("...load U2NEP---4.7 MB")
net = U2NETP(3,1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.cuda()
net.eval()
return net
|