|
|
""" |
|
|
PhySH Taxonomy Classifier — Gradio App |
|
|
|
|
|
Two-stage hierarchical cascade: |
|
|
Stage 1 → Discipline prediction (18-class multi-label) |
|
|
Stage 2 → Concept prediction (186-class multi-label, conditioned on discipline probs) |
|
|
|
|
|
Models were trained on APS PhySH labels with google/embeddinggemma-300m embeddings. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiLabelMLP(nn.Module): |
|
|
def __init__(self, input_dim: int, output_dim: int, |
|
|
hidden_layers: Tuple[int, ...] = (1024, 512), dropout: float = 0.3): |
|
|
super().__init__() |
|
|
layers = [] |
|
|
prev_dim = input_dim |
|
|
for hidden_dim in hidden_layers: |
|
|
layers.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]) |
|
|
prev_dim = hidden_dim |
|
|
layers.append(nn.Linear(prev_dim, output_dim)) |
|
|
self.network = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.network(x) |
|
|
|
|
|
|
|
|
class DisciplineConditionedMLP(nn.Module): |
|
|
def __init__(self, embedding_dim: int, discipline_dim: int, output_dim: int, |
|
|
hidden_layers: Tuple[int, ...] = (1024, 512), dropout: float = 0.3, |
|
|
discipline_dropout: float = 0.0, use_logits: bool = False): |
|
|
super().__init__() |
|
|
self.use_logits = use_logits |
|
|
self.discipline_dropout = nn.Dropout(discipline_dropout) |
|
|
layers = [] |
|
|
prev_dim = embedding_dim + discipline_dim |
|
|
for hidden_dim in hidden_layers: |
|
|
layers.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]) |
|
|
prev_dim = hidden_dim |
|
|
layers.append(nn.Linear(prev_dim, output_dim)) |
|
|
self.network = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, embedding: torch.Tensor, discipline_probs: torch.Tensor) -> torch.Tensor: |
|
|
if self.use_logits: |
|
|
disc_features = torch.clamp(discipline_probs, 1e-7, 1 - 1e-7) |
|
|
disc_features = torch.log(disc_features / (1 - disc_features)) |
|
|
else: |
|
|
disc_features = discipline_probs |
|
|
disc_features = self.discipline_dropout(disc_features) |
|
|
return self.network(torch.cat([embedding, disc_features], dim=1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODELS_DIR = Path(__file__).resolve().parent |
|
|
DISCIPLINE_MODEL_PATH = MODELS_DIR / "discipline_classifier_gemma_20260130_140842.pt" |
|
|
CONCEPT_MODEL_PATH = MODELS_DIR / "concept_conditioned_gemma_20260130_140842.pt" |
|
|
EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m" |
|
|
|
|
|
EXCLUDED_DISCIPLINES = {"Quantum Physics"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device: str = "cpu" |
|
|
embedding_model: SentenceTransformer = None |
|
|
discipline_model: MultiLabelMLP = None |
|
|
concept_model: DisciplineConditionedMLP = None |
|
|
discipline_labels: List[Dict] = [] |
|
|
concept_labels: List[Dict] = [] |
|
|
|
|
|
|
|
|
def load_models(): |
|
|
global device, embedding_model, discipline_model, concept_model |
|
|
global discipline_labels, concept_labels |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = "mps" |
|
|
else: |
|
|
device = "cpu" |
|
|
|
|
|
print(f"Loading embedding model ({EMBEDDING_MODEL_NAME}) on {device} …") |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
embedding_model = SentenceTransformer( |
|
|
EMBEDDING_MODEL_NAME, device=device, token=hf_token, |
|
|
) |
|
|
|
|
|
|
|
|
disc_ckpt = torch.load(DISCIPLINE_MODEL_PATH, map_location=device, weights_only=False) |
|
|
dc = disc_ckpt["model_config"] |
|
|
discipline_model = MultiLabelMLP( |
|
|
dc["input_dim"], dc["output_dim"], |
|
|
tuple(dc["hidden_layers"]), dc["dropout"], |
|
|
) |
|
|
discipline_model.load_state_dict(disc_ckpt["model_state_dict"]) |
|
|
discipline_model.to(device).eval() |
|
|
discipline_labels = disc_ckpt["class_labels"] |
|
|
|
|
|
|
|
|
conc_ckpt = torch.load(CONCEPT_MODEL_PATH, map_location=device, weights_only=False) |
|
|
cc = conc_ckpt["model_config"] |
|
|
concept_model = DisciplineConditionedMLP( |
|
|
cc["embedding_dim"], cc["discipline_dim"], cc["output_dim"], |
|
|
tuple(cc["hidden_layers"]), cc["dropout"], |
|
|
cc.get("discipline_dropout", 0.0), cc.get("use_logits", False), |
|
|
) |
|
|
concept_model.load_state_dict(conc_ckpt["model_state_dict"]) |
|
|
concept_model.to(device).eval() |
|
|
concept_labels = conc_ckpt["class_labels"] |
|
|
|
|
|
print(f"Loaded {len(discipline_labels)} disciplines, {len(concept_labels)} concepts") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_text(text: str) -> str: |
|
|
if not text: |
|
|
return "" |
|
|
return re.sub(r"\s+", " ", text).strip() |
|
|
|
|
|
|
|
|
def predict(title: str, abstract: str, threshold: float, top_k: int): |
|
|
"""Run the two-stage cascade and return formatted results.""" |
|
|
combined = clean_text(title) |
|
|
abs_clean = clean_text(abstract) |
|
|
if combined and abs_clean: |
|
|
combined = f"{combined} [SEP] {abs_clean}" |
|
|
elif abs_clean: |
|
|
combined = abs_clean |
|
|
|
|
|
if not combined.strip(): |
|
|
return "Please enter at least a title or abstract.", "" |
|
|
|
|
|
|
|
|
embedding = embedding_model.encode( |
|
|
[combined], normalize_embeddings=True, convert_to_numpy=True, |
|
|
) |
|
|
emb_tensor = torch.FloatTensor(embedding).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
disc_logits = discipline_model(emb_tensor) |
|
|
disc_probs = torch.sigmoid(disc_logits).cpu().numpy()[0] |
|
|
|
|
|
|
|
|
disc_probs_tensor = torch.FloatTensor(disc_probs).unsqueeze(0).to(device) |
|
|
conc_logits = concept_model(emb_tensor, disc_probs_tensor) |
|
|
conc_probs = torch.sigmoid(conc_logits).cpu().numpy()[0] |
|
|
|
|
|
|
|
|
disc_order = np.argsort(disc_probs)[::-1] |
|
|
disc_lines = [] |
|
|
rank = 0 |
|
|
for idx in disc_order: |
|
|
label = discipline_labels[idx].get("label", f"Discipline_{idx}") |
|
|
if label in EXCLUDED_DISCIPLINES: |
|
|
continue |
|
|
rank += 1 |
|
|
if rank > top_k: |
|
|
break |
|
|
prob = disc_probs[idx] |
|
|
marker = "**" if prob >= threshold else "" |
|
|
disc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}") |
|
|
|
|
|
|
|
|
conc_order = np.argsort(conc_probs)[::-1] |
|
|
conc_lines = [] |
|
|
for rank, idx in enumerate(conc_order[:top_k], 1): |
|
|
prob = conc_probs[idx] |
|
|
label = concept_labels[idx].get("label", f"Concept_{idx}") |
|
|
marker = "**" if prob >= threshold else "" |
|
|
conc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}") |
|
|
|
|
|
disc_md = f"### Disciplines (threshold ≥ {threshold:.0%})\n\n" + "\n".join(disc_lines) |
|
|
conc_md = f"### Research-Area Concepts (threshold ≥ {threshold:.0%})\n\n" + "\n".join(conc_lines) |
|
|
return disc_md, conc_md |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EXAMPLES = [ |
|
|
[ |
|
|
"Quantum Computing: Vision and Challenges", |
|
|
( |
|
|
"The recent development of quantum computing, which uses entanglement, superposition, and other quantum fundamental concepts, " |
|
|
"can provide substantial processing advantages over traditional computing. These quantum features help solve many complex " |
|
|
"problems that cannot be solved otherwise with conventional computing methods. These problems include modeling quantum mechanics, " |
|
|
"logistics, chemical-based advances, drug design, statistical science, sustainable energy, banking, reliable communication, and " |
|
|
"quantum chemical engineering. The last few years have witnessed remarkable progress in quantum software and algorithm creation " |
|
|
"and quantum hardware research, which has significantly advanced the prospect of realizing quantum computers. It would be helpful " |
|
|
"to have comprehensive literature research on this area to grasp the current status and find outstanding problems that require " |
|
|
"considerable attention from the research community working in the quantum computing industry. To better understand quantum computing, " |
|
|
"this paper examines the foundations and vision based on current research in this area. We discuss cutting-edge developments in quantum " |
|
|
"computer hardware advancement and subsequent advances in quantum cryptography, quantum software, and high-scalability quantum computers. " |
|
|
"Many potential challenges and exciting new trends for quantum technology research and development are highlighted in this paper for a broader debate." |
|
|
), |
|
|
], |
|
|
[ |
|
|
"Topological Insulators and Superconductors", |
|
|
( |
|
|
"Topological insulators are electronic materials that have a bulk band gap like an ordinary insulator but have protected conducting states " |
|
|
"on their edge or surface. We review the theoretical foundation for topological insulators and superconductors and describe recent experiments." |
|
|
), |
|
|
], |
|
|
[ |
|
|
"Floquet Topological Insulator in Semiconductor Quantum Wells", |
|
|
( |
|
|
"Topological phase transitions between a conventional insulator and a state of matter with topological properties have been proposed and observed " |
|
|
"in mercury telluride - cadmium telluride quantum wells. We show that a topological state can be induced in such a device, initially in the trivial " |
|
|
"phase, by irradiation with microwave frequencies, without closing the gap and crossing the phase transition. We show that the quasi-energy spectrum " |
|
|
"exhibits a single pair of helical edge states. The velocity of the edge states can be tuned by adjusting the intensity of the microwave radiation. " |
|
|
"We discuss the necessary experimental parameters for our proposal. This proposal provides an example and a proof of principle of a new non-equilibrium " |
|
|
"topological state, Floquet topological insulator, introduced in this paper." |
|
|
), |
|
|
], |
|
|
] |
|
|
|
|
|
|
|
|
def build_app() -> gr.Blocks: |
|
|
with gr.Blocks( |
|
|
title="PhySH Taxonomy Classifier", |
|
|
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"), |
|
|
) as demo: |
|
|
gr.Markdown( |
|
|
"# PhySH Taxonomy Classifier\n" |
|
|
"Enter a paper **title** and **abstract** to predict APS PhySH disciplines " |
|
|
"and research-area concepts using a two-stage hierarchical cascade.\n\n" |
|
|
"Labels above the threshold are **bolded**." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
title_box = gr.Textbox(label="Title", lines=2, placeholder="Paper title …") |
|
|
abstract_box = gr.Textbox(label="Abstract", lines=8, placeholder="Paper abstract …") |
|
|
|
|
|
with gr.Row(): |
|
|
threshold_slider = gr.Slider( |
|
|
minimum=0.05, maximum=0.95, value=0.35, step=0.05, |
|
|
label="Threshold", |
|
|
) |
|
|
topk_slider = gr.Slider( |
|
|
minimum=1, maximum=20, value=10, step=1, label="Top-K", |
|
|
) |
|
|
|
|
|
predict_btn = gr.Button("Classify", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
disc_output = gr.Markdown(label="Disciplines") |
|
|
conc_output = gr.Markdown(label="Concepts") |
|
|
|
|
|
predict_btn.click( |
|
|
fn=predict, |
|
|
inputs=[title_box, abstract_box, threshold_slider, topk_slider], |
|
|
outputs=[disc_output, conc_output], |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=EXAMPLES, |
|
|
inputs=[title_box, abstract_box], |
|
|
label="Example papers", |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
load_models() |
|
|
app = build_app() |
|
|
app.launch() |
|
|
|