KhadijaAsehnoune12
commited on
Commit
•
492b5e5
1
Parent(s):
a332ed4
Update app.py
Browse files
app.py
CHANGED
@@ -25,49 +25,24 @@ id2label = {
|
|
25 |
}
|
26 |
|
27 |
def remove_background(image):
|
28 |
-
# Convert the image to RGBA
|
29 |
image = image.convert("RGBA")
|
30 |
-
|
31 |
-
# Remove the background
|
32 |
image_np = np.array(image)
|
33 |
output_np = rembg.remove(image_np)
|
34 |
-
|
35 |
-
# Create a white background image
|
36 |
white_bg = Image.new("RGBA", image.size, "WHITE")
|
37 |
-
|
38 |
-
# Composite the original image over the white background
|
39 |
output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
|
40 |
-
|
41 |
-
# Convert back to RGB
|
42 |
output_image = output_image.convert("RGB")
|
43 |
-
|
44 |
return output_image
|
45 |
|
46 |
|
47 |
def predict(image):
|
48 |
-
# Remove the background
|
49 |
image = remove_background(image)
|
50 |
-
|
51 |
-
# Preprocess the image
|
52 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
53 |
-
|
54 |
-
# Forward pass through the model
|
55 |
outputs = model(**inputs)
|
56 |
-
|
57 |
-
# Get the logits
|
58 |
logits = outputs.logits
|
59 |
-
|
60 |
-
# Calculate confidence scores with softmax
|
61 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
62 |
-
|
63 |
-
# Get the index of the most probable class
|
64 |
predicted_class_idx = probs.argmax().item()
|
65 |
-
|
66 |
-
# Get the label and confidence score of the most probable class
|
67 |
predicted_label = id2label[str(predicted_class_idx)]
|
68 |
-
confidence_score = probs[predicted_class_idx].item() * 100
|
69 |
-
|
70 |
-
# Return the label and confidence score
|
71 |
return f"{predicted_label}: {confidence_score:.2f}%"
|
72 |
|
73 |
# Create the Gradio interface
|
|
|
25 |
}
|
26 |
|
27 |
def remove_background(image):
|
|
|
28 |
image = image.convert("RGBA")
|
|
|
|
|
29 |
image_np = np.array(image)
|
30 |
output_np = rembg.remove(image_np)
|
|
|
|
|
31 |
white_bg = Image.new("RGBA", image.size, "WHITE")
|
|
|
|
|
32 |
output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
|
|
|
|
|
33 |
output_image = output_image.convert("RGB")
|
|
|
34 |
return output_image
|
35 |
|
36 |
|
37 |
def predict(image):
|
|
|
38 |
image = remove_background(image)
|
|
|
|
|
39 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
|
|
|
|
40 |
outputs = model(**inputs)
|
|
|
|
|
41 |
logits = outputs.logits
|
|
|
|
|
42 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
|
|
|
|
43 |
predicted_class_idx = probs.argmax().item()
|
|
|
|
|
44 |
predicted_label = id2label[str(predicted_class_idx)]
|
45 |
+
confidence_score = probs[predicted_class_idx].item() * 100
|
|
|
|
|
46 |
return f"{predicted_label}: {confidence_score:.2f}%"
|
47 |
|
48 |
# Create the Gradio interface
|