SaniaE commited on
Commit
cf0f372
·
verified ·
1 Parent(s): 8b9f879

revamped complete API structure

Browse files
Files changed (1) hide show
  1. app.py +102 -126
app.py CHANGED
@@ -1,40 +1,46 @@
1
  import os
2
- import torch
3
- import random
4
- import asyncio
5
  import io
 
 
6
  import numpy as np
 
 
7
  import matplotlib.pyplot as plt
8
  from PIL import Image, ImageFilter
9
  from fastapi import FastAPI, UploadFile, File, Query
10
  from fastapi.responses import StreamingResponse
11
  from huggingface_hub import snapshot_download, login
12
- import torch.nn.functional as F
13
 
14
  from transformers import (
15
  BlipProcessor, BlipForConditionalGeneration,
16
- ViTImageProcessor, AutoProcessor, AutoModelForCausalLM
 
17
  )
18
 
19
- app = FastAPI(title="XAI Auditor Ensemble")
20
 
21
- # --- Configuration & State ---
22
  REPO_ID = "SaniaE/Image_Captioning_Ensemble"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
  MODELS = {}
25
 
26
- MODEL_SETTINGS = {
 
27
  "blip": {
28
  "subfolder": "blip",
29
- "processor": BlipProcessor,
30
- "pretrained_path": "Salesforce/blip-image-captioning-large",
31
- "inference_model": BlipForConditionalGeneration
32
- },
33
  "vit": {
34
  "subfolder": "vit",
35
- "processor": [ViTImageProcessor, AutoProcessor],
36
- "pretrained_path": ["nlpconnect/vit-gpt2-image-captioning", "microsoft/git-large"],
37
- "inference_model": AutoModelForCausalLM
 
 
 
 
38
  }
39
  }
40
 
@@ -44,66 +50,72 @@ async def startup_event():
44
  token = os.getenv("HF_Token")
45
  if token: login(token=token)
46
 
47
- print(f"Downloading models from {REPO_ID}...")
48
  local_dir = snapshot_download(repo_id=REPO_ID, token=token, local_dir="weights")
49
 
50
- for name, cfg in MODEL_SETTINGS.items():
51
- ckpt_path = os.path.join(local_dir, cfg["subfolder"])
52
- print(f"Loading {name} from {ckpt_path}...")
53
-
54
- model = cfg["inference_model"].from_pretrained(ckpt_path).to(DEVICE)
55
-
56
- if name == "vit":
57
- i_proc = cfg["processor"][0].from_pretrained(cfg["pretrained_path"][0])
58
- t_proc = cfg["processor"][1].from_pretrained(cfg["pretrained_path"][1])
59
- processor = (i_proc, t_proc)
60
- else:
61
- processor = cfg["processor"].from_pretrained(cfg["pretrained_path"])
62
-
63
- MODELS[name] = {"model": model, "processor": processor}
64
- print("Optimization Complete: Ensemble is live!")
65
-
66
- # --- Core Logic Helpers ---
67
-
68
- def _generate_sync(m_name, image, temp=0.7):
69
- """Synchronous generator tailored for the specific architecture."""
70
- m_data = MODELS[m_name]
71
- model = m_data["model"]
 
72
 
 
 
 
 
 
 
73
  if m_name == "vit":
74
  i_proc, t_proc = m_data["processor"]
75
  inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
76
- gen_ids = model.generate(**inputs, max_length=50, do_sample=True, temperature=temp)
77
- return t_proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
78
  else:
79
  proc = m_data["processor"]
80
  inputs = proc(images=image, return_tensors="pt").to(DEVICE)
81
- gen_ids = model.generate(**inputs, max_length=50, do_sample=True, temperature=temp)
82
- return proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
83
 
84
- # --- Endpoint 1: The Multi-Perspective Generator ---
85
 
86
  @app.post("/generate")
87
- async def generate_endpoint(
88
  file: UploadFile = File(...),
89
  temp: float = Query(0.8),
90
  top_k: int = Query(50),
91
  top_p: float = Query(0.9)
92
  ):
 
93
  image = Image.open(file.file).convert("RGB")
94
- available = ["blip", "vit"]
 
95
 
96
- # Generate 5 captions using a mix of models
97
- model_selection = random.choices(available, k=5)
98
- tasks = [asyncio.to_thread(_generate_sync, m, image, temp, top_k, top_p) for m in model_selection]
99
  captions = await asyncio.gather(*tasks)
