import gradio as gr import torch from transformers import BertForSequenceClassification, BertTokenizer # Load the tokenizer from Hugging Face token_model = "indolem/indobertweet-base-uncased" tokenizer = BertTokenizer.from_pretrained(token_model) # Define the model directory where your config.json and pytorch_model.bin are located model_directory = "model_directory" # Make sure this directory has config.json and pytorch_model.bin # Load the model # If your weights are named differently, ensure the file is named pytorch_model.bin or modify the loading method model = BertForSequenceClassification.from_pretrained(model_directory) model.eval() # Set the model to evaluation mode # Check if CUDA is available and set the device accordingly device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def classify_transaction(notes): # Tokenize the input text inputs = tokenizer.encode_plus( notes, None, add_special_tokens=True, max_length=256, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, return_tensors='pt' ) # Move tensors to the same device as the model input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) # Model in evaluation mode model.eval() # Make prediction with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask) # Extract logits and convert to probabilities logits = outputs[0] probabilities = torch.softmax(logits, dim=1) # Get the predicted class predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy() # Return the predicted class return f"Predicted Category: {predicted_class}" # Creating the Gradio interface iface = gr.Interface( fn=classify_transaction, inputs=gr.Textbox(lines=3, placeholder="Enter Transaction Notes Here", label="Transaction Notes"), outputs=gr.Text(label="Classification Result"), title="Transaction Category Classifier", description="Enter transaction notes to get the predicted category.", live=True # Update the output as soon as the input changes ) if __name__ == "__main__": iface.launch()