aminaj commited on
Commit
7f6f85e
1 Parent(s): 7d33df4

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +63 -0
  2. main.py +100 -0
  3. test_main.py +70 -0
Dockerfile ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # syntax=docker/dockerfile:1
2
+
3
+ # Comments are provided throughout this file to help you get started.
4
+ # If you need more help, visit the Dockerfile reference guide at
5
+ # https://docs.docker.com/go/dockerfile-reference/
6
+
7
+ # Want to help us make this template better? Share your feedback here: https://forms.gle/ybq9Krt8jtBL3iCk7
8
+
9
+ ARG PYTHON_VERSION=3.11.9
10
+ FROM python:${PYTHON_VERSION}-slim as base
11
+
12
+ # Prevents Python from writing pyc files.
13
+ ENV PYTHONDONTWRITEBYTECODE=1
14
+
15
+ # Keeps Python from buffering stdout and stderr to avoid situations where
16
+ # the application crashes without emitting any logs due to buffering.
17
+ ENV PYTHONUNBUFFERED=1
18
+
19
+ WORKDIR /app
20
+
21
+ # Create a non-privileged user that the app will run under.
22
+ # See https://docs.docker.com/go/dockerfile-user-best-practices/
23
+ ARG UID=10001
24
+ RUN adduser \
25
+ --disabled-password \
26
+ --gecos "" \
27
+ --home "/nonexistent" \
28
+ --shell "/sbin/nologin" \
29
+ --no-create-home \
30
+ --uid "${UID}" \
31
+ appuser
32
+
33
+ # Download dependencies as a separate step to take advantage of Docker's caching.
34
+ # Leverage a cache mount to /root/.cache/pip to speed up subsequent builds.
35
+ # Leverage a bind mount to requirements.txt to avoid having to copy them into
36
+ # into this layer.
37
+ RUN --mount=type=cache,target=/root/.cache/pip \
38
+ --mount=type=bind,source=requirements.txt,target=requirements.txt \
39
+ python -m pip install -r requirements.txt
40
+
41
+ # Switch to the non-privileged user to run the application.
42
+ USER appuser
43
+
44
+ # Set the TRANSFORMERS_CACHE environment variable
45
+ ENV TRANSFORMERS_CACHE=/tmp/.cache/huggingface
46
+
47
+ # Create the cache folder with appropriate permissions
48
+ RUN mkdir -p $TRANSFORMERS_CACHE && chmod -R 777 $TRANSFORMERS_CACHE
49
+
50
+ # Set NLTK data directory
51
+ ENV NLTK_DATA=/tmp/nltk_data
52
+
53
+ # Create the NLTK data directory with appropriate permissions
54
+ RUN mkdir -p $NLTK_DATA && chmod -R 777 $NLTK_DATA
55
+
56
+ # Copy the source code into the container.
57
+ COPY . .
58
+
59
+ # Expose the port that the application listens on.
60
+ EXPOSE 8000
61
+
62
+ # Run the application.
63
+ CMD uvicorn 'main:app' --host=0.0.0.0 --port=8000
main.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel, ValidationError
4
+ from fastapi.encoders import jsonable_encoder
5
+
6
+ # TEXT PREPROCESSING
7
+ # --------------------------------------------------------------------
8
+ import re
9
+ import string
10
+ import nltk
11
+ nltk.download('punkt')
12
+ nltk.download('wordnet')
13
+ nltk.download('omw-1.4')
14
+ from nltk.stem import WordNetLemmatizer
15
+
16
+ # Function to remove URLs from text
17
+ def remove_urls(text):
18
+ return re.sub(r'http[s]?://\S+', '', text)
19
+
20
+ # Function to remove punctuations from text
21
+ def remove_punctuation(text):
22
+ regular_punct = string.punctuation
23
+ return str(re.sub(r'['+regular_punct+']', '', str(text)))
24
+
25
+ # Function to convert the text into lower case
26
+ def lower_case(text):
27
+ return text.lower()
28
+
29
+ # Function to lemmatize text
30
+ def lemmatize(text):
31
+ wordnet_lemmatizer = WordNetLemmatizer()
32
+
33
+ tokens = nltk.word_tokenize(text)
34
+ lemma_txt = ''
35
+ for w in tokens:
36
+ lemma_txt = lemma_txt + wordnet_lemmatizer.lemmatize(w) + ' '
37
+
38
+ return lemma_txt
39
+
40
+ def preprocess_text(text):
41
+ # Preprocess the input text
42
+ text = remove_urls(text)
43
+ text = remove_punctuation(text)
44
+ text = lower_case(text)
45
+ text = lemmatize(text)
46
+ return text
47
+
48
+ # Load the model using FastAPI lifespan event so that teh model is loaded at the beginning for efficiency
49
+ @asynccontextmanager
50
+ async def lifespan(app: FastAPI):
51
+ # Load the model from HuggingFace transformers library
52
+ from transformers import pipeline
53
+ global sentiment_task
54
+ sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest")
55
+ yield
56
+ # Clean up the model and release the resources
57
+ del sentiment_task
58
+
59
+ # Initialize the FastAPI app
60
+ app = FastAPI(lifespan=lifespan)
61
+
62
+ # Define the input data model
63
+ class TextInput(BaseModel):
64
+ text: str
65
+
66
+ # Define the welcome endpoint
67
+ @app.get('/')
68
+ async def welcome():
69
+ return "Welcome to our Text Classification API"
70
+
71
+ # Validate input text length
72
+ MAX_TEXT_LENGTH = 1000
73
+
74
+ # Define the sentiment analysis endpoint
75
+ @app.post('/analyze/{text}')
76
+ async def classify_text(text_input:TextInput):
77
+ try:
78
+ # Convert input data to JSON serializable dictionary
79
+ text_input_dict = jsonable_encoder(text_input)
80
+ # Validate input data using Pydantic model
81
+ text_data = TextInput(**text_input_dict) # Convert to Pydantic model
82
+
83
+ # Validate input text length
84
+ if len(text_input.text) > MAX_TEXT_LENGTH:
85
+ raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length")
86
+ elif len(text_input.text) == 0:
87
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
88
+ except ValidationError as e:
89
+ # Handle validation error
90
+ raise HTTPException(status_code=422, detail=str(e))
91
+
92
+ try:
93
+ # Perform text classification
94
+ return sentiment_task(preprocess_text(text_input.text))
95
+ except ValueError as ve:
96
+ # Handle value error
97
+ raise HTTPException(status_code=400, detail=str(ve))
98
+ except Exception as e:
99
+ # Handle other server errors
100
+ raise HTTPException(status_code=500, detail=str(e))
test_main.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+ from main import app
3
+ from main import TextInput
4
+ from fastapi.encoders import jsonable_encoder
5
+
6
+ client = TestClient(app)
7
+
8
+ # Test the welcome endpoint
9
+ def test_welcome():
10
+ # Test the welcome endpoint
11
+ response = client.get("/")
12
+ assert response.status_code == 200
13
+ assert response.json() == "Welcome to our Text Classification API"
14
+
15
+ # Test the sentiment analysis endpoint for positive sentiment
16
+ def test_positive_sentiment():
17
+ with client:
18
+ # Define the request payload
19
+ # Initialize payload as a TextInput object
20
+ payload = TextInput(text="I love this product! It's amazing!")
21
+
22
+ # Convert TextInput object to JSON-serializable dictionary
23
+ payload_dict = jsonable_encoder(payload)
24
+
25
+ # Send a POST request to the sentiment analysis endpoint
26
+ response = client.post("/analyze/{text}", json=payload_dict)
27
+
28
+ # Assert that the response status code is 200 OK
29
+ assert response.status_code == 200
30
+
31
+ # Assert that the sentiment returned is positive
32
+ assert response.json()[0]['label'] == "positive"
33
+
34
+ # Test the sentiment analysis endpoint for negative sentiment
35
+ def test_negative_sentiment():
36
+ with client:
37
+ # Define the request payload
38
+ # Initialize payload as a TextInput object
39
+ payload = TextInput(text="I'm really disappointed with this service. It's terrible.")
40
+
41
+ # Convert TextInput object to JSON-serializable dictionary
42
+ payload_dict = jsonable_encoder(payload)
43
+
44
+ # Send a POST request to the sentiment analysis endpoint
45
+ response = client.post("/analyze/{text}", json=payload_dict)
46
+
47
+ # Assert that the response status code is 200 OK
48
+ assert response.status_code == 200
49
+
50
+ # Assert that the sentiment returned is positive
51
+ assert response.json()[0]['label'] == "negative"
52
+
53
+ # Test the sentiment analysis endpoint for neutral sentiment
54
+ def test_neutral_sentiment():
55
+ with client:
56
+ # Define the request payload
57
+ # Initialize payload as a TextInput object
58
+ payload = TextInput(text="This is a neutral statement.")
59
+
60
+ # Convert TextInput object to JSON-serializable dictionary
61
+ payload_dict = jsonable_encoder(payload)
62
+
63
+ # Send a POST request to the sentiment analysis endpoint
64
+ response = client.post("/analyze/{text}", json=payload_dict)
65
+
66
+ # Assert that the response status code is 200 OK
67
+ assert response.status_code == 200
68
+
69
+ # Assert that the sentiment returned is positive
70
+ assert response.json()[0]['label'] == "neutral"