MilanBandara commited on
Commit
f1017a3
1 Parent(s): 93479fa

test files added

Browse files
Files changed (2) hide show
  1. Huggin_face_test/fsa.py +304 -0
  2. Huggin_face_test/helpers.py +246 -0
Huggin_face_test/fsa.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing libraries
2
+ from threading import Thread
3
+ from flask import Blueprint, jsonify, request
4
+ from flask_cors import CORS
5
+ import sys
6
+ import os
7
+
8
+
9
+ # Importing process pool executor
10
+ from concurrent.futures import ProcessPoolExecutor
11
+
12
+ # Fasttext for model handling
13
+ import fasttext
14
+
15
+
16
+ # Setting absolute path
17
+ sys.path.insert(0, os.path.abspath("."))
18
+
19
+ from app.config import Config
20
+ from app.helpers import *
21
+ from app.db.models import Tasks
22
+ from app.database import db
23
+ from app.threads.process_fsa_v2 import process_fsa_categories_v2
24
+ # from app.threads.process_fsa_v2 import test_function
25
+
26
+ # Create a Blueprint of classification
27
+ fsa = Blueprint("fsa_v2", __name__, url_prefix="/api/v2/fsa")
28
+
29
+ # Enabling CORS for the blueprint
30
+ CORS(
31
+ fsa,
32
+ supports_credentials=True
33
+ )
34
+
35
+
36
+ # Thread class to run the bacth processing in the thread
37
+ class FSAThread_V2(Thread):
38
+ def __init__(self, data={}) -> None:
39
+ Thread.__init__(self)
40
+ self.data = data
41
+ # Run function of the thread
42
+ def run(self) -> None:
43
+ process_fsa_categories_v2(self.data)
44
+
45
+ # Creating a process pool executor
46
+ # Set maximum processes
47
+ max_processes = 4
48
+ process_executor = ProcessPoolExecutor(max_workers=max_processes)
49
+
50
+ # Update the database
51
+ def update_db(table_idx, remarks=None):
52
+ from app.api import app
53
+
54
+ with app.app_context():
55
+ Tasks.update_by_id(table_idx, remarks)
56
+ db.session.close()
57
+
58
+
59
+ # Prediction for single product
60
+ @fsa.route("/single-product", methods=["POST"])
61
+ def predict_categories():
62
+
63
+ # Get the request
64
+ body = request.json
65
+
66
+ # If there is no body in the request send error message
67
+ if not body:
68
+ return jsonify({"message": "Cannot decode JSON from the body"}), 422
69
+
70
+ # Get the product name from the JSON
71
+ product_name = body.get("product_name")
72
+
73
+ # Check whether product name is missing
74
+ if not product_name:
75
+ return jsonify({"message": "Product name is missing"}), 422
76
+
77
+ # Preprocessing product names for input
78
+ product_name = preprocess(product_name)
79
+
80
+ # Prediction
81
+ # Logging processing
82
+ Logger.info(message="Processing FSA categorical data for " + product_name)
83
+
84
+ # Loading L0 model to model
85
+ try:
86
+ model = fasttext.load_model('app/models/L0/L0_model.bin')
87
+ except:
88
+ return jsonify({"message": "Can't load the L0 model"}), 500
89
+
90
+ #Getting L0 prediction and accuracy
91
+ L0_label,L0_accuracy = get_label_and_accuracy(model,product_name)
92
+ L0_return_label,L0_return_score,L0_label_status = get_return_labels(L0_label,L0_accuracy,0.95)
93
+ print("L0",L0_label,L0_accuracy)
94
+
95
+ if not L0_label:
96
+ return jsonify({"message": "Error predicting L0 Category"}), 500
97
+
98
+ #Loading L1 model to model
99
+ try:
100
+ model = fasttext.load_model('app/models/L1/L1_model.bin')
101
+ except:
102
+ return jsonify({"message": "Can't load the L1 model"}), 500
103
+
104
+ #Getting L1 prediction and accuracy
105
+ L1_label,L1_accuracy = get_label_and_accuracy(model,L0_label +" " + product_name)
106
+ L1_return_label,L1_return_score,L1_label_status = get_return_labels(L1_label,L1_accuracy,0.95)
107
+ print("L1",L1_label,L1_accuracy)
108
+
109
+ if not L1_label:
110
+ return jsonify({"message": "Error predicting L1 Category"}), 500
111
+
112
+ #Loading L2 model to model
113
+ try:
114
+ model = fasttext.load_model('app/models/L2/L2_model.bin')
115
+ except:
116
+ return jsonify({"message": "Can't load the L2 model"}), 500
117
+
118
+ #Getting L2 prediction and accuracy
119
+ L2_label,L2_accuracy = get_label_and_accuracy(model,L1_label+" "+product_name)
120
+ L2_return_label,L2_return_score,L2_label_status = get_return_labels(L2_label,L2_accuracy,0.95)
121
+ print("L2",L2_label,L2_accuracy)
122
+
123
+
124
+ if not L2_label:
125
+ return jsonify({"message": "Error predicting L2 Category"}), 500
126
+
127
+ #Loading L3 model to model
128
+ try:
129
+ model = fasttext.load_model('app/models/L3/L3_model.bin')
130
+ except:
131
+ return jsonify({"message": "Can't load the L3 model"}), 500
132
+ #Getting L3 prediction and accuracy
133
+ L3_label,L3_accuracy = get_label_and_accuracy(model,L2_label+" "+product_name)
134
+ L3_return_label,L3_return_score,L3_label_status = get_return_labels(L3_label,L3_accuracy,0.95)
135
+ print("L3",L3_label,L3_accuracy)
136
+
137
+ if not L3_label:
138
+ return jsonify({"message": "Error predicting L3 Category"}), 500
139
+
140
+ if L0_label == "administrative":
141
+ try:
142
+ model = fasttext.load_model('app/models/L4/administrative/L4_Admin_model.bin')
143
+ except:
144
+ return jsonify({"message": "Can't load the L4 (Administrative) model"}), 500
145
+ #Getting L4 prediction and accuracy
146
+ L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+ " " +product_name))
147
+ L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.75)
148
+ print("L4",L4_label,L4_accuracy)
149
+
150
+ # L0 = Beverage
151
+ elif L0_label == "beverage":
152
+ try:
153
+ model = fasttext.load_model('app/models/L4/beverage/L4_beverage_model.bin')
154
+ except:
155
+ return jsonify({"message": "Can't load the L4 (Beverage) model"}), 500
156
+ #Getting L4 prediction and accuracy
157
+ L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+" "+product_name))
158
+ L4_return_score = None
159
+ L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.66)
160
+ print("L4",L4_label,L4_accuracy)
161
+
162
+ # L0 = Food
163
+ elif L0_label == "food":
164
+ try:
165
+ model = fasttext.load_model('app/models/L4/food/L4_food_model.bin')
166
+ except:
167
+ return jsonify({"message": "Can't load the L4 (Food) model"}), 500
168
+ #Getting L4 prediction and accuracy
169
+ L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+" "+product_name))
170
+ L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.85)
171
+ print("L4",L4_label,L4_accuracy)
172
+
173
+ # L0 = Operationals
174
+ elif L0_label == "operationals":
175
+ try:
176
+ model = fasttext.load_model('app/models/L4/operationals/L4_operationals_model.bin')
177
+ except:
178
+ return jsonify({"message": "Can't load the L4 (Operationals) model"}), 500
179
+ #Getting L4 prediction and accuracy
180
+ L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+" "+product_name))
181
+ L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.8)
182
+ print("L4",L4_label,L4_accuracy)
183
+
184
+ # Error prediction on L4 Category (Can't happen)
185
+ else:
186
+ return jsonify({"message": "Error prediction of L4 Category"}), 422
187
+
188
+ if not L4_label:
189
+ return jsonify({"message": "Error predicting L4 Category"}), 422
190
+
191
+ # Logging the task
192
+ Logger.info(message="Done processing FSA categorical data for" + product_name)
193
+
194
+ # Rreturning the result as JSON
195
+
196
+ return jsonify({
197
+ "classification_results": {
198
+ "l0": L0_return_label,
199
+ "l1": L1_return_label,
200
+ "l2": L2_return_label,
201
+ "l3": L3_return_label,
202
+ "l4": L4_return_label
203
+ },
204
+ "scores": {
205
+ "l0": L0_return_score,
206
+ "l1": L1_return_score,
207
+ "l2": L2_return_score,
208
+ "l3": L3_return_score,
209
+ "l4": L4_return_score
210
+ },
211
+ "remarks":{
212
+ "l0": L0_label_status,
213
+ "l1": L1_label_status,
214
+ "l2": L2_label_status,
215
+ "l3": L3_label_status,
216
+ "l4": L4_label_status
217
+ },
218
+ "all_classification_results": {
219
+ "L0": L0_label,
220
+ "L1": L1_label,
221
+ "L2": L2_label,
222
+ "L3": L3_label,
223
+ "L4": L4_label
224
+ },
225
+ "all_scores": {
226
+ "L0": L0_accuracy,
227
+ "L1": L1_accuracy,
228
+ "L2": L2_accuracy,
229
+ "L3": L3_accuracy,
230
+ "L4": L4_accuracy
231
+ }
232
+
233
+
234
+ }), 200
235
+
236
+
237
+
238
+
239
+
240
+ # Batch processing
241
+ @fsa.route("/process-csv", methods=["POST"])
242
+ def process_csv():
243
+
244
+ # Get the body of the json
245
+ body = request.json
246
+
247
+ # Error passing for missing body
248
+ if not body:
249
+ return jsonify({"message": "Cannot decode JSON from the body"}), 422
250
+
251
+ # It is assumed that uploaded file name in the file_name JSON field
252
+ file_name = body.get("uploaded_file_name")
253
+
254
+ # Original file name
255
+ original_file_name = body.get("original_file_name") or file_name
256
+
257
+ # Missing file name
258
+ if not file_name:
259
+ return jsonify({"message": "File name is missing"}), 422
260
+
261
+ files = [{"name": f"fsa_input_{file_name}", "path": f"FSA Categorization/input/{file_name}"}]
262
+
263
+ # Download files from S3 bucket of AWS
264
+ # File is downloaded to th 'app/constants/{file}'
265
+ for file in files:
266
+ download_status = download_file_from_s3(
267
+ file_name=file["name"], file_path=file["path"]
268
+ )
269
+ if isinstance(download_status, botocore.exceptions.ClientError):
270
+ return (
271
+ jsonify({"message": f"Error downloading {file} from s3"}),
272
+ 422,
273
+ )
274
+
275
+
276
+ # Get the dataframe of the csv to check whether "ProdName" column is available
277
+ df = read_files(file_name=file_name)
278
+
279
+ # Check for product_names in columns
280
+ if "product_name" not in df.columns:
281
+ remove_files(f"fsa_input_{file_name}")
282
+ return jsonify({"message": "Product name column is missing from the CSV"}), 422
283
+
284
+
285
+ # Create a task
286
+ created_task = Tasks.create(file_name=file_name, original_file_name=original_file_name)
287
+
288
+ # Create a json object of data to pass the process
289
+ data = {
290
+ "file_name": file_name,
291
+ "table_idx": created_task.id,
292
+ "update_db": update_db
293
+ }
294
+
295
+ db.session.close()
296
+ # Add the process to process pool executor
297
+ result_future = process_executor.submit(process_fsa_categories_v2, (data))
298
+
299
+ # Creating a thread with data
300
+ # thread = FSAThread_V2(data=data)
301
+ # thread.start()
302
+
303
+ # Testing route
304
+ return jsonify({"message": f"{file_name} - File processing starting"}), 200
Huggin_face_test/helpers.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import boto3
4
+ import botocore
5
+ import re
6
+ import pandas as pd
7
+ from nltk.corpus import stopwords
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore")
11
+
12
+ from app.logger import Logger
13
+
14
+ sys.path.insert(0, os.path.abspath("."))
15
+
16
+
17
+ def read_files(
18
+ file_name, sort_by=None, drop_duplicates=None, drop_na=None, encoding=None
19
+ ):
20
+ df = pd.read_csv(
21
+ os.path.join("app/constants", file_name), low_memory=False, encoding=encoding
22
+ )
23
+ if sort_by:
24
+ df = df.sort_values(by=[sort_by])
25
+ if drop_duplicates:
26
+ print("Removing duplicates in ProdName..")
27
+ print("df rows before removing duplicates = " + str(df.shape[0]))
28
+ df.drop_duplicates(subset=drop_duplicates, keep="first", inplace=True)
29
+ print("df rows after removing duplicates = " + str(df.shape[0]))
30
+ if drop_na:
31
+ print("Removing rows with null values..")
32
+ print("df rows before removing nan values = " + str(df.shape[0]))
33
+ df = df.dropna(subset=drop_na)
34
+ print("df rows after removing nan values = " + str(df.shape[0]))
35
+ df = df.reset_index(drop=True)
36
+ return df
37
+
38
+
39
+ def check_file_already_downloaded(file_name):
40
+ files = os.listdir("app/constants")
41
+ if file_name in files:
42
+ return True
43
+ return False
44
+
45
+
46
+ def download_file_from_s3(
47
+ file_name, bucket_name="sku-matching-ai-ml", skip_check=False, file_path=None
48
+ ):
49
+ if check_file_already_downloaded(file_name) and not skip_check:
50
+ return file_name
51
+ else:
52
+ print("STARTING DOWNLOADING: ", file_name)
53
+ if not file_path:
54
+ file_path = file_name
55
+ s3 = boto3.client("s3")
56
+ try:
57
+ s3.download_file(
58
+ Bucket=bucket_name, Key=file_path, Filename=f"app/constants/{file_name}"
59
+ )
60
+ print("DOWNLOADING FINISHED")
61
+ return file_name
62
+ # pylint: disable=invalid-name
63
+ except botocore.exceptions.ClientError as e:
64
+ Logger().exception(
65
+ message=f"Unable to download file: {file_name}",
66
+ )
67
+ return e
68
+
69
+
70
+ def upload_files_to_s3(file_path, upload_path, bucket_name="sku-matching-ai-ml"):
71
+ print("STARTING UPLOADING")
72
+ s3 = boto3.client("s3")
73
+ try:
74
+ s3.upload_file(file_path, bucket_name, upload_path)
75
+ except botocore.exceptions.ClientError as e:
76
+ Logger().exception(
77
+ message=f"Unable to uplaod file",
78
+ )
79
+ return e
80
+
81
+
82
+ def clean(string):
83
+ raw_text = re.sub("[^a-zA-Z]+", " ", string)
84
+ words = raw_text.lower().split()
85
+ stops = set(stopwords.words("english"))
86
+ meaningful_words = [
87
+ word for word in words if ((not word in stops) and (len(word) >= 3))
88
+ ]
89
+ string = " ".join(meaningful_words)
90
+ return string
91
+
92
+
93
+ def close_open_brackets(input_str):
94
+ opening_brackets = ["(", "[", "{"]
95
+ closing_brackets = [")", "]", "}"]
96
+ stack = []
97
+
98
+ for char in input_str:
99
+ if char in opening_brackets:
100
+ stack.append(char)
101
+ elif char in closing_brackets:
102
+ if len(stack) > 0:
103
+ opening_bracket = stack.pop()
104
+ if opening_brackets.index(opening_bracket) != closing_brackets.index(
105
+ char
106
+ ):
107
+ stack.append(opening_bracket)
108
+ stack.append(char)
109
+ else:
110
+ input_str = input_str.replace(char, "")
111
+
112
+ while len(stack) > 0:
113
+ opening_bracket = stack.pop()
114
+ closing_bracket = closing_brackets[opening_brackets.index(opening_bracket)]
115
+ input_str += closing_bracket
116
+
117
+ return input_str
118
+
119
+
120
+ def iterative_filtering(
121
+ df,
122
+ product,
123
+ column_name,
124
+ skip_clean=False,
125
+ consider_starts_with=True,
126
+ regex=False,
127
+ close_brackets=False,
128
+ ):
129
+ if not skip_clean:
130
+ product = clean(product)
131
+ else:
132
+ product = product.lower()
133
+ words = product.split()
134
+ new_df = df
135
+ index = 0
136
+ out_df = new_df
137
+
138
+ while new_df.shape[0] > 0 and index < len(words):
139
+ out_df = new_df
140
+ new_df = df_filtering_by_word(
141
+ new_df,
142
+ words[index],
143
+ column_name,
144
+ consider_starts_with,
145
+ regex,
146
+ close_brackets,
147
+ )
148
+ if new_df.shape[0] > 0:
149
+ out_df = new_df
150
+ new_df[column_name] = new_df[column_name].str.replace(words[index] + " ", "")
151
+ index = index + 1
152
+ out_df = out_df.reset_index(drop=True)
153
+ return out_df
154
+
155
+
156
+ def df_filtering_by_word(
157
+ df, word, column_name, consider_starts_with=True, regex=False, close_brackets=False
158
+ ):
159
+ try:
160
+ if close_brackets:
161
+ word = close_open_brackets(word)
162
+
163
+ if consider_starts_with:
164
+ filtered_df = df[df[column_name].str.startswith(word)]
165
+ if filtered_df.shape[0] == 0:
166
+ filtered_df = df[df[column_name].str.contains(word)]
167
+ else:
168
+ if regex:
169
+ filtered_df = df[
170
+ df[column_name].str.contains(rf"\b({word})\b", case=False)
171
+ ]
172
+ else:
173
+ filtered_df = df[df[column_name].str.contains(word)]
174
+ if filtered_df.shape[0] == 0:
175
+ filtered_df = df
176
+
177
+ return filtered_df
178
+ except Exception as e:
179
+ return df_filtering_by_word(df, clean(word), consider_starts_with, regex)
180
+
181
+
182
+ def remove_files(file_name):
183
+ if os.path.exists(f"app/constants/{file_name}"):
184
+ os.remove(f"app/constants/{file_name}")
185
+
186
+ def get_top_mrf_product(mrf_product_attributes_list, dp_product_attributes, sequence_scores, default_attr_key_list):
187
+ scores = []
188
+ for id, each_mrf_prod_attr in enumerate(mrf_product_attributes_list):
189
+ score = sequence_scores[id]
190
+ for key in default_attr_key_list:
191
+ if key in dp_product_attributes and key in each_mrf_prod_attr:
192
+ if pd.notna(dp_product_attributes[key]) and pd.notna(each_mrf_prod_attr[key]):
193
+ if str(dp_product_attributes[key]).lower() == str(each_mrf_prod_attr[key]).lower():
194
+ score += 5
195
+ scores.append(score)
196
+
197
+ max_index = scores.index(max(scores))
198
+ return max_index, max(scores)
199
+
200
+ # Helper files required for FSA V2
201
+ # Preprocessing Function
202
+ '''
203
+ This Function is using for preprocessing the input product names
204
+ '''
205
+ def preprocess(text):
206
+ text = re.sub(r'&', 'and', text)
207
+ text = re.sub(r'[^\w\s]',' ', text)
208
+ text = re.sub(' +', ' ', text)
209
+ return text.strip().lower()
210
+
211
+ # Function to preprocess labels from the previous prediction
212
+ def label_processing(label):
213
+ label = re.sub('__label__', '', label)
214
+ label = re.sub('_', ' ', label)
215
+ label = re.sub(' +', ' ', label)
216
+ return label.strip().lower()
217
+
218
+ def get_return_labels(label,accuracy,threshold):
219
+ if accuracy >= threshold:
220
+ return_label = label
221
+ return_score = accuracy
222
+ label_status = f"Classified - Above threshold {threshold}"
223
+ else:
224
+ return_label = None
225
+ return_score = None
226
+ label_status = f"Unclassfied - Below threshold {threshold}"
227
+ return return_label,return_score,label_status
228
+
229
+ #Function to get the product label and accuracy
230
+ def get_label_and_accuracy(model,product_name):
231
+ prediction = model.predict(product_name)
232
+ label = prediction[0][0]
233
+ label = label_processing(label)
234
+ accuracy = round(prediction[1][0],3)
235
+
236
+ return label,accuracy
237
+
238
+ # Function for remove new line in product name
239
+ '''
240
+ Some products may contain new line characters in middle of product names.
241
+ This may occur because of preprocessing. It can lead to result \n in middle of the
242
+ product names.
243
+ '''
244
+ def remove_new_lines(text):
245
+ text = re.sub('\n', ' ', text)
246
+ return text.strip().lower()