n0wm3 commited on
Commit
f4d2612
·
verified ·
1 Parent(s): 74dff36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -68
app.py CHANGED
@@ -4,103 +4,105 @@ import numpy as np
4
  import torch
5
  from transformers import AutoModel
6
 
7
- # Choose device: CPU is fine, ignore CUDA warnings if any
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
- # Load the model from Hugging Face
11
  model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True)
12
  model = model.eval().to(device)
13
 
14
 
15
  def calculate_ctr(mask):
16
- """
17
- Calculate cardiothoracic ratio (CTR) from a 2D mask
18
- where:
19
- 1 = right lung
20
- 2 = left lung
21
- 3 = heart
22
- """
23
- # combine lungs
24
  lungs = np.zeros_like(mask, dtype=np.uint8)
25
  lungs[(mask == 1) | (mask == 2)] = 1
26
-
27
  heart = (mask == 3).astype("uint8")
28
 
29
- # lung coordinates
30
  lung_y, lung_x = np.where(lungs == 1)
31
  heart_y, heart_x = np.where(heart == 1)
32
 
33
- # safety checks in case segmentation fails
34
  if lung_x.size == 0 or heart_x.size == 0:
35
- return None
36
 
37
- lung_min = lung_x.min()
38
- lung_max = lung_x.max()
39
- heart_min = heart_x.min()
40
- heart_max = heart_x.max()
41
-
42
- lung_range = lung_max - lung_min
43
- heart_range = heart_max - heart_min
44
 
 
 
45
  if lung_range == 0:
46
- return None
47
-
48
- return float(heart_range / lung_range)
49
 
 
50
 
51
- def analyze(image):
52
- # image is a PIL image from Gradio
53
- if image is None:
54
- return None, "No image uploaded."
55
 
56
- # Convert PIL image to grayscale numpy array
57
- img = np.array(image.convert("L")) # shape: (H, W)
 
58
  h, w = img.shape[:2]
59
 
60
- # Preprocess according to model card
61
  x = model.preprocess(img)
62
  x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0).float()
63
 
64
  with torch.inference_mode():
65
  out = model(x.to(device))
66
 
67
- # Raw mask from model (usually 320x320)
68
  mask_small = out["mask"].argmax(1)[0].cpu().numpy()
 
69
 
70
- # IMPORTANT: resize mask to original image size (H, W)
71
- mask = cv2.resize(
72
- mask_small.astype("uint8"),
73
- (w, h),
74
- interpolation=cv2.INTER_NEAREST
75
- )
76
-
77
- # view / age / sex prediction
78
  view_idx = out["view"].argmax(1).item()
79
  age_pred = float(out["age"].item())
80
  female_prob = float(out["female"].item())
81
- female = female_prob >= 0.5
82
 
