File size: 5,297 Bytes
7be7efa 33761d2 7be7efa 33761d2 7be7efa 33761d2 8d4a0f7 33761d2 7be7efa 33761d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ValidationError
from fastapi.encoders import jsonable_encoder
# TEXT PREPROCESSING
# --------------------------------------------------------------------
import re
import string
import nltk
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4')
from nltk.stem import WordNetLemmatizer
# Function to remove URLs from text
def remove_urls(text):
return re.sub(r'http[s]?://\S+', '', text)
# Function to remove punctuations from text
def remove_punctuation(text):
regular_punct = string.punctuation
return str(re.sub(r'['+regular_punct+']', '', str(text)))
# Function to convert the text into lower case
def lower_case(text):
return text.lower()
# Function to lemmatize text
def lemmatize(text):
wordnet_lemmatizer = WordNetLemmatizer()
tokens = nltk.word_tokenize(text)
lemma_txt = ''
for w in tokens:
lemma_txt = lemma_txt + wordnet_lemmatizer.lemmatize(w) + ' '
return lemma_txt
def preprocess_text(text):
# Preprocess the input text
text = remove_urls(text)
text = remove_punctuation(text)
text = lower_case(text)
text = lemmatize(text)
return text
# Load the model using FastAPI lifespan event so that the model is loaded at the beginning for efficiency
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the model from HuggingFace transformers library
from transformers import pipeline
global sentiment_task
sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest")
yield
# Clean up the model and release the resources
del sentiment_task
# Initialize the FastAPI app
app = FastAPI(lifespan=lifespan)
# Define the input data model
class TextInput(BaseModel):
text: str
# Define the welcome endpoint
@app.get('/')
async def welcome():
return "Welcome to our Text Classification API"
# Validate input text length
MAX_TEXT_LENGTH = 1000
# Define the sentiment analysis endpoint
@app.post('/analyze/{text}')
async def classify_text(text_input:TextInput):
try:
# Convert input data to JSON serializable dictionary
text_input_dict = jsonable_encoder(text_input)
# Validate input data using Pydantic model
text_data = TextInput(**text_input_dict) # Convert to Pydantic model
# Validate input text length
if len(text_input.text) > MAX_TEXT_LENGTH:
raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length")
elif len(text_input.text) == 0:
raise HTTPException(status_code=400, detail="Text cannot be empty")
except ValidationError as e:
# Handle validation error
raise HTTPException(status_code=422, detail=str(e))
try:
# Perform text classification
return sentiment_task(preprocess_text(text_input.text))
except ValueError as ve:
# Handle value error
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
# Handle other server errors
raise HTTPException(status_code=500, detail=str(e))
# Load the model using FastAPI lifespan event so that the model is loaded at the beginning for efficiency
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the model from HuggingFace transformers library
from transformers import pipeline
global sentiment_task
sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest")
yield
# Clean up the model and release the resources
del sentiment_task
# Initialize the FastAPI app
app = FastAPI(lifespan=lifespan)
# Define the input data model
class TextInput(BaseModel):
text: str
# Define the welcome endpoint
@app.get('/')
async def welcome():
return "Welcome to our Text Classification API"
# Validate input text length
MAX_TEXT_LENGTH = 1000
# Define the sentiment analysis endpoint
@app.post('/analyze/{text}')
async def classify_text(text_input:TextInput):
try:
# Convert input data to JSON serializable dictionary
text_input_dict = jsonable_encoder(text_input)
# Validate input data using Pydantic model
text_data = TextInput(**text_input_dict) # Convert to Pydantic model
# Validate input text length
if len(text_input.text) > MAX_TEXT_LENGTH:
raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length")
elif len(text_input.text) == 0:
raise HTTPException(status_code=400, detail="Text cannot be empty")
except ValidationError as e:
# Handle validation error
raise HTTPException(status_code=422, detail=str(e))
try:
# Perform text classification
return sentiment_task(preprocess_text(text_input.text))
except ValueError as ve:
# Handle value error
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
# Handle other server errors
raise HTTPException(status_code=500, detail=str(e)) |