Noveramaaz commited on
Commit
ef66f82
1 Parent(s): 0a9d367

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +75 -44
main.py CHANGED
@@ -1,76 +1,107 @@
 
 
 
 
 
 
 
1
  import re
2
  import string
3
  import nltk
4
- from fastapi import FastAPI, HTTPException
5
- from pydantic import BaseModel
6
- from typing import Optional
7
- from transformers import pipeline
8
- from pyngrok import ngrok
9
- import nest_asyncio
10
- from fastapi.responses import RedirectResponse
11
-
12
- # Download NLTK resources
13
  nltk.download('punkt')
14
  nltk.download('wordnet')
 
 
15
 
16
- # Initialize FastAPI app
17
- app = FastAPI()
18
-
19
- # Text preprocessing functions
20
  def remove_urls(text):
21
  return re.sub(r'http[s]?://\S+', '', text)
22
 
 
23
  def remove_punctuation(text):
24
  regular_punct = string.punctuation
25
- return re.sub(r'['+regular_punct+']', '', text)
26
 
 
27
  def lower_case(text):
28
  return text.lower()
29
 
 
30
  def lemmatize(text):
31
- wordnet_lemmatizer = nltk.WordNetLemmatizer()
 
32
  tokens = nltk.word_tokenize(text)
33
- return ' '.join([wordnet_lemmatizer.lemmatize(w) for w in tokens])
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Model loading
36
- lyx_pipe = pipeline("text-classification", model="lxyuan/distilbert-base-multilingual-cased-sentiments-student")
 
 
 
 
 
 
 
 
37
 
38
- # Input data model
 
 
 
 
 
 
 
 
 
39
  class TextInput(BaseModel):
40
  text: str
41
 
42
- # Welcome endpoint
43
  @app.get('/')
44
  async def welcome():
45
  # Redirect to the Swagger UI page
46
  return RedirectResponse(url="/docs")
 
 
 
47
 
48
- # Sentiment analysis endpoint
49
- @app.post('/analyze/')
50
- async def Predict_Sentiment(text_input: TextInput):
51
- text = text_input.text
 
 
 
 
52
 
53
- # Text preprocessing
54
- text = remove_urls(text)
55
- text = remove_punctuation(text)
56
- text = lower_case(text)
57
- text = lemmatize(text)
 
 
 
58
 
59
- # Perform sentiment analysis
60
  try:
61
- return lyx_pipe(text)
 
 
 
 
62
  except Exception as e:
 
63
  raise HTTPException(status_code=500, detail=str(e))
64
-
65
- # Run the FastAPI app using Uvicorn
66
- if __name__ == "__main__":
67
- # Create ngrok tunnel
68
- ngrok_tunnel = ngrok.connect(7860)
69
- print('Public URL:', ngrok_tunnel.public_url)
70
-
71
- # Allow nested asyncio calls
72
- nest_asyncio.apply()
73
-
74
- # Run the FastAPI app with Uvicorn
75
- import uvicorn
76
- uvicorn.run(app, port=7860)
 
1
+ from contextlib import asynccontextmanager
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel, ValidationError
4
+ from fastapi.encoders import jsonable_encoder
5
+
6
+ # TEXT PREPROCESSING
7
+ # --------------------------------------------------------------------
8
  import re
9
  import string
10
  import nltk
 
 
 
 
 
 
 
 
 
11
  nltk.download('punkt')
12
  nltk.download('wordnet')
13
+ nltk.download('omw-1.4')
14
+ from nltk.stem import WordNetLemmatizer
15
 
16
+ # Function to remove URLs from text
 
 
 
17
  def remove_urls(text):
18
  return re.sub(r'http[s]?://\S+', '', text)
19
 
20
+ # Function to remove punctuations from text
21
  def remove_punctuation(text):
22
  regular_punct = string.punctuation
23
+ return str(re.sub(r'['+regular_punct+']', '', str(text)))
24
 
25
+ # Function to convert the text into lower case
26
  def lower_case(text):
27
  return text.lower()
28
 
29
+ # Function to lemmatize text
30
  def lemmatize(text):
31
+ wordnet_lemmatizer = WordNetLemmatizer()
32
+
33
  tokens = nltk.word_tokenize(text)
34
+ lemma_txt = ''
35
+ for w in tokens:
36
+ lemma_txt = lemma_txt + wordnet_lemmatizer.lemmatize(w) + ' '
37
+
38
+ return lemma_txt
39
+
40
+ def preprocess_text(text):
41
+ # Preprocess the input text
42
+ text = remove_urls(text)
43
+ text = remove_punctuation(text)
44
+ text = lower_case(text)
45
+ text = lemmatize(text)
46
+ return text
47
 
48
+ # Load the model using FastAPI lifespan event so that teh model is loaded at the beginning for efficiency
49
+ @asynccontextmanager
50
+ async def lifespan(app: FastAPI):
51
+ # Load the model from HuggingFace transformers library
52
+ from transformers import pipeline
53
+ global sentiment_task
54
+ sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest")
55
+ yield
56
+ # Clean up the model and release the resources
57
+ del sentiment_task
58
 
59
+ description = """
60
+ ## Text Classification API
61
+ This app shows the sentiment of the text (positive, negative, or neutral).
62
+ Check out the docs for the `/analyze/{text}` endpoint below to try it out!
63
+ """
64
+
65
+ # Initialize the FastAPI app
66
+ app = FastAPI(lifespan=lifespan, docs_url="/", description=description)
67
+
68
+ # Define the input data model
69
  class TextInput(BaseModel):
70
  text: str
71
 
72
+ # Define the welcome endpoint
73
  @app.get('/')
74
  async def welcome():
75
  # Redirect to the Swagger UI page
76
  return RedirectResponse(url="/docs")
77
+
78
+ # Validate input text length
79
+ MAX_TEXT_LENGTH = 1000
80
 
81
+ # Define the sentiment analysis endpoint
82
+ @app.post('/analyze/{text}')
83
+ async def classify_text(text_input:TextInput):
84
+ try:
85
+ # Convert input data to JSON serializable dictionary
86
+ text_input_dict = jsonable_encoder(text_input)
87
+ # Validate input data using Pydantic model
88
+ text_data = TextInput(**text_input_dict) # Convert to Pydantic model
89
 
90
+ # Validate input text length
91
+ if len(text_input.text) > MAX_TEXT_LENGTH:
92
+ raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length")
93
+ elif len(text_input.text) == 0:
94
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
95
+ except ValidationError as e:
96
+ # Handle validation error
97
+ raise HTTPException(status_code=422, detail=str(e))
98
 
 
99
  try:
100
+ # Perform text classification
101
+ return sentiment_task(preprocess_text(text_input.text))
102
+ except ValueError as ve:
103
+ # Handle value error
104
+ raise HTTPException(status_code=400, detail=str(ve))
105
  except Exception as e:
106
+ # Handle other server errors
107
  raise HTTPException(status_code=500, detail=str(e))