83
- # CTR calculation
84
- ctr = calculate_ctr(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # Base colored image
87
- color = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
88
- overlay = color.copy()
89
 
90
- # Color-code segments on resized mask
91
- overlay[mask == 1] = [0, 255, 0] # right lung - green
92
- overlay[mask == 2] = [0, 128, 255] # left lung - teal/orange-ish
93
- overlay[mask == 3] = [255, 0, 0] # heart - red
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  blended = cv2.addWeighted(color, 0.7, overlay, 0.3, 0)
96
 
97
- # Map view label
98
  view_map = {0: "AP", 1: "PA", 2: "lateral"}
99
  view = view_map.get(view_idx, "unknown")
100
 
101
- # Build result text
102
  lines = []
103
-
104
  if ctr is not None:
105
  lines.append(f"CTR: {ctr:.2f}")
106
  else:
@@ -109,35 +111,79 @@ def analyze(image):
109
  lines.extend([
110
  f"View (model): {view}",
111
  f"Predicted age: {age_pred:.0f} years",
112
- f"Predicted sex: {'Female' if female else 'Male'} (prob={female_prob:.2f})",
113
  "",
114
- "⚠️ For research/educational use only, not for clinical decision-making.",
115
  ])
116
 
117
  if view != "PA":
118
  lines.append("⚠️ CTR is normally interpreted on PA view. Interpret with caution.")
119
 
120
- result_text = "\n".join(lines)
121
-
122
- return blended, result_text
123
 
124
 
125
- demo = gr.Interface(
126
  fn=analyze,
127
  inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"),
128
  outputs=[
129
  gr.Image(label="Segmentation overlay"),
130
- gr.Textbox(label="AI output")
131
  ],
132
  title="AI CTR helper (research only)",
133
  description=(
134
- "Uploads a frontal chest radiograph, segments heart and lungs, "
135
- "and estimates cardiothoracic ratio (CTR) using the model "
136
- "'ianpan/chest-x-ray-basic'.\n\n"
137
- "This tool is for research and education only and is NOT approved "
138
- "for clinical use."
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
141
 
142
- if __name__ == '__main__':
143
  demo.launch()
 
4
  import torch
5
  from transformers import AutoModel
6
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
8
  model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True)
9
  model = model.eval().to(device)
10
 
11
 
12
  def calculate_ctr(mask):
 
 
 
 
 
 
 
 
13
  lungs = np.zeros_like(mask, dtype=np.uint8)
14
  lungs[(mask == 1) | (mask == 2)] = 1
 
15
  heart = (mask == 3).astype("uint8")
16
 
 
17
  lung_y, lung_x = np.where(lungs == 1)
18
  heart_y, heart_x = np.where(heart == 1)
19
 
 
20
  if lung_x.size == 0 or heart_x.size == 0:
21
+ return None, None, None, None, None
22
 
23
+ thorax_left = int(lung_x.min())
24
+ thorax_right = int(lung_x.max())
25
+ heart_left = int(heart_x.min())
26
+ heart_right = int(heart_x.max())
 
 
 
27
 
28
+ lung_range = thorax_right - thorax_left
29
+ heart_range = heart_right - heart_left
30
  if lung_range == 0:
31
+ ctr = None
32
+ else:
33
+ ctr = float(heart_range / lung_range)
34
 
35
+ return ctr, thorax_left, thorax_right, heart_left, heart_right
36
 
 
 
 
 
37
 
38
+ def _run_model(image):
39
+ """Shared logic: from PIL image -> (img_gray, mask, view_idx, age, female_prob, coords...)"""
40
+ img = np.array(image.convert("L"))
41
  h, w = img.shape[:2]
42
 
 
43
  x = model.preprocess(img)
44
  x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0).float()
45
 
46
  with torch.inference_mode():
47
  out = model(x.to(device))
48
 
 
49
  mask_small = out["mask"].argmax(1)[0].cpu().numpy()
50
+ mask = cv2.resize(mask_small.astype("uint8"), (w, h), interpolation=cv2.INTER_NEAREST)
51
 
 
 
 
 
 
 
 
 
52
  view_idx = out["view"].argmax(1).item()
53
  age_pred = float(out["age"].item())
54
  female_prob = float(out["female"].item())
 
55
 
56
+ ctr, thorax_left, thorax_right, heart_left, heart_right = calculate_ctr(mask)
57
+
58
+ return (
59
+ img,
60
+ mask,
61
+ h,
62
+ w,
63
+ ctr,
64
+ thorax_left,
65
+ thorax_right,
66
+ heart_left,
67
+ heart_right,
68
+ view_idx,
69
+ age_pred,
70
+ female_prob,
71
+ )
72
 
 
 
 
73
 
74
+ # ---------- 1) Visual demo (what you already have) ----------
 
 
 
75
 
76
+ def analyze(image):
77
+ if image is None:
78
+ return None, "No image uploaded."
79
+
80
+ (
81
+ img,
82
+ mask,
83
+ h,
84
+ w,
85
+ ctr,
86
+ thorax_left,
87
+ thorax_right,
88
+ heart_left,
89
+ heart_right,
90
+ view_idx,
91
+ age_pred,
92
+ female_prob,
93
+ ) = _run_model(image)
94
+
95
+ color = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
96
+ overlay = color.copy()
97
+ overlay[mask == 1] = [0, 255, 0]
98
+ overlay[mask == 2] = [0, 128, 255]
99
+ overlay[mask == 3] = [255, 0, 0]
100
  blended = cv2.addWeighted(color, 0.7, overlay, 0.3, 0)
101
 
 
102
  view_map = {0: "AP", 1: "PA", 2: "lateral"}
103
  view = view_map.get(view_idx, "unknown")
104
 
 
105
  lines = []
 
106
  if ctr is not None:
107
  lines.append(f"CTR: {ctr:.2f}")
108
  else:
 
111
  lines.extend([
112
  f"View (model): {view}",
113
  f"Predicted age: {age_pred:.0f} years",
114
+ f"Predicted sex: {'Female' if female_prob >= 0.5 else 'Male'} (prob={female_prob:.2f})",
115
  "",
116
+ "⚠️ Research/educational use only, NOT for clinical decision-making.",
117
  ])
118
 
119
  if view != "PA":
120
  lines.append("⚠️ CTR is normally interpreted on PA view. Interpret with caution.")
121
 
122
+ return blended, "\n".join(lines)
 
 
123
 
124
 
125
+ visual_demo = gr.Interface(
126
  fn=analyze,
127
  inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"),
128
  outputs=[
129
  gr.Image(label="Segmentation overlay"),
130
+ gr.Textbox(label="AI output"),
131
  ],
132
  title="AI CTR helper (research only)",
133
  description=(
134
+ "Segments heart and lungs and estimates CTR using 'ianpan/chest-x-ray-basic'. "
135
+ "Research use only."
136
+ ),
137
+ )
138
+
139
+
140
+ # ---------- 2) JSON points API (for your Lovable app) ----------
141
+
142
+ def get_points(image):
143
+ if image is None:
144
+ return {"error": "No image uploaded"}
145
+
146
+ (
147
+ img,
148
+ mask,
149
+ h,
150
+ w,
151
+ ctr,
152
+ thorax_left,
153
+ thorax_right,
154
+ heart_left,
155
+ heart_right,
156
+ view_idx,
157
+ age_pred,
158
+ female_prob,
159
+ ) = _run_model(image)
160
+
161
+ result = {
162
+ "image_width": w,
163
+ "image_height": h,
164
+ "ctr": ctr,
165
+ "thorax_left_px": thorax_left,
166
+ "thorax_right_px": thorax_right,
167
+ "heart_left_px": heart_left,
168
+ "heart_right_px": heart_right,
169
+ "view_idx": int(view_idx),
170
+ }
171
+ return result
172
+
173
+
174
+ points_api = gr.Interface(
175
+ fn=get_points,
176
+ inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"),
177
+ outputs=gr.JSON(label="CTR points JSON"),
178
+ title="CTR points API",
179
+ description="Returns thorax/heart x-coordinates and CTR as JSON.",
180
+ api_name="ctr_points", # important for programmatic calls
181
+ )
182
+
183
+ demo = gr.TabbedInterface(
184
+ [visual_demo, points_api],
185
+ ["Viewer", "JSON API"],
186
  )
187
 
188
+ if __name__ == "__main__":
189
  demo.launch()