Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
app = FastAPI() | |
def greet_json(): | |
return {"Hello": "World!"} | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
from bertopic import BERTopic | |
model = BERTopic.load("sleiyer/restricted_item_detector") | |
# Load the trained model and tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Function to predict the class of a single input text | |
def predict(text): | |
# Preprocess the input text | |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True) | |
# Make predictions | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get the predicted class | |
logits = outputs.logits | |
predicted_class = torch.argmax(logits, dim=1).item() | |
label_map = {0: 'Allowed Item', 1: 'Restricted Item'} | |
# Map the predicted class to a human-readable label | |
predicted_label = label_map[predicted_class] | |
# Displaying the user input | |
return f'The item "{text}" is classified as: "{predicted_label}"' | |
return predicted_class | |
def predict(input): | |
return predict(input) | |