File size: 5,297 Bytes
7be7efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33761d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be7efa
33761d2
 
 
 
 
 
 
 
 
 
 
7be7efa
33761d2
 
 
 
 
 
 
 
8d4a0f7
33761d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be7efa
33761d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ValidationError
from fastapi.encoders import jsonable_encoder

# TEXT PREPROCESSING
# --------------------------------------------------------------------
import re
import string
import nltk
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4')
from nltk.stem import WordNetLemmatizer

# Function to remove URLs from text
def remove_urls(text):
    return re.sub(r'http[s]?://\S+', '', text)

# Function to remove punctuations from text
def remove_punctuation(text):
    regular_punct = string.punctuation
    return str(re.sub(r'['+regular_punct+']', '', str(text)))

# Function to convert the text into lower case
def lower_case(text):
    return text.lower()

# Function to lemmatize text
def lemmatize(text):
    wordnet_lemmatizer = WordNetLemmatizer()

    tokens = nltk.word_tokenize(text)
    lemma_txt = ''
    for w in tokens:
        lemma_txt = lemma_txt + wordnet_lemmatizer.lemmatize(w) + ' '

    return lemma_txt

def preprocess_text(text):
    # Preprocess the input text
    text = remove_urls(text)
    text = remove_punctuation(text)
    text = lower_case(text)
    text = lemmatize(text)
    return text

# Load the model using FastAPI lifespan event so that the model is loaded at the beginning for efficiency
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Load the model from HuggingFace transformers library
    from transformers import pipeline
    global sentiment_task
    sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest")
    yield
    # Clean up the model and release the resources
    del sentiment_task

# Initialize the FastAPI app
app = FastAPI(lifespan=lifespan)

# Define the input data model
class TextInput(BaseModel):
    text: str

# Define the welcome endpoint
@app.get('/')
async def welcome():
    return "Welcome to our Text Classification API"

# Validate input text length
MAX_TEXT_LENGTH = 1000

# Define the sentiment analysis endpoint 
@app.post('/analyze/{text}')
async def classify_text(text_input:TextInput):    
    try:
        # Convert input data to JSON serializable dictionary
        text_input_dict = jsonable_encoder(text_input)
        # Validate input data using Pydantic model
        text_data = TextInput(**text_input_dict)  # Convert to Pydantic model

        # Validate input text length
        if len(text_input.text) > MAX_TEXT_LENGTH:
            raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length")
        elif len(text_input.text) == 0:
            raise HTTPException(status_code=400, detail="Text cannot be empty")
    except ValidationError as e:
        # Handle validation error
        raise HTTPException(status_code=422, detail=str(e))

    try:
        # Perform text classification
        return sentiment_task(preprocess_text(text_input.text))
    except ValueError as ve:
        # Handle value error
        raise HTTPException(status_code=400, detail=str(ve))
    except Exception as e:
        # Handle other server errors
        raise HTTPException(status_code=500, detail=str(e))

# Load the model using FastAPI lifespan event so that the model is loaded at the beginning for efficiency
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Load the model from HuggingFace transformers library
    from transformers import pipeline
    global sentiment_task
    sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest")
    yield
    # Clean up the model and release the resources
    del sentiment_task

# Initialize the FastAPI app
app = FastAPI(lifespan=lifespan)

# Define the input data model
class TextInput(BaseModel):
    text: str

# Define the welcome endpoint
@app.get('/')
async def welcome():
    return "Welcome to our Text Classification API"

# Validate input text length
MAX_TEXT_LENGTH = 1000

# Define the sentiment analysis endpoint 
@app.post('/analyze/{text}')
async def classify_text(text_input:TextInput):    
    try:
        # Convert input data to JSON serializable dictionary
        text_input_dict = jsonable_encoder(text_input)
        # Validate input data using Pydantic model
        text_data = TextInput(**text_input_dict)  # Convert to Pydantic model

        # Validate input text length
        if len(text_input.text) > MAX_TEXT_LENGTH:
            raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length")
        elif len(text_input.text) == 0:
            raise HTTPException(status_code=400, detail="Text cannot be empty")
    except ValidationError as e:
        # Handle validation error
        raise HTTPException(status_code=422, detail=str(e))

    try:
        # Perform text classification
        return sentiment_task(preprocess_text(text_input.text))
    except ValueError as ve:
        # Handle value error
        raise HTTPException(status_code=400, detail=str(ve))
    except Exception as e:
        # Handle other server errors
        raise HTTPException(status_code=500, detail=str(e))