darisdzakwanhoesien commited on
Commit
6f50095
·
verified ·
1 Parent(s): 139f241

Update for all model running

Browse files
Files changed (1) hide show
  1. app.py +61 -43
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import torch
 
4
 
5
- # List of models
 
 
6
  MODELS = {
7
  "econbert": "climatebert/econbert",
8
  "controversy-classification": "climatebert/ClimateControversyBERT_classification",
@@ -19,10 +21,12 @@ MODELS = {
19
  "environmental-claims": "climatebert/environmental-claims",
20
  "climate-f": "climatebert/distilroberta-base-climate-f",
21
  "climate-d-s": "climatebert/distilroberta-base-climate-d-s",
22
- "climate-d": "climatebert/distilroberta-base-climate-d"
23
  }
24
 
25
- # Human-readable label mappings
 
 
26
  LABEL_MAPS = {
27
  "climate-commitment": {
28
  "LABEL_0": "Not about climate commitments",
@@ -54,17 +58,19 @@ LABEL_MAPS = {
54
  "LABEL_0": "Not about renewables",
55
  "LABEL_1": "About renewables",
56
  },
57
- # You can expand mappings for other models after checking their model cards
58
  }
59
 
60
- # Cache for loaded pipelines
 
 
61
  pipelines = {}
62
 
63
  def load_model(model_key):
64
- """Load pipeline for the selected model with truncation enabled."""
65
  if model_key not in pipelines:
66
  repo_id = MODELS[model_key]
67
  device = 0 if torch.cuda.is_available() else -1
 
68
  pipelines[model_key] = pipeline(
69
  "text-classification",
70
  model=repo_id,
@@ -75,45 +81,57 @@ def load_model(model_key):
75
  )
76
  return pipelines[model_key]
77
 
78
- def predict(model_key, text):
79
- """Run inference on selected model with truncation and readable labels."""
 
 
 
80
  if not text.strip():
81
- return "Please enter some text."
82
-
83
- try:
84
- model = load_model(model_key)
85
- results = model(text)
86
 
87
- label_map = LABEL_MAPS.get(model_key, {})
88
- formatted = "\n".join([
89
- f"{label_map.get(r['label'], r['label'])}: {r['score']:.2f}"
90
- for r in results
91
- ])
 
92
 
93
- return f"Predictions for '{text[:50]}...':\n{formatted}"
94
- except Exception as e:
95
- return f"Error: {str(e)} (Check input length or model card for details)."
 
96
 
97
- # Gradio interface
98
- with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo:
99
- gr.Markdown("# 🌍 ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).")
100
-
101
- with gr.Row():
102
- model_dropdown = gr.Dropdown(
103
- choices=list(MODELS.keys()),
104
- label="Select Model",
105
- value=list(MODELS.keys())[0]
106
- )
107
- text_input = gr.Textbox(
108
- label="Input Text",
109
- placeholder="E.g., 'Companies must reduce emissions to net zero by 2050.'",
110
- lines=2
111
- )
112
-
113
- output = gr.Textbox(label="Output", lines=5)
114
-
115
- predict_btn = gr.Button("Run Inference")
116
- predict_btn.click(predict, inputs=[model_dropdown, text_input], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import pipeline
4
 
5
+ # -------------------
6
+ # 1. Model definitions
7
+ # -------------------
8
  MODELS = {
9
  "econbert": "climatebert/econbert",
10
  "controversy-classification": "climatebert/ClimateControversyBERT_classification",
 
21
  "environmental-claims": "climatebert/environmental-claims",
22
  "climate-f": "climatebert/distilroberta-base-climate-f",
23
  "climate-d-s": "climatebert/distilroberta-base-climate-d-s",
24
+ "climate-d": "climatebert/distilroberta-base-climate-d",
25
  }
26
 
27
+ # -------------------
28
+ # 2. Human-readable label maps
29
+ # -------------------
30
  LABEL_MAPS = {
31
  "climate-commitment": {
32
  "LABEL_0": "Not about climate commitments",
 
58
  "LABEL_0": "Not about renewables",
59
  "LABEL_1": "About renewables",
60
  },
 
61
  }
62
 
63
+ # -------------------
64
+ # 3. Pipeline cache
65
+ # -------------------
66
  pipelines = {}
67
 
68
  def load_model(model_key):
69
+ """Load and cache a model pipeline."""
70
  if model_key not in pipelines:
71
  repo_id = MODELS[model_key]
72
  device = 0 if torch.cuda.is_available() else -1
73
+ print(f"🔹 Loading model: {model_key} ({repo_id})")
74
  pipelines[model_key] = pipeline(
75
  "text-classification",
76
  model=repo_id,
 
81
  )
82
  return pipelines[model_key]
83
 
84
+ # -------------------
85
+ # 4. Inference across all models
86
+ # -------------------
87
+ def predict_all_models(text):
88
+ """Run inference across all ClimateBERT models and return structured output."""
89
  if not text.strip():
90
+ return "⚠️ Please enter some text."
 
 
 
 
91
 
92
+ results_summary = []
93
+ for model_key, repo in MODELS.items():
94
+ try:
95
+ model = load_model(model_key)
96
+ outputs = model(text)
97
+ label_map = LABEL_MAPS.get(model_key, {})
98
 
99
+ formatted = "\n".join([
100
+ f"• {label_map.get(r['label'], r['label'])}: {r['score']:.2f}"
101
+ for r in outputs
102
+ ])
103
 
104
+ results_summary.append(f"### {model_key}\n{formatted}")
105
+ except Exception as e:
106
+ results_summary.append(f"### {model_key}\n❌ Error: {str(e)}")
107
+
108
+ return "\n\n".join(results_summary)
109
+
110
+ # -------------------
111
+ # 5. Gradio UI
112
+ # -------------------
113
+ with gr.Blocks(title="🌍 ClimateBERT All-Models Analyzer") as demo:
114
+ gr.Markdown("""
115
+ # 🌍 ClimateBERT Multi-Model Analysis
116
+ This app runs **all ClimateBERT models** on your input text (`mergedMarkdown` style).
117
+ It detects sentiment, specificity, renewables, commitments, and more — all at once.
118
+ """)
119
+
120
+ text_input = gr.Textbox(
121
+ label="Input Text (mergedMarkdown)",
122
+ placeholder="Paste the sustainability report, ESG statement, or corporate disclosure here...",
123
+ lines=5
124
+ )
125
+
126
+ output = gr.Markdown(label="Model Outputs")
127
+
128
+ run_btn = gr.Button("🔍 Run All Models")
129
+ run_btn.click(predict_all_models, inputs=text_input, outputs=output)
130
+
131
+ gr.Markdown("""
132
+ ---
133
+ **Note:** Each model captures a different aspect of climate-related discourse (e.g., sentiment, specificity, commitments, etc.).
134
+ """)
135
 
136
  if __name__ == "__main__":
137
+ demo.launch(server_name="0.0.0.0", server_port=7860)