100
 
101
- return {"captions": captions, "architectures": model_selection}
102
-
103
- # --- Endpoint 2: Objective Vision Saliency (Static Image Perception) ---
104
 
105
- @app.post("/saliency-explorer/vision")
106
- async def get_objective_saliency(file: UploadFile = File(...)):
 
107
  image_bytes = await file.read()
108
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
109
 
@@ -111,93 +123,57 @@ async def get_objective_saliency(file: UploadFile = File(...)):
111
  inputs = blip["processor"](images=orig_img, return_tensors="pt").to(DEVICE)
112
 
113
  with torch.no_grad():
114
- # Capturing Self-Attention from the Vision Encoder itself
115
- # This shows what the model finds interesting in the image, regardless of prompt
116
- outputs = blip["model"].vision_model(
117
- inputs.pixel_values,
118
- output_attentions=True
119
- )
120
-
121
- # Last layer attention: (batch, heads, patches, patches)
122
- attentions = outputs.attentions[-1]
123
-
124
- # Average across heads and focus on CLS token's view of the patches
125
- # Patch grid for BLIP-Large is typically 24x24 (576 patches + 1 CLS)
126
- nh = attentions.shape[1]
127
- attentional_map = attentions[0, :, 0, 1:].reshape(nh, -1)
128
- mask_1d = attentional_map.mean(dim=0)
129
-
130
  grid_size = int(np.sqrt(mask_1d.shape[-1]))
131
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
132
 
133
- # Normalization and High-Contrast "Heat"
134
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
135
- mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
136
- mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=10))
137
-
138
- heatmap_rgba = plt.get_cmap('magma')(np.array(mask_pill)/255.0)
139
- heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
140
 
141
- # Blending at 0.6 alpha to make the "Model's Focus" pop
142
- blended_img = Image.blend(orig_img, heatmap_img, alpha=0.6)
 
143
 
144
  buf = io.BytesIO()
145
- blended_img.save(buf, format="PNG")
146
  buf.seek(0)
147
  return StreamingResponse(buf, media_type="image/png")
148
 
149
- # --- Endpoint 3: Perspective Auditor (Internal Debate) ---
150
- # --- Endpoint 3: Internal Debate (Audit Mode) ---
151
-
152
- @app.post("/audit-perspective")
153
- async def audit_perspective(file: UploadFile = File(...), user_prompt: str = Query(...)):
154
- image = Image.open(file.file).convert("RGB")
155
 
156
- # Run both models to get the "Internal Debate"
157
  blip_caption = await asyncio.to_thread(_generate_sync, "blip", image, 0.7, 50, 0.9)
158
- vit_caption = await asyncio.to_thread(_generate_sync, "vit", image, 0.7, 50, 0.9)
159
-
160
- def get_metrics(target, reference):
161
- # 1. Semantic Embedding (The "Vibe" check)
162
- blip = MODELS["blip"]
163
- t_in = blip["processor"](text=target, return_tensors="pt", padding=True).to(DEVICE)
164
- r_in = blip["processor"](text=reference, return_tensors="pt", padding=True).to(DEVICE)
165
-
166
- with torch.no_grad():
167
- t_emb = F.normalize(blip["model"].text_decoder.bert(**t_in).last_hidden_state.mean(dim=1), p=2, dim=-1)
168
- r_emb = F.normalize(blip["model"].text_decoder.bert(**r_in).last_hidden_state.mean(dim=1), p=2, dim=-1)
169
-
170
- cosine_sim = torch.matmul(t_emb, r_emb.T).item()
171
-
172
- # 2. Jaccard Calibration (The "Accuracy" check - 70% weight)
173
- t_words = set(target.lower().replace(",", "").split())
174
- r_words = set(reference.lower().replace(",", "").split())
175
- jaccard = len(t_words & r_words) / len(t_words | r_words) if t_words | r_words else 0
176
-
177
- return (cosine_sim * 0.3) + (jaccard * 0.7)
178
-
179
- user_vs_blip = get_metrics(user_prompt, blip_caption)
180
- user_vs_vit = get_metrics(user_prompt, vit_caption)
181
- consensus = get_metrics(blip_caption, vit_caption)
182
-
183
- # XAI Verdict Logic
184
- if consensus < 0.5:
185
- verdict = "Model Confusion: High Uncertainty"
186
- elif user_vs_blip < 0.6:
187
- verdict = "Perspective Divergence: Prompt Mismatch"
188
  else:
