Noveramaaz commited on
Commit
0a9d367
1 Parent(s): 0683671

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -121
main.py CHANGED
@@ -1,154 +1,76 @@
1
- from contextlib import asynccontextmanager
2
- from fastapi import FastAPI, HTTPException
3
- from pydantic import BaseModel, ValidationError
4
- from fastapi.encoders import jsonable_encoder
5
-
6
- # TEXT PREPROCESSING
7
- # --------------------------------------------------------------------
8
  import re
9
  import string
10
  import nltk
 
 
 
 
 
 
 
 
 
11
  nltk.download('punkt')
12
  nltk.download('wordnet')
13
- nltk.download('omw-1.4')
14
- from nltk.stem import WordNetLemmatizer
15
 
16
- # Function to remove URLs from text
 
 
 
17
  def remove_urls(text):
18
  return re.sub(r'http[s]?://\S+', '', text)
19
 
20
- # Function to remove punctuations from text
21
  def remove_punctuation(text):
22
  regular_punct = string.punctuation
23
- return str(re.sub(r'['+regular_punct+']', '', str(text)))
24
 
25
- # Function to convert the text into lower case
26
  def lower_case(text):
27
  return text.lower()
28
 
29
- # Function to lemmatize text
30
  def lemmatize(text):
31
- wordnet_lemmatizer = WordNetLemmatizer()
32
-
33
  tokens = nltk.word_tokenize(text)
34
- lemma_txt = ''
35
- for w in tokens:
36
- lemma_txt = lemma_txt + wordnet_lemmatizer.lemmatize(w) + ' '
37
 
38
- return lemma_txt
 
39
 
40
- def preprocess_text(text):
41
- # Preprocess the input text
42
- text = remove_urls(text)
43
- text = remove_punctuation(text)
44
- text = lower_case(text)
45
- text = lemmatize(text)
46
- return text
47
-
48
- # Load the model using FastAPI lifespan event so that the model is loaded at the beginning for efficiency
49
- @asynccontextmanager
50
- async def lifespan(app: FastAPI):
51
- # Load the model from HuggingFace transformers library
52
- from transformers import pipeline
53
- global sentiment_task
54
- sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest")
55
- yield
56
- # Clean up the model and release the resources
57
- del sentiment_task
58
-
59
- # Initialize the FastAPI app
60
- app = FastAPI(lifespan=lifespan)
61
-
62
- # Define the input data model
63
  class TextInput(BaseModel):
64
  text: str
65
 
66
- # Define the welcome endpoint
67
  @app.get('/')
68
  async def welcome():
69
- return "Welcome to our Text Classification API"
 
70
 
71
- # Validate input text length
72
- MAX_TEXT_LENGTH = 1000
 
 
73
 
74
- # Define the sentiment analysis endpoint
75
- @app.post('/analyze/{text}')
76
- async def classify_text(text_input:TextInput):
77
- try:
78
- # Convert input data to JSON serializable dictionary
79
- text_input_dict = jsonable_encoder(text_input)
80
- # Validate input data using Pydantic model
81
- text_data = TextInput(**text_input_dict) # Convert to Pydantic model
82
-
83
- # Validate input text length
84
- if len(text_input.text) > MAX_TEXT_LENGTH:
85
- raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length")
86
- elif len(text_input.text) == 0:
87
- raise HTTPException(status_code=400, detail="Text cannot be empty")
88
- except ValidationError as e:
89
- # Handle validation error
90
- raise HTTPException(status_code=422, detail=str(e))
91
 
 
92
  try:
93
- # Perform text classification
94
- return sentiment_task(preprocess_text(text_input.text))
95
- except ValueError as ve:
96
- # Handle value error
97
- raise HTTPException(status_code=400, detail=str(ve))
98
  except Exception as e:
99
- # Handle other server errors
100
  raise HTTPException(status_code=500, detail=str(e))
101
 
102
- # Load the model using FastAPI lifespan event so that the model is loaded at the beginning for efficiency
103
- @asynccontextmanager
104
- async def lifespan(app: FastAPI):
105
- # Load the model from HuggingFace transformers library
106
- from transformers import pipeline
107
- global sentiment_task
108
- sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest")
109
- yield
110
- # Clean up the model and release the resources
111
- del sentiment_task
112
-
113
- # Initialize the FastAPI app
114
- app = FastAPI(lifespan=lifespan)
115
 
