Sebastian Kułaga commited on
Commit
a6c83a1
1 Parent(s): 822338a

Initial commit

Browse files
dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python 3.10.9 image
2
+ FROM python:3.10.14
3
+
4
+ # Copy the current directory contents into the container at .
5
+ COPY . .
6
+
7
+ # Set the working directory to /
8
+ WORKDIR /
9
+
10
+ # Install requirements.txt
11
+ RUN pip install --no-cache-dir --upgrade -r /requirements.txt
12
+
13
+ # Start the FastAPI app on port 7860, the default port expected by Spaces
14
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
fastapi/__init__.py ADDED
File without changes
routers/__init__.py ADDED
File without changes
routers/distilbert.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from src.distilbert import Distilbert
3
+
4
+
5
+ MODEL_PATH = "SebaK13/DistilBERT-finetuned-customer-queries-balanced"
6
+ predict = APIRouter(prefix="/predict", tags=["predict"])
7
+ distilbert_service = Distilbert(model_path=MODEL_PATH)
8
+
9
+
10
+
11
+ @predict.get("/predict")
12
+ def predict(query: str):
13
+ output = distilbert_service.predict_query_type(query)
14
+ return {"label": output}
src/__init__.py ADDED
File without changes
src/distilbert.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ from transformers import pipeline
4
+
5
+ class Distilbert:
6
+ """
7
+ Class for distilBERT
8
+ """
9
+ def __init__(self, model_path:str):
10
+ self.pipe = pipeline("text-classification", model=model_path)
11
+
12
+
13
+ def predict_query_type(self, custom_query):
14
+ """
15
+ Predict the type of a given custom query using a pre-trained classifier.
16
+
17
+ Parameters:
18
+ custom_query (str): The input query for which the type needs to be predicted.
19
+
20
+ Returns:
21
+ str: The predicted label for the input query.
22
+ """
23
+ # Get predictions from the classifier
24
+ preds = self.pipe(custom_query, top_k=None)
25
+
26
+ # Get the list of labels from the balanced dataset
27
+ labels_dict = self.label_dict()
28
+ labels = labels_dict.keys()
29
+
30
+ # Create a DataFrame from the predictions
31
+ preds_df = pd.DataFrame(preds)
32
+
33
+ # Process the labels to remove 'LABEL_' prefix and convert to integer
34
+ preds_df['label'] = preds_df['label'].str.replace('LABEL_', '')
35
+ preds_df['label'] = pd.to_numeric(preds_df['label'], errors='coerce').astype('Int64')
36
+ preds_df = preds_df.sort_values('label').reset_index(drop=True)
37
+
38
+ # Find the index of the maximum score
39
+ max_score_index = preds_df['score'].idxmax()
40
+
41
+ # Return the label corresponding to the maximum score
42
+ return labels[max_score_index]
43
+
44
+ def load_lables(self):
45
+ """
46
+ This function reads the 'labels.json' file, which contains a dictionary mapping integers to string labels.
47
+ The function then returns this dictionary.
48
+
49
+ Returns:
50
+ dict: label dictionary with key as label string value and dictionary value as int
51
+ """
52
+ label_dict = {}
53
+ with open('labels.json') as json_file:
54
+ label_dict = json.load(json_file)
55
+ return label_dict
56
+
src/labels.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Airport Services": 0,
3
+ "Baggage Policies": 1,
4
+ "Business Travel": 2,
5
+ "COVID-19 Policies": 3,
6
+ "Check-in Procedures": 4,
7
+ "Child and Infant Travel": 5,
8
+ "Complaints and Feedback": 6,
9
+ "Customer Account Issues": 7,
10
+ "Duty-Free Shopping": 8,
11
+ "Flight Bookings": 9,
12
+ "Flight Changes": 10,
13
+ "Flight Status": 11,
14
+ "Frequent Flyer Miles": 12,
15
+ "Group Bookings": 13,
16
+ "In-flight Services": 14,
17
+ "Lost and Found": 15,
18
+ "Loyalty Programs": 16,
19
+ "Mobile App Issues": 17,
20
+ "Partnerships and Alliances": 18,
21
+ "Payment Issues": 19,
22
+ "Pet Travel": 20,
23
+ "Promotions and Discounts": 21,
24
+ "Refunds and Compensation": 22,
25
+ "Seat Selection": 23,
26
+ "Special Assistance": 24,
27
+ "Travel Documentation": 25,
28
+ "Travel Insurance": 26,
29
+ "Travel Restrictions": 27,
30
+ "Travel Vouchers": 28,
31
+ "Weather-related Disruptions": 29
32
+ }