Commit
·
b305b72
1
Parent(s):
d6ffbb8
Add guideline to use BiRefNet. Remove codes of model.
Browse files
README.md
CHANGED
@@ -31,9 +31,12 @@ import matplotlib.pyplot as plt
|
|
31 |
import torch
|
32 |
from torchvision import transforms
|
33 |
|
|
|
|
|
|
|
34 |
# Input Data
|
35 |
transform_image = transforms.Compose([
|
36 |
-
transforms.Resize((
|
37 |
transforms.ToTensor(),
|
38 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
39 |
])
|
@@ -42,7 +45,7 @@ image = Image.open(imagepath)
|
|
42 |
input_images = transform_image(image).unsqueeze(0).to('cuda')
|
43 |
|
44 |
# Load Model
|
45 |
-
device = '
|
46 |
torch.set_float32_matmul_precision(['high', 'highest'][0])
|
47 |
model = BiRefNet.from_pretrained('zhengpeng7/birefnet')
|
48 |
model.to(device)
|
@@ -55,7 +58,7 @@ with torch.no_grad():
|
|
55 |
pred = preds[0].squeeze()
|
56 |
|
57 |
# Show Results
|
58 |
-
plt.imshow(pred, cmap='gray')
|
59 |
plt.show()
|
60 |
|
61 |
```
|
|
|
31 |
import torch
|
32 |
from torchvision import transforms
|
33 |
|
34 |
+
from models.birefnet import BiRefNet
|
35 |
+
|
36 |
+
|
37 |
# Input Data
|
38 |
transform_image = transforms.Compose([
|
39 |
+
transforms.Resize((1024, 1024)),
|
40 |
transforms.ToTensor(),
|
41 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
42 |
])
|
|
|
45 |
input_images = transform_image(image).unsqueeze(0).to('cuda')
|
46 |
|
47 |
# Load Model
|
48 |
+
device = 'cuda'
|
49 |
torch.set_float32_matmul_precision(['high', 'highest'][0])
|
50 |
model = BiRefNet.from_pretrained('zhengpeng7/birefnet')
|
51 |
model.to(device)
|
|
|
58 |
pred = preds[0].squeeze()
|
59 |
|
60 |
# Show Results
|
61 |
+
plt.imshow(transforms.ToPILImage()(pred).resize(image.size), cmap='gray')
|
62 |
plt.show()
|
63 |
|
64 |
```
|