KhadijaAsehnoune12 commited on
Commit
492b5e5
1 Parent(s): a332ed4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -26
app.py CHANGED
@@ -25,49 +25,24 @@ id2label = {
25
  }
26
 
27
  def remove_background(image):
28
- # Convert the image to RGBA
29
  image = image.convert("RGBA")
30
-
31
- # Remove the background
32
  image_np = np.array(image)
33
  output_np = rembg.remove(image_np)
34
-
35
- # Create a white background image
36
  white_bg = Image.new("RGBA", image.size, "WHITE")
37
-
38
- # Composite the original image over the white background
39
  output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
40
-
41
- # Convert back to RGB
42
  output_image = output_image.convert("RGB")
43
-
44
  return output_image
45
 
46
 
47
  def predict(image):
48
- # Remove the background
49
  image = remove_background(image)
50
-
51
- # Preprocess the image
52
  inputs = feature_extractor(images=image, return_tensors="pt")
53
-
54
- # Forward pass through the model
55
  outputs = model(**inputs)
56
-
57
- # Get the logits
58
  logits = outputs.logits
59
-
60
- # Calculate confidence scores with softmax
61
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
62
-
63
- # Get the index of the most probable class
64
  predicted_class_idx = probs.argmax().item()
65
-
66
- # Get the label and confidence score of the most probable class
67
  predicted_label = id2label[str(predicted_class_idx)]
68
- confidence_score = probs[predicted_class_idx].item() * 100 # Multiply by 100 to get a percentage
69
-
70
- # Return the label and confidence score
71
  return f"{predicted_label}: {confidence_score:.2f}%"
72
 
73
  # Create the Gradio interface
 
25
  }
26
 
27
  def remove_background(image):
 
28
  image = image.convert("RGBA")
 
 
29
  image_np = np.array(image)
30
  output_np = rembg.remove(image_np)
 
 
31
  white_bg = Image.new("RGBA", image.size, "WHITE")
 
 
32
  output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
 
 
33
  output_image = output_image.convert("RGB")
 
34
  return output_image
35
 
36
 
37
  def predict(image):
 
38
  image = remove_background(image)
 
 
39
  inputs = feature_extractor(images=image, return_tensors="pt")
 
 
40
  outputs = model(**inputs)
 
 
41
  logits = outputs.logits
 
 
42
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
 
 
43
  predicted_class_idx = probs.argmax().item()
 
 
44
  predicted_label = id2label[str(predicted_class_idx)]
45
+ confidence_score = probs[predicted_class_idx].item() * 100
 
 
46
  return f"{predicted_label}: {confidence_score:.2f}%"
47
 
48
  # Create the Gradio interface