Lamp Socrates commited on
Commit
4efeb3b
1 Parent(s): b022555
Files changed (1) hide show
  1. app.py +95 -36
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import uvicorn
2
  import threading
 
3
  from typing import Optional
4
  from transformers import pipeline
5
  from transformers import AutoTokenizer, AutoModelForTokenClassification
@@ -13,11 +14,14 @@ from fastapi import FastAPI
13
  from pydantic import BaseModel
14
  from typing import List, Dict
15
 
 
16
  # Define the FastAPI app
17
  app = FastAPI()
18
  model_cache: Optional[object] = None
 
19
 
20
  def load_model():
 
21
 
22
  tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
23
  model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
@@ -36,6 +40,12 @@ def load_plod_cw_dataset():
36
  dataset = load_dataset("surrey-nlp/PLOD-CW")
37
  return dataset
38
 
 
 
 
 
 
 
39
  def get_cached_model():
40
  global model_cache
41
  if model_cache is None:
@@ -44,8 +54,7 @@ def get_cached_model():
44
 
45
  # Cache the model when the server starts
46
  model = get_cached_model()
47
-
48
-
49
 
50
  class Entity(BaseModel):
51
  entity: str
@@ -62,15 +71,20 @@ class NERRequest(BaseModel):
62
 
63
  @app.get("/hello")
64
  def read_root():
 
65
  return {"message": "Hello, World!"}
66
 
67
 
68
  @app.post("/ner", response_model=NERResponse)
69
  def get_entities(request: NERRequest):
 
70
  print(request)
 
71
  model = get_cached_model()
 
72
  # Use the NER model to detect entities
73
  entities = model(request.text)
 
74
  print(entities[0].keys())
75
  # Convert entities to the response model
76
  response_entities = [Entity(**entity) for entity in entities]
@@ -81,8 +95,9 @@ def get_color_for_label(label: str) -> str:
81
  # Define a mapping of labels to colors
82
  color_mapping = {
83
  "I-LF": "red",
 
84
  "B-AC": "blue",
85
- "LOC": "green",
86
  # Add more labels and colors as needed
87
  }
88
  return color_mapping.get(label, "black") # Default to black if label not found
@@ -90,30 +105,73 @@ def get_color_for_label(label: str) -> str:
90
 
91
  # Define the Gradio interface function
92
  def ner_demo(text):
 
93
  model = get_cached_model()
94
  entities = model(text)
95
- #return {"entities": entities}
96
 
97
- # Color code the entities
98
- color_coded_text = text
 
 
 
99
  for entity in entities:
100
- #print(entity)
101
  start, end, label = entity["start"], entity["end"], entity["entity"]
102
- color = get_color_for_label(label) # You need to define this function
103
  entity_text = text[start:end]
104
- colored_entity = f'<span style="color: {color}; font-weight: bold;">{entity_text}</span>'
105
- color_coded_text = color_coded_text[:start] + colored_entity + color_coded_text[end:]
106
-
107
- return color_coded_text
108
-
109
- PROJECT_INTRO = "This is a HF Spaces hosted Gradio App built by NLP Group 27 . The model has been trained on surrey-nlp/PLOD-CW dataset"
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def echo(text, request: gr.Request):
 
112
  if request:
