Spaces:
Running
Running
fancyfeast
commited on
Commit
•
b72bef3
1
Parent(s):
df9e86f
Prepare images correctly
Browse files
app.py
CHANGED
@@ -4,15 +4,45 @@ import huggingface_hub
|
|
4 |
from PIL import Image
|
5 |
import torch.amp.autocast_mode
|
6 |
from pathlib import Path
|
|
|
|
|
7 |
|
8 |
|
9 |
MODEL_REPO = "fancyfeast/joytag"
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
@torch.no_grad()
|
13 |
def predict(image: Image.Image):
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
16 |
tag_preds = preds['tags'].sigmoid().cpu()
|
17 |
|
18 |
return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))}
|
|
|
4 |
from PIL import Image
|
5 |
import torch.amp.autocast_mode
|
6 |
from pathlib import Path
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms.functional as TVF
|
9 |
|
10 |
|
11 |
MODEL_REPO = "fancyfeast/joytag"
|
12 |
|
13 |
|
14 |
+
def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor:
|
15 |
+
# Pad image to square
|
16 |
+
image_shape = image.size
|
17 |
+
max_dim = max(image_shape)
|
18 |
+
pad_left = (max_dim - image_shape[0]) // 2
|
19 |
+
pad_top = (max_dim - image_shape[1]) // 2
|
20 |
+
|
21 |
+
padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
|
22 |
+
padded_image.paste(image, (pad_left, pad_top))
|
23 |
+
|
24 |
+
# Resize image
|
25 |
+
if max_dim != target_size:
|
26 |
+
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
|
27 |
+
|
28 |
+
# Convert to tensor
|
29 |
+
image_tensor = TVF.pil_to_tensor(padded_image) / 255.0
|
30 |
+
|
31 |
+
# Normalize
|
32 |
+
image_tensor = TVF.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
|
33 |
+
|
34 |
+
return image_tensor
|
35 |
+
|
36 |
+
|
37 |
@torch.no_grad()
|
38 |
def predict(image: Image.Image):
|
39 |
+
image_tensor = prepare_image(image, model.image_size)
|
40 |
+
batch = {
|
41 |
+
'image': image_tensor.unsqueeze(0),
|
42 |
+
}
|
43 |
+
|
44 |
+
with torch.amp.autocast_mode.autocast('cpu', enabled=True):
|
45 |
+
preds = model(batch)
|
46 |
tag_preds = preds['tags'].sigmoid().cpu()
|
47 |
|
48 |
return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))}
|