Commit
•
f1017a3
1
Parent(s):
93479fa
test files added
Browse files- Huggin_face_test/fsa.py +304 -0
- Huggin_face_test/helpers.py +246 -0
Huggin_face_test/fsa.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing libraries
|
2 |
+
from threading import Thread
|
3 |
+
from flask import Blueprint, jsonify, request
|
4 |
+
from flask_cors import CORS
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
# Importing process pool executor
|
10 |
+
from concurrent.futures import ProcessPoolExecutor
|
11 |
+
|
12 |
+
# Fasttext for model handling
|
13 |
+
import fasttext
|
14 |
+
|
15 |
+
|
16 |
+
# Setting absolute path
|
17 |
+
sys.path.insert(0, os.path.abspath("."))
|
18 |
+
|
19 |
+
from app.config import Config
|
20 |
+
from app.helpers import *
|
21 |
+
from app.db.models import Tasks
|
22 |
+
from app.database import db
|
23 |
+
from app.threads.process_fsa_v2 import process_fsa_categories_v2
|
24 |
+
# from app.threads.process_fsa_v2 import test_function
|
25 |
+
|
26 |
+
# Create a Blueprint of classification
|
27 |
+
fsa = Blueprint("fsa_v2", __name__, url_prefix="/api/v2/fsa")
|
28 |
+
|
29 |
+
# Enabling CORS for the blueprint
|
30 |
+
CORS(
|
31 |
+
fsa,
|
32 |
+
supports_credentials=True
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
# Thread class to run the bacth processing in the thread
|
37 |
+
class FSAThread_V2(Thread):
|
38 |
+
def __init__(self, data={}) -> None:
|
39 |
+
Thread.__init__(self)
|
40 |
+
self.data = data
|
41 |
+
# Run function of the thread
|
42 |
+
def run(self) -> None:
|
43 |
+
process_fsa_categories_v2(self.data)
|
44 |
+
|
45 |
+
# Creating a process pool executor
|
46 |
+
# Set maximum processes
|
47 |
+
max_processes = 4
|
48 |
+
process_executor = ProcessPoolExecutor(max_workers=max_processes)
|
49 |
+
|
50 |
+
# Update the database
|
51 |
+
def update_db(table_idx, remarks=None):
|
52 |
+
from app.api import app
|
53 |
+
|
54 |
+
with app.app_context():
|
55 |
+
Tasks.update_by_id(table_idx, remarks)
|
56 |
+
db.session.close()
|
57 |
+
|
58 |
+
|
59 |
+
# Prediction for single product
|
60 |
+
@fsa.route("/single-product", methods=["POST"])
|
61 |
+
def predict_categories():
|
62 |
+
|
63 |
+
# Get the request
|
64 |
+
body = request.json
|
65 |
+
|
66 |
+
# If there is no body in the request send error message
|
67 |
+
if not body:
|
68 |
+
return jsonify({"message": "Cannot decode JSON from the body"}), 422
|
69 |
+
|
70 |
+
# Get the product name from the JSON
|
71 |
+
product_name = body.get("product_name")
|
72 |
+
|
73 |
+
# Check whether product name is missing
|
74 |
+
if not product_name:
|
75 |
+
return jsonify({"message": "Product name is missing"}), 422
|
76 |
+
|
77 |
+
# Preprocessing product names for input
|
78 |
+
product_name = preprocess(product_name)
|
79 |
+
|
80 |
+
# Prediction
|
81 |
+
# Logging processing
|
82 |
+
Logger.info(message="Processing FSA categorical data for " + product_name)
|
83 |
+
|
84 |
+
# Loading L0 model to model
|
85 |
+
try:
|
86 |
+
model = fasttext.load_model('app/models/L0/L0_model.bin')
|
87 |
+
except:
|
88 |
+
return jsonify({"message": "Can't load the L0 model"}), 500
|
89 |
+
|
90 |
+
#Getting L0 prediction and accuracy
|
91 |
+
L0_label,L0_accuracy = get_label_and_accuracy(model,product_name)
|
92 |
+
L0_return_label,L0_return_score,L0_label_status = get_return_labels(L0_label,L0_accuracy,0.95)
|
93 |
+
print("L0",L0_label,L0_accuracy)
|
94 |
+
|
95 |
+
if not L0_label:
|
96 |
+
return jsonify({"message": "Error predicting L0 Category"}), 500
|
97 |
+
|
98 |
+
#Loading L1 model to model
|
99 |
+
try:
|
100 |
+
model = fasttext.load_model('app/models/L1/L1_model.bin')
|
101 |
+
except:
|
102 |
+
return jsonify({"message": "Can't load the L1 model"}), 500
|
103 |
+
|
104 |
+
#Getting L1 prediction and accuracy
|
105 |
+
L1_label,L1_accuracy = get_label_and_accuracy(model,L0_label +" " + product_name)
|
106 |
+
L1_return_label,L1_return_score,L1_label_status = get_return_labels(L1_label,L1_accuracy,0.95)
|
107 |
+
print("L1",L1_label,L1_accuracy)
|
108 |
+
|
109 |
+
if not L1_label:
|
110 |
+
return jsonify({"message": "Error predicting L1 Category"}), 500
|
111 |
+
|
112 |
+
#Loading L2 model to model
|
113 |
+
try:
|
114 |
+
model = fasttext.load_model('app/models/L2/L2_model.bin')
|
115 |
+
except:
|
116 |
+
return jsonify({"message": "Can't load the L2 model"}), 500
|
117 |
+
|
118 |
+
#Getting L2 prediction and accuracy
|
119 |
+
L2_label,L2_accuracy = get_label_and_accuracy(model,L1_label+" "+product_name)
|
120 |
+
L2_return_label,L2_return_score,L2_label_status = get_return_labels(L2_label,L2_accuracy,0.95)
|
121 |
+
print("L2",L2_label,L2_accuracy)
|
122 |
+
|
123 |
+
|
124 |
+
if not L2_label:
|
125 |
+
return jsonify({"message": "Error predicting L2 Category"}), 500
|
126 |
+
|
127 |
+
#Loading L3 model to model
|
128 |
+
try:
|
129 |
+
model = fasttext.load_model('app/models/L3/L3_model.bin')
|
130 |
+
except:
|
131 |
+
return jsonify({"message": "Can't load the L3 model"}), 500
|
132 |
+
#Getting L3 prediction and accuracy
|
133 |
+
L3_label,L3_accuracy = get_label_and_accuracy(model,L2_label+" "+product_name)
|
134 |
+
L3_return_label,L3_return_score,L3_label_status = get_return_labels(L3_label,L3_accuracy,0.95)
|
135 |
+
print("L3",L3_label,L3_accuracy)
|
136 |
+
|
137 |
+
if not L3_label:
|
138 |
+
return jsonify({"message": "Error predicting L3 Category"}), 500
|
139 |
+
|
140 |
+
if L0_label == "administrative":
|
141 |
+
try:
|
142 |
+
model = fasttext.load_model('app/models/L4/administrative/L4_Admin_model.bin')
|
143 |
+
except:
|
144 |
+
return jsonify({"message": "Can't load the L4 (Administrative) model"}), 500
|
145 |
+
#Getting L4 prediction and accuracy
|
146 |
+
L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+ " " +product_name))
|
147 |
+
L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.75)
|
148 |
+
print("L4",L4_label,L4_accuracy)
|
149 |
+
|
150 |
+
# L0 = Beverage
|
151 |
+
elif L0_label == "beverage":
|
152 |
+
try:
|
153 |
+
model = fasttext.load_model('app/models/L4/beverage/L4_beverage_model.bin')
|
154 |
+
except:
|
155 |
+
return jsonify({"message": "Can't load the L4 (Beverage) model"}), 500
|
156 |
+
#Getting L4 prediction and accuracy
|
157 |
+
L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+" "+product_name))
|
158 |
+
L4_return_score = None
|
159 |
+
L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.66)
|
160 |
+
print("L4",L4_label,L4_accuracy)
|
161 |
+
|
162 |
+
# L0 = Food
|
163 |
+
elif L0_label == "food":
|
164 |
+
try:
|
165 |
+
model = fasttext.load_model('app/models/L4/food/L4_food_model.bin')
|
166 |
+
except:
|
167 |
+
return jsonify({"message": "Can't load the L4 (Food) model"}), 500
|
168 |
+
#Getting L4 prediction and accuracy
|
169 |
+
L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+" "+product_name))
|
170 |
+
L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.85)
|
171 |
+
print("L4",L4_label,L4_accuracy)
|
172 |
+
|
173 |
+
# L0 = Operationals
|
174 |
+
elif L0_label == "operationals":
|
175 |
+
try:
|
176 |
+
model = fasttext.load_model('app/models/L4/operationals/L4_operationals_model.bin')
|
177 |
+
except:
|
178 |
+
return jsonify({"message": "Can't load the L4 (Operationals) model"}), 500
|
179 |
+
#Getting L4 prediction and accuracy
|
180 |
+
L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+" "+product_name))
|
181 |
+
L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.8)
|
182 |
+
print("L4",L4_label,L4_accuracy)
|
183 |
+
|
184 |
+
# Error prediction on L4 Category (Can't happen)
|
185 |
+
else:
|
186 |
+
return jsonify({"message": "Error prediction of L4 Category"}), 422
|
187 |
+
|
188 |
+
if not L4_label:
|
189 |
+
return jsonify({"message": "Error predicting L4 Category"}), 422
|
190 |
+
|
191 |
+
# Logging the task
|
192 |
+
Logger.info(message="Done processing FSA categorical data for" + product_name)
|
193 |
+
|
194 |
+
# Rreturning the result as JSON
|
195 |
+
|
196 |
+
return jsonify({
|
197 |
+
"classification_results": {
|
198 |
+
"l0": L0_return_label,
|
199 |
+
"l1": L1_return_label,
|
200 |
+
"l2": L2_return_label,
|
201 |
+
"l3": L3_return_label,
|
202 |
+
"l4": L4_return_label
|
203 |
+
},
|
204 |
+
"scores": {
|
205 |
+
"l0": L0_return_score,
|
206 |
+
"l1": L1_return_score,
|
207 |
+
"l2": L2_return_score,
|
208 |
+
"l3": L3_return_score,
|
209 |
+
"l4": L4_return_score
|
210 |
+
},
|
211 |
+
"remarks":{
|
212 |
+
"l0": L0_label_status,
|
213 |
+
"l1": L1_label_status,
|
214 |
+
"l2": L2_label_status,
|
215 |
+
"l3": L3_label_status,
|
216 |
+
"l4": L4_label_status
|
217 |
+
},
|
218 |
+
"all_classification_results": {
|
219 |
+
"L0": L0_label,
|
220 |
+
"L1": L1_label,
|
221 |
+
"L2": L2_label,
|
222 |
+
"L3": L3_label,
|
223 |
+
"L4": L4_label
|
224 |
+
},
|
225 |
+
"all_scores": {
|
226 |
+
"L0": L0_accuracy,
|
227 |
+
"L1": L1_accuracy,
|
228 |
+
"L2": L2_accuracy,
|
229 |
+
"L3": L3_accuracy,
|
230 |
+
"L4": L4_accuracy
|
231 |
+
}
|
232 |
+
|
233 |
+
|
234 |
+
}), 200
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
# Batch processing
|
241 |
+
@fsa.route("/process-csv", methods=["POST"])
|
242 |
+
def process_csv():
|
243 |
+
|
244 |
+
# Get the body of the json
|
245 |
+
body = request.json
|
246 |
+
|
247 |
+
# Error passing for missing body
|
248 |
+
if not body:
|
249 |
+
return jsonify({"message": "Cannot decode JSON from the body"}), 422
|
250 |
+
|
251 |
+
# It is assumed that uploaded file name in the file_name JSON field
|
252 |
+
file_name = body.get("uploaded_file_name")
|
253 |
+
|
254 |
+
# Original file name
|
255 |
+
original_file_name = body.get("original_file_name") or file_name
|
256 |
+
|
257 |
+
# Missing file name
|
258 |
+
if not file_name:
|
259 |
+
return jsonify({"message": "File name is missing"}), 422
|
260 |
+
|
261 |
+
files = [{"name": f"fsa_input_{file_name}", "path": f"FSA Categorization/input/{file_name}"}]
|
262 |
+
|
263 |
+
# Download files from S3 bucket of AWS
|
264 |
+
# File is downloaded to th 'app/constants/{file}'
|
265 |
+
for file in files:
|
266 |
+
download_status = download_file_from_s3(
|
267 |
+
file_name=file["name"], file_path=file["path"]
|
268 |
+
)
|
269 |
+
if isinstance(download_status, botocore.exceptions.ClientError):
|
270 |
+
return (
|
271 |
+
jsonify({"message": f"Error downloading {file} from s3"}),
|
272 |
+
422,
|
273 |
+
)
|
274 |
+
|
275 |
+
|
276 |
+
# Get the dataframe of the csv to check whether "ProdName" column is available
|
277 |
+
df = read_files(file_name=file_name)
|
278 |
+
|
279 |
+
# Check for product_names in columns
|
280 |
+
if "product_name" not in df.columns:
|
281 |
+
remove_files(f"fsa_input_{file_name}")
|
282 |
+
return jsonify({"message": "Product name column is missing from the CSV"}), 422
|
283 |
+
|
284 |
+
|
285 |
+
# Create a task
|
286 |
+
created_task = Tasks.create(file_name=file_name, original_file_name=original_file_name)
|
287 |
+
|
288 |
+
# Create a json object of data to pass the process
|
289 |
+
data = {
|
290 |
+
"file_name": file_name,
|
291 |
+
"table_idx": created_task.id,
|
292 |
+
"update_db": update_db
|
293 |
+
}
|
294 |
+
|
295 |
+
db.session.close()
|
296 |
+
# Add the process to process pool executor
|
297 |
+
result_future = process_executor.submit(process_fsa_categories_v2, (data))
|
298 |
+
|
299 |
+
# Creating a thread with data
|
300 |
+
# thread = FSAThread_V2(data=data)
|
301 |
+
# thread.start()
|
302 |
+
|
303 |
+
# Testing route
|
304 |
+
return jsonify({"message": f"{file_name} - File processing starting"}), 200
|
Huggin_face_test/helpers.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import boto3
|
4 |
+
import botocore
|
5 |
+
import re
|
6 |
+
import pandas as pd
|
7 |
+
from nltk.corpus import stopwords
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
warnings.filterwarnings("ignore")
|
11 |
+
|
12 |
+
from app.logger import Logger
|
13 |
+
|
14 |
+
sys.path.insert(0, os.path.abspath("."))
|
15 |
+
|
16 |
+
|
17 |
+
def read_files(
|
18 |
+
file_name, sort_by=None, drop_duplicates=None, drop_na=None, encoding=None
|
19 |
+
):
|
20 |
+
df = pd.read_csv(
|
21 |
+
os.path.join("app/constants", file_name), low_memory=False, encoding=encoding
|
22 |
+
)
|
23 |
+
if sort_by:
|
24 |
+
df = df.sort_values(by=[sort_by])
|
25 |
+
if drop_duplicates:
|
26 |
+
print("Removing duplicates in ProdName..")
|
27 |
+
print("df rows before removing duplicates = " + str(df.shape[0]))
|
28 |
+
df.drop_duplicates(subset=drop_duplicates, keep="first", inplace=True)
|
29 |
+
print("df rows after removing duplicates = " + str(df.shape[0]))
|
30 |
+
if drop_na:
|
31 |
+
print("Removing rows with null values..")
|
32 |
+
print("df rows before removing nan values = " + str(df.shape[0]))
|
33 |
+
df = df.dropna(subset=drop_na)
|
34 |
+
print("df rows after removing nan values = " + str(df.shape[0]))
|
35 |
+
df = df.reset_index(drop=True)
|
36 |
+
return df
|
37 |
+
|
38 |
+
|
39 |
+
def check_file_already_downloaded(file_name):
|
40 |
+
files = os.listdir("app/constants")
|
41 |
+
if file_name in files:
|
42 |
+
return True
|
43 |
+
return False
|
44 |
+
|
45 |
+
|
46 |
+
def download_file_from_s3(
|
47 |
+
file_name, bucket_name="sku-matching-ai-ml", skip_check=False, file_path=None
|
48 |
+
):
|
49 |
+
if check_file_already_downloaded(file_name) and not skip_check:
|
50 |
+
return file_name
|
51 |
+
else:
|
52 |
+
print("STARTING DOWNLOADING: ", file_name)
|
53 |
+
if not file_path:
|
54 |
+
file_path = file_name
|
55 |
+
s3 = boto3.client("s3")
|
56 |
+
try:
|
57 |
+
s3.download_file(
|
58 |
+
Bucket=bucket_name, Key=file_path, Filename=f"app/constants/{file_name}"
|
59 |
+
)
|
60 |
+
print("DOWNLOADING FINISHED")
|
61 |
+
return file_name
|
62 |
+
# pylint: disable=invalid-name
|
63 |
+
except botocore.exceptions.ClientError as e:
|
64 |
+
Logger().exception(
|
65 |
+
message=f"Unable to download file: {file_name}",
|
66 |
+
)
|
67 |
+
return e
|
68 |
+
|
69 |
+
|
70 |
+
def upload_files_to_s3(file_path, upload_path, bucket_name="sku-matching-ai-ml"):
|
71 |
+
print("STARTING UPLOADING")
|
72 |
+
s3 = boto3.client("s3")
|
73 |
+
try:
|
74 |
+
s3.upload_file(file_path, bucket_name, upload_path)
|
75 |
+
except botocore.exceptions.ClientError as e:
|
76 |
+
Logger().exception(
|
77 |
+
message=f"Unable to uplaod file",
|
78 |
+
)
|
79 |
+
return e
|
80 |
+
|
81 |
+
|
82 |
+
def clean(string):
|
83 |
+
raw_text = re.sub("[^a-zA-Z]+", " ", string)
|
84 |
+
words = raw_text.lower().split()
|
85 |
+
stops = set(stopwords.words("english"))
|
86 |
+
meaningful_words = [
|
87 |
+
word for word in words if ((not word in stops) and (len(word) >= 3))
|
88 |
+
]
|
89 |
+
string = " ".join(meaningful_words)
|
90 |
+
return string
|
91 |
+
|
92 |
+
|
93 |
+
def close_open_brackets(input_str):
|
94 |
+
opening_brackets = ["(", "[", "{"]
|
95 |
+
closing_brackets = [")", "]", "}"]
|
96 |
+
stack = []
|
97 |
+
|
98 |
+
for char in input_str:
|
99 |
+
if char in opening_brackets:
|
100 |
+
stack.append(char)
|
101 |
+
elif char in closing_brackets:
|
102 |
+
if len(stack) > 0:
|
103 |
+
opening_bracket = stack.pop()
|
104 |
+
if opening_brackets.index(opening_bracket) != closing_brackets.index(
|
105 |
+
char
|
106 |
+
):
|
107 |
+
stack.append(opening_bracket)
|
108 |
+
stack.append(char)
|
109 |
+
else:
|
110 |
+
input_str = input_str.replace(char, "")
|
111 |
+
|
112 |
+
while len(stack) > 0:
|
113 |
+
opening_bracket = stack.pop()
|
114 |
+
closing_bracket = closing_brackets[opening_brackets.index(opening_bracket)]
|
115 |
+
input_str += closing_bracket
|
116 |
+
|
117 |
+
return input_str
|
118 |
+
|
119 |
+
|
120 |
+
def iterative_filtering(
|
121 |
+
df,
|
122 |
+
product,
|
123 |
+
column_name,
|
124 |
+
skip_clean=False,
|
125 |
+
consider_starts_with=True,
|
126 |
+
regex=False,
|
127 |
+
close_brackets=False,
|
128 |
+
):
|
129 |
+
if not skip_clean:
|
130 |
+
product = clean(product)
|
131 |
+
else:
|
132 |
+
product = product.lower()
|
133 |
+
words = product.split()
|
134 |
+
new_df = df
|
135 |
+
index = 0
|
136 |
+
out_df = new_df
|
137 |
+
|
138 |
+
while new_df.shape[0] > 0 and index < len(words):
|
139 |
+
out_df = new_df
|
140 |
+
new_df = df_filtering_by_word(
|
141 |
+
new_df,
|
142 |
+
words[index],
|
143 |
+
column_name,
|
144 |
+
consider_starts_with,
|
145 |
+
regex,
|
146 |
+
close_brackets,
|
147 |
+
)
|
148 |
+
if new_df.shape[0] > 0:
|
149 |
+
out_df = new_df
|
150 |
+
new_df[column_name] = new_df[column_name].str.replace(words[index] + " ", "")
|
151 |
+
index = index + 1
|
152 |
+
out_df = out_df.reset_index(drop=True)
|
153 |
+
return out_df
|
154 |
+
|
155 |
+
|
156 |
+
def df_filtering_by_word(
|
157 |
+
df, word, column_name, consider_starts_with=True, regex=False, close_brackets=False
|
158 |
+
):
|
159 |
+
try:
|
160 |
+
if close_brackets:
|
161 |
+
word = close_open_brackets(word)
|
162 |
+
|
163 |
+
if consider_starts_with:
|
164 |
+
filtered_df = df[df[column_name].str.startswith(word)]
|
165 |
+
if filtered_df.shape[0] == 0:
|
166 |
+
filtered_df = df[df[column_name].str.contains(word)]
|
167 |
+
else:
|
168 |
+
if regex:
|
169 |
+
filtered_df = df[
|
170 |
+
df[column_name].str.contains(rf"\b({word})\b", case=False)
|
171 |
+
]
|
172 |
+
else:
|
173 |
+
filtered_df = df[df[column_name].str.contains(word)]
|
174 |
+
if filtered_df.shape[0] == 0:
|
175 |
+
filtered_df = df
|
176 |
+
|
177 |
+
return filtered_df
|
178 |
+
except Exception as e:
|
179 |
+
return df_filtering_by_word(df, clean(word), consider_starts_with, regex)
|
180 |
+
|
181 |
+
|
182 |
+
def remove_files(file_name):
|
183 |
+
if os.path.exists(f"app/constants/{file_name}"):
|
184 |
+
os.remove(f"app/constants/{file_name}")
|
185 |
+
|
186 |
+
def get_top_mrf_product(mrf_product_attributes_list, dp_product_attributes, sequence_scores, default_attr_key_list):
|
187 |
+
scores = []
|
188 |
+
for id, each_mrf_prod_attr in enumerate(mrf_product_attributes_list):
|
189 |
+
score = sequence_scores[id]
|
190 |
+
for key in default_attr_key_list:
|
191 |
+
if key in dp_product_attributes and key in each_mrf_prod_attr:
|
192 |
+
if pd.notna(dp_product_attributes[key]) and pd.notna(each_mrf_prod_attr[key]):
|
193 |
+
if str(dp_product_attributes[key]).lower() == str(each_mrf_prod_attr[key]).lower():
|
194 |
+
score += 5
|
195 |
+
scores.append(score)
|
196 |
+
|
197 |
+
max_index = scores.index(max(scores))
|
198 |
+
return max_index, max(scores)
|
199 |
+
|
200 |
+
# Helper files required for FSA V2
|
201 |
+
# Preprocessing Function
|
202 |
+
'''
|
203 |
+
This Function is using for preprocessing the input product names
|
204 |
+
'''
|
205 |
+
def preprocess(text):
|
206 |
+
text = re.sub(r'&', 'and', text)
|
207 |
+
text = re.sub(r'[^\w\s]',' ', text)
|
208 |
+
text = re.sub(' +', ' ', text)
|
209 |
+
return text.strip().lower()
|
210 |
+
|
211 |
+
# Function to preprocess labels from the previous prediction
|
212 |
+
def label_processing(label):
|
213 |
+
label = re.sub('__label__', '', label)
|
214 |
+
label = re.sub('_', ' ', label)
|
215 |
+
label = re.sub(' +', ' ', label)
|
216 |
+
return label.strip().lower()
|
217 |
+
|
218 |
+
def get_return_labels(label,accuracy,threshold):
|
219 |
+
if accuracy >= threshold:
|
220 |
+
return_label = label
|
221 |
+
return_score = accuracy
|
222 |
+
label_status = f"Classified - Above threshold {threshold}"
|
223 |
+
else:
|
224 |
+
return_label = None
|
225 |
+
return_score = None
|
226 |
+
label_status = f"Unclassfied - Below threshold {threshold}"
|
227 |
+
return return_label,return_score,label_status
|
228 |
+
|
229 |
+
#Function to get the product label and accuracy
|
230 |
+
def get_label_and_accuracy(model,product_name):
|
231 |
+
prediction = model.predict(product_name)
|
232 |
+
label = prediction[0][0]
|
233 |
+
label = label_processing(label)
|
234 |
+
accuracy = round(prediction[1][0],3)
|
235 |
+
|
236 |
+
return label,accuracy
|
237 |
+
|
238 |
+
# Function for remove new line in product name
|
239 |
+
'''
|
240 |
+
Some products may contain new line characters in middle of product names.
|
241 |
+
This may occur because of preprocessing. It can lead to result \n in middle of the
|
242 |
+
product names.
|
243 |
+
'''
|
244 |
+
def remove_new_lines(text):
|
245 |
+
text = re.sub('\n', ' ', text)
|
246 |
+
return text.strip().lower()
|