MissingBreath's picture
Update api.py
65d6b85 verified
from fastapi import FastAPI, File, UploadFile
import numpy as np
from PIL import Image
import io
import tensorflow as tf
import os
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# # os.environ['HF_TOKEN']=''
# from huggingface_hub import login
# hf_token = os.getenv("HF_TOKEN")
# login(token=hf_token)
# Read token from environment
# hf_token = os.getenv("HF_TOKEN")
# print("HF_TOKEN:", hf_token)
# Load tokenizer directly with the token (no login)
# tokenizer = AutoTokenizer.from_pretrained(
# "chillies/distilbert-course-review-classification",
# token=hf_token # Pass it directly
# )
# tokenizer = AutoTokenizer.from_pretrained("chillies/distilbert-course-review-classification")
# from transformers import DistilBertTokenizer
# tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# model = AutoModelForSequenceClassification.from_pretrained("chillies/distilbert-course-review-classification")
# from transformers import DistilBertTokenizerFast
# tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
# from transformers import pipeline
# model = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_DIR = "./my_model"
TOKENIZER_DIR = "./my_tokenizer"
# Load the model and tokenizer
try:
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
def inference(review):
inputs = tokenizer(review, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
# Assuming the model outputs logits
predicted_class = outputs.logits.argmax(dim=-1).item()
class_labels = [
'Improvement Suggestions', 'Questions', 'Confusion', 'Support Request',
'Discussion', 'Course Comparison', 'Related Course Suggestions',
'Negative', 'Positive'
]
return class_labels[predicted_class]
from pydantic import BaseModel
from typing import List
class ReviewRequest(BaseModel):
reviews: List[str]
app = FastAPI()
@app.post("/classify")
async def classify(request: ReviewRequest):
print("HERE", request)
reviews = request.reviews
predictions = []
for review in reviews:
predicted_class = inference(review)
predictions.append(predicted_class)
return {"predictions": predictions}