Sevixdd commited on
Commit
01d8cbc
1 Parent(s): 2f3b676

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -51
app.py CHANGED
@@ -7,71 +7,188 @@ import threading
7
  import psutil
8
  import random
9
  from transformers import pipeline
10
- from time import gmtime, strftime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Load the model
12
- ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
 
 
 
 
13
 
14
- # --- Prometheus Metrics Setup ---
15
- REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests')
16
- REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds')
17
- ERROR_COUNT = Counter('gradio_error_count', 'Total number of errors')
18
- RESPONSE_SIZE = Histogram('gradio_response_size_bytes', 'Size of responses in bytes')
19
- CPU_USAGE = Gauge('system_cpu_usage_percent', 'System CPU usage in percent')
20
- MEM_USAGE = Gauge('system_memory_usage_percent', 'System memory usage in percent')
21
- QUEUE_LENGTH = Gauge('chat_queue_length', 'Length of the chat queue')
22
 
23
- # --- Logging Setup ---
24
- logging.basicConfig(filename="chat_log.txt", level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
 
 
 
 
 
 
 
 
 
 
25
 
26
  # --- Queue and Metrics ---
27
  chat_queue = Queue() # Define chat_queue globally
28
 
29
- # --- Chat Function with Monitoring ---
30
- def chat_function(message, history):
31
- logging.debug("Starting chat_function")
32
- with REQUEST_LATENCY.time():
33
- REQUEST_COUNT.inc()
34
- try:
35
- start_time = time.time()
36
- chat_queue.put(message)
37
- logging.info(f"Received message from user: {message}")
38
- time = "\nGMT: " + time.strftime("%a, %d %b %Y %I:%M:%S %p %Z", time.gmtime())
39
- ner_results = ner_pipeline(message)
40
- logging.debug(f"NER results: {ner_results}")
41
 
42
  detailed_response = []
 
43
  for result in ner_results:
44
  token = result['word']
45
  score = result['score']
46
  entity = result['entity']
47
- start = result['start']
48
- end = result['end']
49
- detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}, Start: {start}, End: {end}")
50
 
51
  response = "\n".join(detailed_response)
52
- logging.info(f"Generated response: {response}")
53
-
54
  response_size = len(response.encode('utf-8'))
55
  RESPONSE_SIZE.observe(response_size)
56
 
57
  time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  chat_queue.get()
60
- logging.debug("Finished processing message")
61
- return response
62
  except Exception as e:
63
  ERROR_COUNT.inc()
64
- logging.error(f"Error in chat processing: {e}")
65
- return "An error occurred. Please try again."
66
 
67
  # Function to simulate stress test
68
  def stress_test(num_requests, message, delay):
69
  def send_chat_message():
70
- response = requests.post("http://127.0.0.1:7860/api/predict/", json={
71
- "data": [message],
72
- "fn_index": 0 # This might need to be updated based on your Gradio app's function index
73
- })
74
- logging.debug(response.json())
 
 
 
 
75
 
76
  threads = []
77
  for _ in range(num_requests):
@@ -86,14 +203,22 @@ def stress_test(num_requests, message, delay):
86
  # --- Gradio Interface with Background Image and Three Windows ---