189
- verdict = "Verified: Strong Alignment"
190
 
191
  return {
192
- "perspectives": {
193
- "user_intent": user_prompt,
194
- "blip_view": blip_caption,
195
- "vit_git_view": vit_caption
196
- },
197
- "audit_metrics": {
198
- "user_vs_blip": round(user_vs_blip, 4),
199
- "user_vs_vit": round(user_vs_vit, 4),
200
- "inter_model_consensus": round(consensus, 4)
201
- },
202
  "verdict": verdict
203
  }
 
1
  import os
 
 
 
2
  import io
3
+ import asyncio
4
+ import random
5
  import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
  import matplotlib.pyplot as plt
9
  from PIL import Image, ImageFilter
10
  from fastapi import FastAPI, UploadFile, File, Query
11
  from fastapi.responses import StreamingResponse
12
  from huggingface_hub import snapshot_download, login
 
13
 
14
  from transformers import (
15
  BlipProcessor, BlipForConditionalGeneration,
16
+ ViTImageProcessor, AutoProcessor, AutoModelForCausalLM,
17
+ CLIPModel, CLIPProcessor
18
  )
19
 
20
+ app = FastAPI(title="XAI Auditor Ensemble with CLIP Jury")
21
 
22
+ # --- Configuration & Paths ---
23
  REPO_ID = "SaniaE/Image_Captioning_Ensemble"
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  MODELS = {}
26
 
27
+ # Metadata for loading
28
+ MODEL_CONFIGS = {
29
  "blip": {
30
  "subfolder": "blip",
31
+ "proc_class": BlipProcessor,
32
+ "model_class": BlipForConditionalGeneration,
33
+ "base_path": "Salesforce/blip-image-captioning-large"
34
+ },
35
  "vit": {
36
  "subfolder": "vit",
37
+ "proc_classes": [ViTImageProcessor, AutoProcessor],
38
+ "model_class": AutoModelForCausalLM,
39
+ "base_paths": ["nlpconnect/vit-gpt2-image-captioning", "microsoft/git-large"]
40
+ },
41
+ "clip": {
42
+ "model_subfolder": "clip/clip_model",
43
+ "proc_subfolder": "clip/clip_processor"
44
  }
45
  }
46
 
 
50
  token = os.getenv("HF_Token")
51
  if token: login(token=token)
52
 
53
+ print(f"Syncing weights from {REPO_ID}...")
54
  local_dir = snapshot_download(repo_id=REPO_ID, token=token, local_dir="weights")
55
 
56
+ # 1. Load BLIP
57
+ cfg_b = MODEL_CONFIGS["blip"]
58
+ MODELS["blip"] = {
59
+ "model": cfg_b["model_class"].from_pretrained(os.path.join(local_dir, cfg_b["subfolder"])).to(DEVICE),
60
+ "processor": cfg_b["proc_class"].from_pretrained(cfg_b["base_path"])
61
+ }
62
+
63
+ # 2. Load ViT/GIT Ensemble
64
+ cfg_v = MODEL_CONFIGS["vit"]
65
+ MODELS["vit"] = {
66
+ "model": cfg_v["model_class"].from_pretrained(os.path.join(local_dir, cfg_v["subfolder"])).to(DEVICE),
67
+ "processor": (
68
+ cfg_v["proc_classes"][0].from_pretrained(cfg_v["base_paths"][0]),
69
+ cfg_v["proc_classes"][1].from_pretrained(cfg_v["base_paths"][1])
70
+ )
71
+ }
72
+
73
+ # 3. Load Fine-Tuned CLIP (Your Jury)
74
+ cfg_c = MODEL_CONFIGS["clip"]
75
+ MODELS["clip"] = {
76
+ "model": CLIPModel.from_pretrained(os.path.join(local_dir, cfg_c["model_subfolder"])).to(DEVICE),
77
+ "processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, cfg_c["proc_subfolder"]))
78
+ }
79
 
80
+ print("All models synchronized. Auditor is active.")
81
+
82
+ # --- Utilities ---
83
+
84
+ def _generate_sync(m_name, image, temp, top_k, top_p):
85
+ m_data = MODELS[m_name]
86
  if m_name == "vit":
