|
|
from fastapi import FastAPI, Request, Form |
|
|
from fastapi.responses import HTMLResponse |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
import mlflow |
|
|
import pickle |
|
|
import os |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from nltk.stem import WordNetLemmatizer |
|
|
from nltk.corpus import stopwords |
|
|
import string |
|
|
import re |
|
|
import dagshub |
|
|
import nltk |
|
|
|
|
|
import warnings |
|
|
warnings.simplefilter("ignore", UserWarning) |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
try: |
|
|
nltk.download('stopwords', quiet=True) |
|
|
nltk.download('wordnet', quiet=True) |
|
|
nltk.download('omw-1.4', quiet=True) |
|
|
except: |
|
|
pass |
|
|
|
|
|
def lemmatization(text): |
|
|
"""Lemmatize the text.""" |
|
|
lemmatizer = WordNetLemmatizer() |
|
|
text = text.split() |
|
|
text = [lemmatizer.lemmatize(word) for word in text] |
|
|
return " ".join(text) |
|
|
|
|
|
def remove_stop_words(text): |
|
|
"""Remove stop words from the text.""" |
|
|
stop_words = set(stopwords.words("english")) |
|
|
text = [word for word in str(text).split() if word not in stop_words] |
|
|
return " ".join(text) |
|
|
|
|
|
def removing_numbers(text): |
|
|
"""Remove numbers from the text.""" |
|
|
text = ''.join([char for char in text if not char.isdigit()]) |
|
|
return text |
|
|
|
|
|
def lower_case(text): |
|
|
"""Convert text to lower case.""" |
|
|
text = text.split() |
|
|
text = [word.lower() for word in text] |
|
|
return " ".join(text) |
|
|
|
|
|
def removing_punctuations(text): |
|
|
"""Remove punctuations from the text.""" |
|
|
text = re.sub('[%s]' % re.escape(string.punctuation), ' ', text) |
|
|
text = text.replace('؛', "") |
|
|
text = re.sub('\s+', ' ', text).strip() |
|
|
return text |
|
|
|
|
|
def removing_urls(text): |
|
|
"""Remove URLs from the text.""" |
|
|
url_pattern = re.compile(r'https?://\S+|www\.\S+') |
|
|
return url_pattern.sub(r'', text) |
|
|
|
|
|
def remove_small_sentences(df): |
|
|
"""Remove sentences with less than 3 words.""" |
|
|
for i in range(len(df)): |
|
|
if len(df.text.iloc[i].split()) < 3: |
|
|
df.text.iloc[i] = np.nan |
|
|
|
|
|
def normalize_text(text): |
|
|
text = lower_case(text) |
|
|
text = remove_stop_words(text) |
|
|
text = removing_numbers(text) |
|
|
text = removing_punctuations(text) |
|
|
text = removing_urls(text) |
|
|
text = lemmatization(text) |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dagshub_token = os.getenv("CAPSTONE_TEST") |
|
|
if not dagshub_token: |
|
|
raise EnvironmentError("CAPSTONE_TEST environment variable is not set") |
|
|
|
|
|
os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_token |
|
|
os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token |
|
|
|
|
|
dagshub_url = "https://dagshub.com" |
|
|
repo_owner = "CodeBy-HP" |
|
|
repo_name = "Sentiment-Classification-Mlflow-DVC" |
|
|
|
|
|
mlflow.set_tracking_uri(f'{dagshub_url}/{repo_owner}/{repo_name}.mlflow') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Sentiment Analysis API", version="1.0.0") |
|
|
|
|
|
|
|
|
current_file_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
templates_dir = os.path.join(current_file_dir, "templates") |
|
|
templates = Jinja2Templates(directory=templates_dir) |
|
|
|
|
|
|
|
|
|
|
|
model_name = "my_model" |
|
|
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
vectorizer_path = os.path.join(current_dir, 'models', 'vectorizer.pkl') |
|
|
if not os.path.exists(vectorizer_path): |
|
|
|
|
|
alt_paths = [ |
|
|
os.path.join(os.getcwd(), 'models', 'vectorizer.pkl'), |
|
|
os.path.join(current_dir, '..', 'models', 'vectorizer.pkl'), |
|
|
'/app/models/vectorizer.pkl' |
|
|
] |
|
|
for path in alt_paths: |
|
|
if os.path.exists(path): |
|
|
vectorizer_path = path |
|
|
break |
|
|
|
|
|
def get_latest_model_version(model_name): |
|
|
client = mlflow.MlflowClient() |
|
|
latest_version = client.get_latest_versions(model_name, stages=["Production"]) |
|
|
if not latest_version: |
|
|
latest_version = client.get_latest_versions(model_name, stages=["None"]) |
|
|
return latest_version[0].version if latest_version else None |
|
|
|
|
|
model_version = get_latest_model_version(model_name) |
|
|
model_uri = f'models:/{model_name}/{model_version}' |
|
|
print(f"Fetching model from: {model_uri}") |
|
|
model = mlflow.sklearn.load_model(model_uri) |
|
|
vectorizer = pickle.load(open(vectorizer_path, 'rb')) |
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def home(request: Request): |
|
|
"""Render the home page.""" |
|
|
return templates.TemplateResponse( |
|
|
request=request, |
|
|
name="index.html", |
|
|
context={"result": None} |
|
|
) |
|
|
|
|
|
@app.post("/predict", response_class=HTMLResponse) |
|
|
async def predict(request: Request, text: str = Form(...)): |
|
|
"""Handle sentiment prediction.""" |
|
|
|
|
|
cleaned_text = normalize_text(text) |
|
|
|
|
|
|
|
|
features = vectorizer.transform([cleaned_text]) |
|
|
|
|
|
features_array = features.toarray() |
|
|
|
|
|
|
|
|
result = model.predict(features_array) |
|
|
prediction = int(result[0]) |
|
|
|
|
|
|
|
|
|
|
|
probabilities = model.predict_proba(features_array)[0] |
|
|
confidence = float(probabilities[prediction]) * 100 |
|
|
|
|
|
return templates.TemplateResponse( |
|
|
request=request, |
|
|
name="index.html", |
|
|
context={"result": prediction, "confidence": confidence} |
|
|
) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint for monitoring.""" |
|
|
return {"status": "healthy", "model_version": model_version} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|