D3V1L1810 commited on
Commit
8428754
·
verified ·
1 Parent(s): 09aaef7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -28
app.py CHANGED
@@ -6,13 +6,14 @@ import requests
6
  import gradio as gr
7
  import logging
8
 
 
9
  bert_tokenizer = BertTokenizer.from_pretrained('MultiTokenizer_ep10')
10
  bert_model = TFBertForSequenceClassification.from_pretrained('MultiModel_ep10')
11
 
 
12
  # def send_results_to_api(data, result_url):
13
  # headers = {'Content-Type':'application/json'}
14
  # response = requests.post(result_url, json = data, headers=headers)
15
-
16
  # if response.status_code == 200:
17
  # return response.json
18
  # else:
@@ -21,56 +22,59 @@ bert_model = TFBertForSequenceClassification.from_pretrained('MultiModel_ep10')
21
  def predict_text(params):
22
  try:
23
  params = json.loads(params)
24
- except JSONDecodeError as e:
25
  logging.error(f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}")
26
  return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"}
27
 
28
- texts = params.get("urls",[])
29
- if not params.get("normalfileID",[]):
30
- file_ids = [None]*len(texts)
31
  else:
32
- file_ids = params.get("normalfileID",[])
33
- # api = params.get("api", "")
34
- # job_id = params.get("job_id","")
35
 
36
  if not texts:
37
  return {"error": "Missing required parameters: 'texts'"}
38
 
39
  solutions = []
 
40
 
41
- for text,file_id in zip(texts,file_ids):
42
  encoding = bert_tokenizer.encode_plus(
43
- text,
44
- add_special_tokens=True,
45
- max_length=128,
46
- return_token_type_ids=True,
47
- padding = 'max_length',
48
- truncation=True,
49
- return_attention_mask=True,
50
- return_tensors='tf'
51
- )
52
  input_ids = encoding['input_ids']
53
  token_type_ids = encoding['token_type_ids']
54
  attention_mask = encoding['attention_mask']
55
 
56
  pred = bert_model.predict([input_ids, token_type_ids, attention_mask])
57
  logits = pred.logits
 
58
  pred_label = tf.argmax(logits, axis=1).numpy()[0]
59
- print(f"{logits}\t{pred_label}")
60
- if not pred_label:
61
- predict_label = 7
62
-
 
 
 
 
63
  label = {0: 'BUSINESS', 1: 'COMEDY', 2: 'CRIME', 3: 'FOOD & DRINK', 4: 'POLITICS', 5: 'SPORTS', 6: 'TRAVEL', 7: 'None'}
64
- result = {'text':text, 'answer':[label[pred_label]], "qcUser" : None,"normalfileID":file_id}
65
  solutions.append(result)
66
 
67
  # result_url = f"{api}/{job_id}"
68
  # send_results_to_api(solutions, result_url)
69
- return json.dumps({"solutions":solutions})
70
-
71
 
72
- inputt = gr.Textbox(label="Parameters in Json Format... Eg. {'texts':['text1', 'text2']")
73
  outputt = gr.JSON()
74
 
75
- application = gr.Interface(fn = predict_text, inputs = inputt, outputs = outputt, title='Multi Text Classification with API Integration..')
76
- application.launch()
 
6
  import gradio as gr
7
  import logging
8
 
9
+ # Initialize the tokenizer and model
10
  bert_tokenizer = BertTokenizer.from_pretrained('MultiTokenizer_ep10')
11
  bert_model = TFBertForSequenceClassification.from_pretrained('MultiModel_ep10')
12
 
13
+ # Function to send results to API
14
  # def send_results_to_api(data, result_url):
15
  # headers = {'Content-Type':'application/json'}
16
  # response = requests.post(result_url, json = data, headers=headers)
 
17
  # if response.status_code == 200:
18
  # return response.json
19
  # else:
 
22
  def predict_text(params):
23
  try:
24
  params = json.loads(params)
25
+ except json.JSONDecodeError as e:
26
  logging.error(f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}")
27
  return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"}
28
 
29
+ texts = params.get("urls", [])
30
+ if not params.get("normalfileID", []):
31
+ file_ids = [None] * len(texts)
32
  else:
33
+ file_ids = params.get("normalfileID", [])
 
 
34
 
35
  if not texts:
36
  return {"error": "Missing required parameters: 'texts'"}
37
 
38
  solutions = []
39
+ confidence_threshold = 0.5 # Define your confidence threshold
40
 
41
+ for text, file_id in zip(texts, file_ids):
42
  encoding = bert_tokenizer.encode_plus(
43
+ text,
44
+ add_special_tokens=True,
45
+ max_length=128,
46
+ return_token_type_ids=True,
47
+ padding='max_length',
48
+ truncation=True,
49
+ return_attention_mask=True,
50
+ return_tensors='tf'
51
+ )
52
  input_ids = encoding['input_ids']
53
  token_type_ids = encoding['token_type_ids']
54
  attention_mask = encoding['attention_mask']
55
 
56
  pred = bert_model.predict([input_ids, token_type_ids, attention_mask])
57
  logits = pred.logits
58
+ softmax_scores = tf.nn.softmax(logits, axis=1).numpy()[0]
59
  pred_label = tf.argmax(logits, axis=1).numpy()[0]
60
+
61
+ # Get the confidence score for the predicted label
62
+ confidence = softmax_scores[pred_label]
63
+
64
+ # If confidence is below the threshold, set answer to None
65
+ if confidence < confidence_threshold:
66
+ pred_label = 7 # Set to 'None' class
67
+
68
  label = {0: 'BUSINESS', 1: 'COMEDY', 2: 'CRIME', 3: 'FOOD & DRINK', 4: 'POLITICS', 5: 'SPORTS', 6: 'TRAVEL', 7: 'None'}
69
+ result = {'text': text, 'answer': [label[pred_label]], "qcUser": None, "normalfileID": file_id}
70
  solutions.append(result)
71
 
72
  # result_url = f"{api}/{job_id}"
73
  # send_results_to_api(solutions, result_url)
74
+ return json.dumps({"solutions": solutions})
 
75
 
76
+ inputt = gr.Textbox(label="Parameters in Json Format... Eg. {'texts':['text1', 'text2']}")
77
  outputt = gr.JSON()
78
 
79
+ application = gr.Interface(fn=predict_text, inputs=inputt, outputs=outputt, title='Multi Text Classification with API Integration..')
80
+ application.launch()