113
- print("Request headers dictionary:", request.headers)
114
- print("IP address:", request.client.host)
115
- print("Query parameters:", dict(request.query_params))
116
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Create the Gradio interface
119
  demo = gr.Interface(
@@ -124,26 +182,27 @@ demo = gr.Interface(
124
  title="Named Entity Recognition on PLOD-CW ",
125
  description=f"{PROJECT_INTRO}\n\nEnter text to extract named entities using a NER model."
126
  )
127
- '''
128
- with gr.Blocks() as demo:
129
- gr.Markdown("# Page Title")
130
- gr.Markdown("## Subtitle with h2 Font")
131
- inputs=gr.Textbox(lines=10, placeholder="Enter text here...", label="Input Text")
132
-
133
- with gr.Column():
134
- echo_output = gr.Textbox(label="Echo Output")
135
- html_output = ner_demo
136
 
137
- with gr.Column():
138
- button1 = gr.Button("Submit")
139
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- #CUSTOM_PATH = "/gradio"
142
- #app = gr.mount_gradio_app(app, demo, path=CUSTOM_PATH)
143
 
144
- # Function to run FastAPI
145
- def run_fastapi():
146
- uvicorn.run(app, host="0.0.0.0", port=8000)
147
 
148
  # Function to run Gradio
149
 
 
1
  import uvicorn
2
  import threading
3
+ from collections import Counter
4
  from typing import Optional
5
  from transformers import pipeline
6
  from transformers import AutoTokenizer, AutoModelForTokenClassification
 
14
  from pydantic import BaseModel
15
  from typing import List, Dict
16
 
17
+
18
  # Define the FastAPI app
19
  app = FastAPI()
20
  model_cache: Optional[object] = None
21
+ dataset_cache : Optional[object] = None
22
 
23
  def load_model():
24
+ """ We load the model at startup"""
25
 
26
  tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
27
  model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
 
40
  dataset = load_dataset("surrey-nlp/PLOD-CW")
41
  return dataset
42
 
43
+ def get_cached_data():
44
+ global dataset_cache
45
+ if dataset_cache is None:
46
+ dataset_cache = load_plod_cw_dataset()
47
+ return dataset_cache
48
+
49
  def get_cached_model():
50
  global model_cache
51
  if model_cache is None:
 
54
 
55
  # Cache the model when the server starts
56
  model = get_cached_model()
57
+ #plod_cw = get_cached_data()
 
58
 
59
  class Entity(BaseModel):
60
  entity: str
 
71
 
72
  @app.get("/hello")
73
  def read_root():
74
+ """useful for testing connections"""
75
  return {"message": "Hello, World!"}
76
 
77
 
78
  @app.post("/ner", response_model=NERResponse)
79
  def get_entities(request: NERRequest):
80
+ """ This is invoked while API Testing """
81
  print(request)
82
+
83
  model = get_cached_model()
84
+
85
  # Use the NER model to detect entities
86
  entities = model(request.text)
87
+
88
  print(entities[0].keys())
89
  # Convert entities to the response model
90
  response_entities = [Entity(**entity) for entity in entities]
 
95
  # Define a mapping of labels to colors
96
  color_mapping = {
97
  "I-LF": "red",
98
+ "B-LF": "pink",
99
  "B-AC": "blue",
100
+ "B-O": "green",
101
  # Add more labels and colors as needed
102
  }
103
  return color_mapping.get(label, "black") # Default to black if label not found
 
105
 
106
  # Define the Gradio interface function
107
  def ner_demo(text):
108
+ """ This is invoked while rendering the page"""
109
  model = get_cached_model()
110
  entities = model(text)
 
111
 
112
+ print("Entities detected {}".format(Counter( [ entity['entity'] for entity in entities])))
113
+
114
+ all_html = ""
115
+ last_index = 0
116
+
117
  for entity in entities:
 
118
  start, end, label = entity["start"], entity["end"], entity["entity"]
119
+ color = get_color_for_label(label)
120
  entity_text = text[start:end]
121
+ #colored_entity = f'<span style="color: {color}; font-weight: bold;">{entity_text}</span>'
122
+ colored_entity = f'<sup style="color: {color}; font-weight: bold;">{entity_text}</sup>'
123
+
124
+
125
+ # Append text before the entity
126
+ all_html += text[last_index:start]
127
+ # Append the colored entity
128
+ all_html += colored_entity
129
+ # Update the last_index
130
+ last_index = end
131
+
132
+ # Append the remaining text after the last entity
133
+ all_html += text[last_index:]
134
+ return all_html
135
+
136
+ bo_color = get_color_for_label("B-O")
137
+ bac_color = get_color_for_label("B-AC")
138
+ ilf_color = get_color_for_label("I-LF")
139
+ blf_color = get_color_for_label("B-LF")
140
+
141
+ PROJECT_INTRO = f"""This is a HF Spaces hosted Gradio App built by NLP Group 27. \n\n
142
+ The model has been trained on surrey-nlp/PLOD-CW dataset.
143
+ The following Entities are recognized:
144
+ <sup style="color: {bo_color}; font-weight: bold;">B-O</sup>
145
+ <sup style="color: {bac_color}; font-weight: bold;">B-AC</sup>
146
+ <sup style="color: {ilf_color}; font-weight: bold;">I-LF</sup>
147
+ <sup style="color: {blf_color}; font-weight: bold;">B-LF</sup>
148
+ <sup style="color: black; font-weight: bold;">Rest</sup>
149
+ """
150
  def echo(text, request: gr.Request):
151
+ res = '<div>'
152
  if request:
153
+ res += f"Request headers dictionary: {request.headers} <p>"
154
+ res += f"IP address: {request.client.host} <p>"
155
+ res += f"Query parameters: {dict(request.query_params)} <p>"
156
+ res += "</div>"
157
+
158
+ return res
159
+
160
+ def sample_data(text):
161
+ text = "The red dots represents LCI , the bright yellow rectangle represents RV , and the black triangle represents the /TLCnLCI"
162
+
163
+ #dat = get_cached_data()
164
+
165
+ #df = dat['test']['tokens'].sample(5)
166
+
167
+ data = {
168
+ "Text": [text],
169
+ "Length": [len(text)]
170
+ }
171
+ df = pd.DataFrame(data)
172
+ return df
173
+
174
+
175
 
176
  # Create the Gradio interface
177
  demo = gr.Interface(
 
182
  title="Named Entity Recognition on PLOD-CW ",
183
  description=f"{PROJECT_INTRO}\n\nEnter text to extract named entities using a NER model."
184
  )
 
 
 
 
 
 
 
 
 
185
 
186
+ with gr.Blocks() as demo:
187
+ gr.Markdown("# Named Entity Recognition on PLOD-CW")
188
+ gr.Markdown(PROJECT_INTRO)
189
+ gr.Markdown("### Enter text to extract named entities using a NER model.")
190
+ text_input = gr.Textbox(lines=10, placeholder="Enter text here...", label="Input Text")
191
+ html_output = gr.HTML(label="HTML Output")
192
+
193
+ with gr.Row():
194
+ submit_button = gr.Button("Submit")
195
+ echo_button = gr.Button("Echo Client")
196
+ sample_button = gr.Button("Sample PLOD_CW")
197
+
198
+ sample_output = gr.Dataframe(label="Sample Table")
199
+ echo_output = gr.HTML(label="HTML Output")
200
+
201
+ submit_button.click(ner_demo, inputs=text_input, outputs=html_output)
202
 
203
+ echo_button.click(echo, inputs=text_input, outputs=echo_output)
204
+ sample_button.click(sample_data, inputs=text_input, outputs=sample_output)
205
 
 
 
 
206
 
207
  # Function to run Gradio
208