Kurkur99's picture
Update app.py
f20c8b9 verified
raw
history blame
2.28 kB
import gradio as gr
import torch
from transformers import BertForSequenceClassification, BertTokenizer
# Load the tokenizer from Hugging Face
token_model = "indolem/indobertweet-base-uncased"
tokenizer = BertTokenizer.from_pretrained(token_model)
# Define the model directory where your config.json and pytorch_model.bin are located
model_directory = "modeling" # Make sure this directory has config.json and pytorch_model.bin
# Load the model
# If your weights are named differently, ensure the file is named pytorch_model.bin or modify the loading method
model = BertForSequenceClassification.from_pretrained(model_directory)
model.eval() # Set the model to evaluation mode
# Check if CUDA is available and set the device accordingly
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def classify_transaction(notes):
# Tokenize the input text
inputs = tokenizer.encode_plus(
notes,
None,
add_special_tokens=True,
max_length=256,
padding='max_length',
return_token_type_ids=False,
return_attention_mask=True,
truncation=True,
return_tensors='pt'
)
# Move tensors to the same device as the model
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
# Model in evaluation mode
model.eval()
# Make prediction
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
# Extract logits and convert to probabilities
logits = outputs[0]
probabilities = torch.softmax(logits, dim=1)
# Get the predicted class
predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy()
# Return the predicted class
return f"Predicted Category: {predicted_class}"
# Creating the Gradio interface
iface = gr.Interface(
fn=classify_transaction,
inputs=gr.Textbox(lines=3, placeholder="Enter Transaction Notes Here", label="Transaction Notes"),
outputs=gr.Text(label="Classification Result"),
title="Transaction Category Classifier",
description="Enter transaction notes to get the predicted category.",
live=True # Update the output as soon as the input changes
)
if __name__ == "__main__":
iface.launch()