from transformers import BertTokenizer, BertForSequenceClassification import torch # Load the trained model and tokenizer # model = BertForSequenceClassification.from_pretrained("/Users/slei/hackweek2024-sup-genai-tools/spaces/restricted_item_detector/trained_model") model = BertForSequenceClassification.from_pretrained("sleiyer/restricted_item_detector") tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # Function to predict the class of a single input text def predict(text): # Preprocess the input text inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True) # Make predictions with torch.no_grad(): outputs = model(**inputs) # Get the predicted class logits = outputs.logits predicted_class = torch.argmax(logits, dim=1).item() return predicted_class label_map = {0: 'Allowed Item', 1: 'Restricted Item'} def main(): while True: # Prompting the user for input user_input = input("Enter something: ") predicted_class = predict(user_input) # Map the predicted class to a human-readable label predicted_label = label_map[predicted_class] # Displaying the user input print(f'The item "{user_input}" is classified as: "{predicted_label}"') if __name__ == "__main__": main()