LukeFP commited on
Commit
bc9ecb4
·
1 Parent(s): 261241c

added the folder

Browse files
0219_gradio/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PhySH Taxonomy Classifier — Gradio App
2
+
3
+ Interactive web app that predicts APS PhySH **disciplines** and **research-area concepts**
4
+ for a given paper title + abstract.
5
+
6
+ ## How it works
7
+
8
+ 1. Text is embedded with `google/embeddinggemma-300m` (768-dim, L2-normalised).
9
+ 2. **Stage 1** — A multi-label MLP predicts discipline probabilities (18 classes).
10
+ 3. **Stage 2** — A discipline-conditioned MLP concatenates the embedding with discipline
11
+ probabilities and predicts research-area concepts (186 classes).
12
+
13
+ Both models are `.pt` checkpoints trained in `../0120_taxonomy_training_inference/`.
14
+
15
+ ## Setup
16
+
17
+ The app uses the project-level virtualenv (`.venv` at the repo root).
18
+
19
+ ```bash
20
+ # From the repo root
21
+ source .venv/bin/activate
22
+
23
+ # Install the one extra dependency
24
+ pip install gradio
25
+ ```
26
+
27
+ ## Run
28
+
29
+ ```bash
30
+ cd 0219_gradio
31
+ python app.py
32
+ ```
33
+
34
+ Then open `http://127.0.0.1:7860` in your browser.
35
+
36
+ ## Model files
37
+
38
+ The app expects these checkpoints in the same directory as `app.py`:
39
+
40
+ - `discipline_classifier_gemma_20260130_140842.pt`
41
+ - `concept_conditioned_gemma_20260130_140842.pt`
0219_gradio/__pycache__/app.cpython-313.pyc ADDED
Binary file (13.9 kB). View file
 