87
  with gr.Blocks(css="""
88
  body {
89
- background-image: url("stag.jpeg");
90
- background-size: cover;
91
  background-repeat: no-repeat;
92
  }
93
  """, title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
94
- with gr.Tab("Chat"):
 
 
 
 
 
 
95
  gr.Markdown("## Chat with the Bot")
96
- chatbot = gr.ChatInterface(fn=chat_function)
 
 
97
 
98
  with gr.Tab("Model Parameters"):
99
  model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters
@@ -111,27 +236,27 @@ body {
111
 
112
  with gr.Tab("Stress Testing"):
113
  num_requests_input = gr.Number(label="Number of Requests", value=10)
114
- message_input = gr.Textbox(label="Message", value="Hello bot!")
115
  delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
116
  stress_test_button = gr.Button("Start Stress Test")
117
  stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
118
 
119
- def run_stress_test(num_requests, message, delay):
120
  stress_test_status.value = "Stress test started..."
121
  try:
122
- stress_test(num_requests, message, delay)
123
  stress_test_status.value = "Stress test completed."
124
  except Exception as e:
125
  stress_test_status.value = f"Stress test failed: {e}"
126
 
127
- stress_test_button.click(run_stress_test, [num_requests_input, message_input, delay_input], stress_test_status)
128
 
129
  # --- Update Functions ---
130
  def update_metrics(request_count_display, avg_latency_display):
131
  while True:
132
  request_count = REQUEST_COUNT._value.get()
133
  latency_samples = REQUEST_LATENCY.collect()[0].samples
134
- avg_latency = sum(s.value for s in latency_samples) / len(latency_samples) if latency_samples else 0
135
 
136
  request_count_display.value = request_count
137
  avg_latency_display.value = round(avg_latency, 2)
@@ -148,9 +273,16 @@ body {
148
 
149
  def update_logs(logs_display):
150
  while True:
151
- with open("chat_log.txt", "r") as log_file:
152
- logs = log_file.readlines()
153
- logs_display.value = "".join(logs[-10:]) # Display last 10 lines
 
 
 
 
 
 
 
154
  time.sleep(1) # Update every 1 second
155
 
156
  def display_model_params(model_params_display):
@@ -169,7 +301,7 @@ body {
169
  threading.Thread(target=start_http_server, args=(8000,), daemon=True).start()
170
  threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start()
171
  threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start()
172
- threading.Thread(target=update_logs, args=(logs_display), daemon=True).start()
173
  threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start()
174
  threading.Thread(target=update_queue_length, daemon=True).start()
175
 
 
7
  import psutil
8
  import random
9
  from transformers import pipeline
10
+ from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
11
+ import requests
12
+ from datasets import load_dataset
13
+ import os
14
+ from logging.handlers import RotatingFileHandler
15
+
16
+
17
+
18
+ # Ensure the log files exist
19
+ log_file_path = 'chat_log.log'
20
+ debug_log_file_path = 'debug.log'
21
+ if not os.path.exists(log_file_path):
22
+ with open(log_file_path, 'w') as f:
23
+ f.write("")
24
+ if not os.path.exists(debug_log_file_path):
25
+ with open(debug_log_file_path, 'w') as f:
26
+ f.write("")
27
+
28
+
29
+ # Create logger instance
30
+ logger = logging.getLogger()
31
+ logger.setLevel(logging.DEBUG) # Set logger level to the lowest level needed
32
+
33
+ #Create formatter
34
+ formatter = logging.Formatter(
35
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
36
+
37
+ # Create handlers
38
+ info_handler = RotatingFileHandler(
39
+ filename=log_file_path, mode='w', maxBytes=5*1024*1024, backupCount=2)
40
+ info_handler.setLevel(logging.INFO)
41
+ info_handler.setFormatter(formatter)
42
+
43
+ debug_handler = RotatingFileHandler(
44
+ filename=debug_log_file_path, mode='w', maxBytes=5*1024*1024, backupCount=2)
45
+ debug_handler.setLevel(logging.DEBUG)
46
+ debug_handler.setFormatter(formatter)
47
+
48
+
49
+ # Function to capture logs for Gradio display
50
+ class GradioHandler(logging.Handler):
51
+ def __init__(self, logs_queue):
52
+ super().__init__()
53
+ self.logs_queue = logs_queue
54
+
55
+ def emit(self, record):
56
+ log_entry = self.format(record)
57
+ self.logs_queue.put(log_entry)
58
+
59
+ # Create a logs queue
60
+ logs_queue = Queue()
61
+
62
+ # Create and configure Gradio handler
63
+ gradio_handler = GradioHandler(logs_queue)
64
+ gradio_handler.setLevel(logging.INFO)
65
+ gradio_handler.setFormatter(formatter)
66
+
67
+ # Add handlers to the logger
68
+ logger.addHandler(info_handler)
69
+ logger.addHandler(debug_handler)
70
+ logger.addHandler(gradio_handler)
71
+
72
  # Load the model
73
+ try:
74
+ ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
75
+ logger.debug("NER pipeline loaded.")
76
+ except Exception as e:
77
+ logger.debug(f"Error loading NER pipeline: {e}")
78
 
79
+ # Load the dataset
80
+ try:
81
+ dataset = load_dataset("surrey-nlp/PLOD-filtered")
82
+ logger.debug("Dataset loaded.")
83
+ except Exception as e:
84
+ logger.debug(f"Error loading dataset: {e}")
 
 
85
 
86
+ # --- Prometheus Metrics Setup ---
87
+ try:
88
+ REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests')
89
+ REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds')
90
+ ERROR_COUNT = Counter('gradio_error_count', 'Total number of errors')
91
+ RESPONSE_SIZE = Histogram('gradio_response_size_bytes', 'Size of responses in bytes')
92
+ CPU_USAGE = Gauge('system_cpu_usage_percent', 'System CPU usage in percent')
93
+ MEM_USAGE = Gauge('system_memory_usage_percent', 'System memory usage in percent')
94
+ QUEUE_LENGTH = Gauge('chat_queue_length', 'Length of the chat queue')
95
+ logger.debug("Prometheus metrics setup complete.")
96
+ except Exception as e:
97
+ logger.debug(f"Error setting up Prometheus metrics: {e}")
98
 
99
  # --- Queue and Metrics ---
100
  chat_queue = Queue() # Define chat_queue globally
101
 
102
+ label_mapping = {
103
+ 0: 'B-O',
104
+ 1: 'B-AC',
105
+ 3: 'B-LF',
106
+ 4: 'I-LF'
107
+ }
108
+
109
+
110
+ def classification(message):
111
+ # Predict using the model
112
+ ner_results = ner_pipeline(" ".join(message))
 
113
 
114
  detailed_response = []
115
+ model_predicted_labels = []
116
  for result in ner_results:
117
  token = result['word']
118
  score = result['score']
119
  entity = result['entity']
120
+ label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
121
+ model_predicted_labels.append(label_mapping[label_id])
122
+ detailed_response.append(f"Token: {token}, Entity: {label_mapping[label_id]}, Score: {score:.4f}")
123
 
124
  response = "\n".join(detailed_response)
 
 
125
  response_size = len(response.encode('utf-8'))
126
  RESPONSE_SIZE.observe(response_size)
127
 
128
  time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
129
 
130
+ return response, model_predicted_labels
131
+
132
+ # --- Chat Function with Monitoring ---
133
+ def chat_function(input, datasets):
134
+ logger.debug("Starting chat_function")
135
+ with REQUEST_LATENCY.time():
136
+ REQUEST_COUNT.inc()
137
+ try:
138
+ if input.isnumeric():
139
+ chat_queue.put(input)
140
+ # Get the example from the dataset
141
+ if datasets:
142
+ example = datasets[int(input)]
143
+ else:
144
+ example = dataset['train'][int(input)]
145
+ tokens = example['tokens']
146
+ ground_truth_labels = [label_mapping[label] for label in example['ner_tags']]
147
+
148
+ # Call the classification function
149
+ response, model_predicted_labels = classification(tokens)
150
+
151
+
152
+ # Ensure the model and ground truth labels are the same length for comparison
153
+ model_predicted_labels = model_predicted_labels[:len(ground_truth_labels)]
154
+
155
+ precision = precision_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
156
+ recall = recall_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
157
+ f1 = f1_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
158
+ accuracy = accuracy_score(ground_truth_labels, model_predicted_labels)
159
+
160
+ metrics_response = (f"Precision: {precision:.4f}\n"
161
+ f"Recall: {recall:.4f}\n"
162
+ f"F1 Score: {f1:.4f}\n"
163
+ f"Accuracy: {accuracy:.4f}")
164
+
165
+ full_response = f"**Record**:\nTokens: {tokens}\nGround Truth Labels: {ground_truth_labels}\n\n**Predictions**:\n{response}\n\n**Metrics**:\n{metrics_response}"
166
+ logger.info(f"\nInput details: \n Received index from user: {input} Sending response to user: {full_response}")
167
+ else:
168
+ chat_queue.put(input)
169
+ response, predicted_labels = classification([input])
170
+ full_response = f"Input details: \n**Input Sentence:** {input}\n\n**Predictions**:\n{response}\n\n"
171
+ logger.info(full_response)
172
+
173
  chat_queue.get()
174
+ return full_response
 
175
  except Exception as e:
176
  ERROR_COUNT.inc()
177
+ logger.error(f"Error in chat processing: {e}", exc_info=True)
178
+ return f"An error occurred. Please try again. Error: {e}"
179
 
180
  # Function to simulate stress test
181
  def stress_test(num_requests, message, delay):
182
  def send_chat_message():
183
+ try:
184
+ response = requests.post("http://127.0.0.1:7860/api/predict/", json={
185
+ "data": [message],
186
+ "fn_index": 0 # This might need to be updated based on your Gradio app's function index
187
+ })
188
+ logger.debug(f"Request payload: {message}")
189
+ logger.debug(f"Response: {response.json()}")
190
+ except Exception as e:
191
+ logger.debug(f"Error during stress test request: {e}", exc_info=True)
192
 
193
  threads = []
194
  for _ in range(num_requests):
 
203
  # --- Gradio Interface with Background Image and Three Windows ---
204
  with gr.Blocks(css="""
205
  body {
206
+ background-image: url("stag.jpeg");
207
+ background-size: cover;
208
  background-repeat: no-repeat;
209
  }
210
  """, title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
211
+ with gr.Tab("Sentence input"):
212
+ gr.Markdown("## Chat with the Bot")
213
+ index_input = gr.Textbox(label="Enter A sentence:", lines=1)
214
+ output = gr.Markdown(label="Response")
215
+ chat_interface = gr.Interface(fn=chat_function, inputs=[index_input], outputs=output)
216
+
217
+ with gr.Tab("Dataset and Index Input"):
218
  gr.Markdown("## Chat with the Bot")
219
+ interface = gr.Interface(fn = chat_function,
220
+ inputs=[gr.Textbox(label="Enter dataset index:", lines=1), gr.UploadButton(label ="Upload Dataset", file_types=[".csv", ".tsv"])],
221
+ outputs = gr.Markdown(label="Response"))
222
 
223
  with gr.Tab("Model Parameters"):
224
  model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters
 
236
 
237
  with gr.Tab("Stress Testing"):
238
  num_requests_input = gr.Number(label="Number of Requests", value=10)
239
+ index_input_stress = gr.Textbox(label="Dataset Index", value="2")
240
  delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
241
  stress_test_button = gr.Button("Start Stress Test")
242
  stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
243
 
244
+ def run_stress_test(num_requests, index, delay):
245
  stress_test_status.value = "Stress test started..."
246
  try:
247
+ stress_test(num_requests, index, delay)
248
  stress_test_status.value = "Stress test completed."
249
  except Exception as e:
250
  stress_test_status.value = f"Stress test failed: {e}"
251
 
252
+ stress_test_button.click(run_stress_test, [num_requests_input, index_input_stress, delay_input], stress_test_status)
253
 
254
  # --- Update Functions ---
255
  def update_metrics(request_count_display, avg_latency_display):
256
  while True:
257
  request_count = REQUEST_COUNT._value.get()
258
  latency_samples = REQUEST_LATENCY.collect()[0].samples
259
+ avg_latency = sum(s.value for s in latency_samples) / len(latency_samples if latency_samples else [1]) # Avoid division by zero
260
 
261
  request_count_display.value = request_count
262
  avg_latency_display.value = round(avg_latency, 2)
 
273
 
274
  def update_logs(logs_display):
275
  while True:
276
+ info_log_vector = []
277
+ # with open('debug.log', "r") as log_file_handler:
278
+ # for line in log_file_handler: # Skip empty lines
279
+ # info_log_vector.append(line)
280
+ # debugger.debug(info_log_vector)
281
+ # logs_display.value = info_log_vector # Display last 10 lines
282
+ logs = []
283
+ while not logs_queue.empty():
284
+ logs.append(logs_queue.get())
285
+ logs_display.value = "\n".join(logs)
286
  time.sleep(1) # Update every 1 second
287
 
288
  def display_model_params(model_params_display):
 
301
  threading.Thread(target=start_http_server, args=(8000,), daemon=True).start()
302
  threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start()
303
  threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start()
304
+ threading.Thread(target=update_logs, args=(logs_display,), daemon=True).start()
305
  threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start()
306
  threading.Thread(target=update_queue_length, daemon=True).start()
307