87
  i_proc, t_proc = m_data["processor"]
88
  inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
89
+ ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p)
90
+ return t_proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
91
  else:
92
  proc = m_data["processor"]
93
  inputs = proc(images=image, return_tensors="pt").to(DEVICE)
94
+ ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p)
95
+ return proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
96
 
97
+ # --- Endpoints ---
98
 
99
  @app.post("/generate")
100
+ async def generate_captions(
101
  file: UploadFile = File(...),
102
  temp: float = Query(0.8),
103
  top_k: int = Query(50),
104
  top_p: float = Query(0.9)
105
  ):
106
+ """Generates 5 diverse captions using the model ensemble."""
107
  image = Image.open(file.file).convert("RGB")
108
+ architectures = ["blip", "vit"]
109
+ selection = random.choices(architectures, k=5)
110
 
111
+ tasks = [asyncio.to_thread(_generate_sync, m, image, temp, top_k, top_p) for m in selection]
 
 
112
  captions = await asyncio.gather(*tasks)
113
 
114
+ return {"captions": captions, "metadata": {"models_used": selection, "temp": temp}}
 
 
115
 
116
+ @app.post("/saliency")
117
+ async def get_vision_saliency(file: UploadFile = File(...)):
118
+ """Objective Saliency: Shows what the Vision Encoder focuses on (Self-Attention)."""
119
  image_bytes = await file.read()
120
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
121
 
 
123
  inputs = blip["processor"](images=orig_img, return_tensors="pt").to(DEVICE)
124
 
125
  with torch.no_grad():
126
+ outputs = blip["model"].vision_model(inputs.pixel_values, output_attentions=True)
127
+ attentions = outputs.attentions[-1] # Last layer
128
+ # Average heads, look at CLS token attention to patches
129
+ mask_1d = attentions[0, :, 0, 1:].mean(dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
130
  grid_size = int(np.sqrt(mask_1d.shape[-1]))
131
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
132
 
 
133
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
134
+ mask_img = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
135
+ mask_img = mask_img.filter(ImageFilter.GaussianBlur(radius=10))
 
 
 
136
 
137
+ heatmap = plt.get_cmap('magma')(np.array(mask_img)/255.0)
138
+ heatmap_img = Image.fromarray((heatmap[:, :, :3] * 255).astype('uint8')).convert("RGB")
139
+ blended = Image.blend(orig_img, heatmap_img, alpha=0.6)
140
 
141
  buf = io.BytesIO()
142
+ blended.save(buf, format="PNG")
143
  buf.seek(0)
144
  return StreamingResponse(buf, media_type="image/png")
145
 
146
+ @app.post("/audit")
147
+ async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str = Query(...)):
148
+ """The CLIP-Powered Jury: Compares User Intent vs. Model Perception."""
149
+ image_bytes = await file.read()
150
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
151
 
152
+ # 1. Model Perception
153
  blip_caption = await asyncio.to_thread(_generate_sync, "blip", image, 0.7, 50, 0.9)
154
+
155
+ # 2. CLIP Scoring (Multimodal Alignment)
156
+ clip_m = MODELS["clip"]["model"]
157
+ clip_p = MODELS["clip"]["processor"]
158
+
159
+ inputs = clip_p(text=[user_prompt, blip_caption], images=image, return_tensors="pt", padding=True).to(DEVICE)
160
+
161
+ with torch.no_grad():
162
+ outputs = clip_m(**inputs)
163
+ probs = outputs.logits_per_image.softmax(dim=-1).cpu().numpy()[0]
164
+
165
+ u_score, m_score = float(probs[0]), float(probs[1])
166
+
167
+ # 3. Decision Logic
168
+ if u_score < 0.35:
169
+ verdict = "Perspective Divergence: Intent not grounded in image."
170
+ elif abs(u_score - m_score) < 0.15:
171
+ verdict = "Consensus: High Alignment."
 
 
 
 
 
 
 
 
 
 
 
 
172
  else:
173
+ verdict = "Model Bias Detected."
174
 
175
  return {
176
+ "perspectives": {"user": user_prompt, "ai": blip_caption},
177
+ "audit_scores": {"intent_grounding": round(u_score, 4), "ai_grounding": round(m_score, 4)},
 
 
 
 
 
 
 
 
178
  "verdict": verdict
179
  }