Abdu07 commited on
Commit
7efb51a
·
verified ·
1 Parent(s): 4375fb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -34
app.py CHANGED
@@ -24,55 +24,51 @@ class MultiTaskModel(nn.Module):
24
  return obj_logits, bin_logits
25
 
26
  ########################################
27
- # 2. Reconstruct the Model and Load Weights
28
  ########################################
29
- # IMPORTANT: The checkpoint was saved with a single object class,
30
- # so we set num_obj_classes to 1.
31
- num_obj_classes = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
33
  device = torch.device("cpu")
34
 
35
  resnet = models.resnet50(pretrained=False)
36
  resnet.fc = nn.Identity()
37
  feature_dim = 2048
 
 
38
  model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
39
  model.to(device)
40
 
41
  repo_id = "Abdu07/multitask-model"
42
- filename = "DualSight.pth"
43
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
44
  state_dict = torch.load(weights_path, map_location="cpu")
45
  model.load_state_dict(state_dict)
46
  model.eval()
47
 
48
  ########################################
49
- # 3. Load Label Mapping and Define Transforms
50
- ########################################
51
- # Attempt to load the mapping from JSON.
52
- # If the mapping contains more than one label, we override it with a single-label mapping
53
- try:
54
- with open("obj_label_mapping.json", "r") as f:
55
- obj_label_to_idx = json.load(f)
56
- if len(obj_label_to_idx) != 1:
57
- obj_label_to_idx = {"Detected Object": 0}
58
- except Exception as e:
59
- print("Error loading mapping, using default mapping. Error:", e)
60
- obj_label_to_idx = {"Detected Object": 0}
61
-
62
- idx_to_obj_label = {v: k for k, v in obj_label_to_idx.items()}
63
-
64
- bin_label_names = ["AI-Generated", "Real"]
65
-
66
- val_transforms = transforms.Compose([
67
- transforms.Resize(256),
68
- transforms.CenterCrop(224),
69
- transforms.ToTensor(),
70
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
71
- std=[0.229, 0.224, 0.225])
72
- ])
73
-
74
- ########################################
75
- # 4. Define the Inference Function
76
  ########################################
77
  def predict_image(img: Image.Image) -> str:
78
  img = img.convert("RGB")
@@ -86,13 +82,13 @@ def predict_image(img: Image.Image) -> str:
86
  return f"Prediction: {obj_name} ({bin_name})"
87
 
88
  ########################################
89
- # 5. Create Gradio UI
90
  ########################################
91
  demo = gr.Interface(
92
  fn=predict_image,
93
  inputs=gr.Image(type="pil"),
94
  outputs="text",
95
- title="Multi-Task Image Classifier",
96
  description="Upload an image to receive two predictions:\n1) The primary object in the image,\n2) Whether the image is AI-generated or Real."
97
  )
98
 
 
24
  return obj_logits, bin_logits
25
 
26
  ########################################
27
+ # 2. Load the Label Mapping and Set num_obj_classes
28
  ########################################
29
+ # Load the saved mapping from JSON
30
+ with open("obj_label_mapping.json", "r") as f:
31
+ obj_label_to_idx = json.load(f)
32
+ # Use the mapping as-is; do not override it.
33
+ num_obj_classes = len(obj_label_to_idx)
34
+ # Create the inverse mapping
35
+ idx_to_obj_label = {v: k for k, v in obj_label_to_idx.items()}
36
+
37
+ bin_label_names = ["AI-Generated", "Real"]
38
+
39
+ ########################################
40
+ # 3. Define Validation Transforms
41
+ ########################################
42
+ val_transforms = transforms.Compose([
43
+ transforms.Resize(256),
44
+ transforms.CenterCrop(224),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
47
+ std=[0.229, 0.224, 0.225])
48
+ ])
49
 
50
+ ########################################
51
+ # 4. Reconstruct the Model and Load Weights
52
+ ########################################
53
  device = torch.device("cpu")
54
 
55
  resnet = models.resnet50(pretrained=False)
56
  resnet.fc = nn.Identity()
57
  feature_dim = 2048
58
+
59
+ # Build the model architecture.
60
  model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
61
  model.to(device)
62
 
63
  repo_id = "Abdu07/multitask-model"
64
+ filename = "DualSight.pth" # Ensure this checkpoint is from training with the same num_obj_classes
65
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
66
  state_dict = torch.load(weights_path, map_location="cpu")
67
  model.load_state_dict(state_dict)
68
  model.eval()
69
 
70
  ########################################
71
+ # 5. Define the Inference Function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  ########################################
73
  def predict_image(img: Image.Image) -> str:
74
  img = img.convert("RGB")
 
82
  return f"Prediction: {obj_name} ({bin_name})"
83
 
84
  ########################################
85
+ # 6. Create Gradio UI
86
  ########################################
87
  demo = gr.Interface(
88
  fn=predict_image,
89
  inputs=gr.Image(type="pil"),
90
  outputs="text",
91
+ title="Multi-Task Image Classifier Trained by [Abdellahi El Moustapha](https://abmstpha.github.io/),
92
  description="Upload an image to receive two predictions:\n1) The primary object in the image,\n2) Whether the image is AI-generated or Real."
93
  )
94