116
- # Define the input data model
117
- class TextInput(BaseModel):
118
- text: str
119
 
120
- # Define the welcome endpoint
121
- @app.get('/')
122
- async def welcome():
123
- return "Welcome to our Text Classification API"
124
-
125
- # Validate input text length
126
- MAX_TEXT_LENGTH = 1000
127
-
128
- # Define the sentiment analysis endpoint
129
- @app.post('/analyze/{text}')
130
- async def classify_text(text_input:TextInput):
131
- try:
132
- # Convert input data to JSON serializable dictionary
133
- text_input_dict = jsonable_encoder(text_input)
134
- # Validate input data using Pydantic model
135
- text_data = TextInput(**text_input_dict) # Convert to Pydantic model
136
-
137
- # Validate input text length
138
- if len(text_input.text) > MAX_TEXT_LENGTH:
139
- raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length")
140
- elif len(text_input.text) == 0:
141
- raise HTTPException(status_code=400, detail="Text cannot be empty")
142
- except ValidationError as e:
143
- # Handle validation error
144
- raise HTTPException(status_code=422, detail=str(e))
145
-
146
- try:
147
- # Perform text classification
148
- return sentiment_task(preprocess_text(text_input.text))
149
- except ValueError as ve:
150
- # Handle value error
151
- raise HTTPException(status_code=400, detail=str(ve))
152
- except Exception as e:
153
- # Handle other server errors
154
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
1
  import re
2
  import string
3
  import nltk
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
6
+ from typing import Optional
7
+ from transformers import pipeline
8
+ from pyngrok import ngrok
9
+ import nest_asyncio
10
+ from fastapi.responses import RedirectResponse
11
+
12
+ # Download NLTK resources
13
  nltk.download('punkt')
14
  nltk.download('wordnet')
 
 
15
 
16
+ # Initialize FastAPI app
17
+ app = FastAPI()
18
+
19
+ # Text preprocessing functions
20
  def remove_urls(text):
21
  return re.sub(r'http[s]?://\S+', '', text)
22
 
 
23
  def remove_punctuation(text):
24
  regular_punct = string.punctuation
25
+ return re.sub(r'['+regular_punct+']', '', text)
26
 
 
27
  def lower_case(text):
28
  return text.lower()
29
 
 
30
  def lemmatize(text):
31
+ wordnet_lemmatizer = nltk.WordNetLemmatizer()
 
32
  tokens = nltk.word_tokenize(text)
33
+ return ' '.join([wordnet_lemmatizer.lemmatize(w) for w in tokens])
 
 
34
 
35
+ # Model loading
36
+ lyx_pipe = pipeline("text-classification", model="lxyuan/distilbert-base-multilingual-cased-sentiments-student")
37
 
38
+ # Input data model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class TextInput(BaseModel):
40
  text: str
41
 
42
+ # Welcome endpoint
43
  @app.get('/')
44
  async def welcome():
45
+ # Redirect to the Swagger UI page
46
+ return RedirectResponse(url="/docs")
47
 
48
+ # Sentiment analysis endpoint
49
+ @app.post('/analyze/')
50
+ async def Predict_Sentiment(text_input: TextInput):
51
+ text = text_input.text
52
 
53
+ # Text preprocessing
54
+ text = remove_urls(text)
55
+ text = remove_punctuation(text)
56
+ text = lower_case(text)
57
+ text = lemmatize(text)
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Perform sentiment analysis
60
  try:
61
+ return lyx_pipe(text)
 
 
 
 
62
  except Exception as e:
 
63
  raise HTTPException(status_code=500, detail=str(e))
64
 
65
+ # Run the FastAPI app using Uvicorn
66
+ if __name__ == "__main__":
67
+ # Create ngrok tunnel
68
+ ngrok_tunnel = ngrok.connect(7860)
69
+ print('Public URL:', ngrok_tunnel.public_url)
 
 
 
 
 
 
 
 
70
 
71
+ # Allow nested asyncio calls
72
+ nest_asyncio.apply()
 
73
 
74
+ # Run the FastAPI app with Uvicorn
75
+ import uvicorn
76
+ uvicorn.run(app, port=7860)