0219_gradio/app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PhySH Taxonomy Classifier — Gradio App
3
+
4
+ Two-stage hierarchical cascade:
5
+ Stage 1 → Discipline prediction (18-class multi-label)
6
+ Stage 2 → Concept prediction (186-class multi-label, conditioned on discipline probs)
7
+
8
+ Models were trained on APS PhySH labels with google/embeddinggemma-300m embeddings.
9
+ """
10
+
11
+ import re
12
+ from pathlib import Path
13
+ from typing import Dict, List, Tuple
14
+
15
+ import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ from sentence_transformers import SentenceTransformer
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Model definitions (mirror the training code exactly)
23
+ # ---------------------------------------------------------------------------
24
+
25
+ class MultiLabelMLP(nn.Module):
26
+ def __init__(self, input_dim: int, output_dim: int,
27
+ hidden_layers: Tuple[int, ...] = (1024, 512), dropout: float = 0.3):
28
+ super().__init__()
29
+ layers = []
30
+ prev_dim = input_dim
31
+ for hidden_dim in hidden_layers:
32
+ layers.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)])
33
+ prev_dim = hidden_dim
34
+ layers.append(nn.Linear(prev_dim, output_dim))
35
+ self.network = nn.Sequential(*layers)
36
+
37
+ def forward(self, x):
38
+ return self.network(x)
39
+
40
+
41
+ class DisciplineConditionedMLP(nn.Module):
42
+ def __init__(self, embedding_dim: int, discipline_dim: int, output_dim: int,
43
+ hidden_layers: Tuple[int, ...] = (1024, 512), dropout: float = 0.3,
44
+ discipline_dropout: float = 0.0, use_logits: bool = False):
45
+ super().__init__()
46
+ self.use_logits = use_logits
47
+ self.discipline_dropout = nn.Dropout(discipline_dropout)
48
+ layers = []
49
+ prev_dim = embedding_dim + discipline_dim
50
+ for hidden_dim in hidden_layers:
51
+ layers.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)])
52
+ prev_dim = hidden_dim
53
+ layers.append(nn.Linear(prev_dim, output_dim))
54
+ self.network = nn.Sequential(*layers)
55
+
56
+ def forward(self, embedding: torch.Tensor, discipline_probs: torch.Tensor) -> torch.Tensor:
57
+ if self.use_logits:
58
+ disc_features = torch.clamp(discipline_probs, 1e-7, 1 - 1e-7)
59
+ disc_features = torch.log(disc_features / (1 - disc_features))
60
+ else:
61
+ disc_features = discipline_probs
62
+ disc_features = self.discipline_dropout(disc_features)
63
+ return self.network(torch.cat([embedding, disc_features], dim=1))
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Paths
68
+ # ---------------------------------------------------------------------------
69
+ MODELS_DIR = Path(__file__).resolve().parent
70
+ DISCIPLINE_MODEL_PATH = MODELS_DIR / "discipline_classifier_gemma_20260130_140842.pt"
71
+ CONCEPT_MODEL_PATH = MODELS_DIR / "concept_conditioned_gemma_20260130_140842.pt"
72
+ EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m"
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Globals (loaded once at startup)
76
+ # ---------------------------------------------------------------------------
77
+ device: str = "cpu"
78
+ embedding_model: SentenceTransformer = None
79
+ discipline_model: MultiLabelMLP = None
80
+ concept_model: DisciplineConditionedMLP = None
81
+ discipline_labels: List[Dict] = []
82
+ concept_labels: List[Dict] = []
83
+
84
+
85
+ def load_models():
86
+ global device, embedding_model, discipline_model, concept_model
87
+ global discipline_labels, concept_labels
88
+
89
+ if torch.cuda.is_available():
90
+ device = "cuda"
91
+ elif torch.backends.mps.is_available():
92
+ device = "mps"
93
+ else:
94
+ device = "cpu"
95
+
96
+ print(f"Loading embedding model ({EMBEDDING_MODEL_NAME}) on {device} …")
97
+ embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=device)
98
+
99
+ # --- discipline model ---
100
+ disc_ckpt = torch.load(DISCIPLINE_MODEL_PATH, map_location=device, weights_only=False)
101
+ dc = disc_ckpt["model_config"]
102
+ discipline_model = MultiLabelMLP(
103
+ dc["input_dim"], dc["output_dim"],
104
+ tuple(dc["hidden_layers"]), dc["dropout"],
105
+ )
106
+ discipline_model.load_state_dict(disc_ckpt["model_state_dict"])
107
+ discipline_model.to(device).eval()
108
+ discipline_labels = disc_ckpt["class_labels"]
109
+
110
+ # --- concept model ---
111
+ conc_ckpt = torch.load(CONCEPT_MODEL_PATH, map_location=device, weights_only=False)
112
+ cc = conc_ckpt["model_config"]
113
+ concept_model = DisciplineConditionedMLP(
114
+ cc["embedding_dim"], cc["discipline_dim"], cc["output_dim"],
115
+ tuple(cc["hidden_layers"]), cc["dropout"],
116
+ cc.get("discipline_dropout", 0.0), cc.get("use_logits", False),
117
+ )
118
+ concept_model.load_state_dict(conc_ckpt["model_state_dict"])
119
+ concept_model.to(device).eval()
120
+ concept_labels = conc_ckpt["class_labels"]
121
+
122
+ print(f"Loaded {len(discipline_labels)} disciplines, {len(concept_labels)} concepts")
123
+
124
+
125
+ # ---------------------------------------------------------------------------
126
+ # Prediction
127
+ # ---------------------------------------------------------------------------
128
+
129
+ def clean_text(text: str) -> str:
130
+ if not text:
131
+ return ""
132
+ return re.sub(r"\s+", " ", text).strip()
133
+
134
+
135
+ def predict(title: str, abstract: str, threshold: float, top_k: int):
136
+ """Run the two-stage cascade and return formatted results."""
137
+ combined = clean_text(title)
138
+ abs_clean = clean_text(abstract)
139
+ if combined and abs_clean:
140
+ combined = f"{combined} [SEP] {abs_clean}"
141
+ elif abs_clean:
142
+ combined = abs_clean
143
+
144
+ if not combined.strip():
145
+ return "Please enter at least a title or abstract.", ""
146
+
147
+ # Embed
148
+ embedding = embedding_model.encode(
149
+ [combined], normalize_embeddings=True, convert_to_numpy=True,
150
+ )
151
+ emb_tensor = torch.FloatTensor(embedding).to(device)
152
+
153
+ with torch.no_grad():
154
+ # Stage 1
155
+ disc_logits = discipline_model(emb_tensor)
156
+ disc_probs = torch.sigmoid(disc_logits).cpu().numpy()[0]
157
+
158
+ # Stage 2
159
+ disc_probs_tensor = torch.FloatTensor(disc_probs).unsqueeze(0).to(device)
160
+ conc_logits = concept_model(emb_tensor, disc_probs_tensor)
161
+ conc_probs = torch.sigmoid(conc_logits).cpu().numpy()[0]
162
+
163
+ # Format discipline results
164
+ disc_order = np.argsort(disc_probs)[::-1]
165
+ disc_lines = []
166
+ for rank, idx in enumerate(disc_order[:top_k], 1):
167
+ prob = disc_probs[idx]
168
+ label = discipline_labels[idx].get("label", f"Discipline_{idx}")
169
+ marker = "**" if prob >= threshold else ""
170
+ disc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}")
171
+
172
+ # Format concept results
173
+ conc_order = np.argsort(conc_probs)[::-1]
174
+ conc_lines = []
175
+ for rank, idx in enumerate(conc_order[:top_k], 1):
176
+ prob = conc_probs[idx]
177
+ label = concept_labels[idx].get("label", f"Concept_{idx}")
178
+ marker = "**" if prob >= threshold else ""
179
+ conc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}")
180
+
181
+ disc_md = f"### Disciplines (threshold ≥ {threshold:.0%})\n\n" + "\n".join(disc_lines)
182
+ conc_md = f"### Research-Area Concepts (threshold ≥ {threshold:.0%})\n\n" + "\n".join(conc_lines)
183
+ return disc_md, conc_md
184
+
185
+
186
+ # ---------------------------------------------------------------------------
187
+ # Gradio UI
188
+ # ---------------------------------------------------------------------------
189
+
190
+ EXAMPLES = [
191
+ [
192
+ "Observation of Gravitational Waves from a Binary Black Hole Merger",
193
+ "On September 14, 2015 at 09:50:45 UTC the two detectors of the Laser "
194
+ "Interferometer Gravitational-Wave Observatory simultaneously observed a "
195
+ "transient gravitational-wave signal. The signal sweeps upwards in frequency "
196
+ "from 35 to 250 Hz with a peak gravitational-wave strain of 1.0×10⁻²¹.",
197
+ ],
198
+ [
199
+ "Topological Insulators and Superconductors",
200
+ "Topological insulators are electronic materials that have a bulk band gap "
201
+ "like an ordinary insulator but have protected conducting states on their "
202
+ "edge or surface. We review the theoretical foundation for topological "
203
+ "insulators and superconductors and describe recent experiments.",
204
+ ],
205
+ [
206
+ "Deep Learning for Particle Physics",
207
+ "We review the application of modern machine learning techniques to the "
208
+ "analysis of data from high-energy particle physics experiments. Neural "
209
+ "networks are used for jet tagging, event classification, anomaly detection, "
210
+ "and fast simulation of detector response.",
211
+ ],
212
+ ]
213
+
214
+
215
+ def build_app() -> gr.Blocks:
216
+ with gr.Blocks(
217
+ title="PhySH Taxonomy Classifier",
218
+ theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
219
+ ) as demo:
220
+ gr.Markdown(
221
+ "# PhySH Taxonomy Classifier\n"
222
+ "Enter a paper **title** and **abstract** to predict APS PhySH disciplines "
223
+ "and research-area concepts using a two-stage hierarchical cascade.\n\n"
224
+ "Labels above the threshold are **bolded**."
225
+ )
226
+
227
+ with gr.Row():
228
+ with gr.Column(scale=2):
229
+ title_box = gr.Textbox(label="Title", lines=2, placeholder="Paper title …")
230
+ abstract_box = gr.Textbox(label="Abstract", lines=8, placeholder="Paper abstract …")
231
+
232
+ with gr.Row():
233
+ threshold_slider = gr.Slider(
234
+ minimum=0.05, maximum=0.95, value=0.35, step=0.05,
235
+ label="Threshold",
236
+ )
237
+ topk_slider = gr.Slider(
238
+ minimum=1, maximum=20, value=10, step=1, label="Top-K",
239
+ )
240
+
241
+ predict_btn = gr.Button("Classify", variant="primary", size="lg")
242
+
243
+ with gr.Column(scale=3):
244
+ disc_output = gr.Markdown(label="Disciplines")
245
+ conc_output = gr.Markdown(label="Concepts")
246
+
247
+ predict_btn.click(
248
+ fn=predict,
249
+ inputs=[title_box, abstract_box, threshold_slider, topk_slider],
250
+ outputs=[disc_output, conc_output],
251
+ )
252
+
253
+ gr.Examples(
254
+ examples=EXAMPLES,
255
+ inputs=[title_box, abstract_box],
256
+ label="Example papers",
257
+ )
258
+
259
+ return demo
260
+
261
+
262
+ if __name__ == "__main__":
263
+ load_models()
264
+ app = build_app()
265
+ app.launch()
0219_gradio/concept_conditioned_gemma_20260130_140842.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77da740b38d773acad76a8b1f9d8b4a37a28bcefd3ef1d869564fbbcda7e18d7
3
+ size 5733613
0219_gradio/discipline_classifier_gemma_20260130_140842.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30d46f03c0a5c10d747525096b46c63909a86c40c7a4adc2c5989846c8e4ae61
3
+ size 5291653
0219_gradio/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=5.0,<6.0
2
+ torch>=2.0
3
+ sentence-transformers>=3.0
4
+ numpy