samithva's picture
Update app.py
0afe598 verified
import gradio as gr
import torch
from khmernltk import word_tokenize
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load your model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(
"./",
# load_in_8bit=True, # Use if you want to load in 8-bit quantized format
# torch_dtype=torch.float16, # Use appropriate dtype based on your GPU
# device_map="cuda:0" # Automatically map model to available devices
)
tokenizer = AutoTokenizer.from_pretrained("./")
# Ensure the model is in evaluation mode
model.eval()
class_labels = {
0: "non-accident",
1: "accident"
# Add more labels if you have more classes
}
# Define the inference function
def classify(text):
words = word_tokenize(text)
sent = ' '.join(words)
print(f'sent : {sent}')
encoded_dict = tokenizer.encode_plus(
sent, # Sentence to encode.
add_special_tokens = True, # Add '[CLS]' and '[SEP]'
max_length = 512, # 64 Pad & truncate all sentences.
pad_to_max_length = True,
return_attention_mask = True, # Construct attn. masks.
return_tensors = 'pt', # Return pytorch tensors.
)
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
with torch.no_grad(): # Disable gradient calculation
outputs = model(input_ids, attention_masks)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
return class_labels[predictions.item()]
# Set up Gradio interface
interface = gr.Interface(fn=classify,
inputs="text",
outputs="text",
title="Accident Classification",
description="Enter a text to classify it.")
# Launch the interface
interface.launch(True)