alinikkhah's picture
Update app.py
c8db798 verified
raw
history blame contribute delete
No virus
1.95 kB
import torch
from transformers import BertForSequenceClassification
import gradio as gr
from transformers import BertTokenizer
import torch
from transformers import BertForSequenceClassification, BertTokenizer
import gradio as gr
import torch
from transformers import BertForSequenceClassification
# Load the model architecture with the number of labels
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
# Load the state dict while mapping to CPU
try:
model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu')), strict=False)
except Exception as e:
print(f"Error loading state dict: {e}")
model.eval() # Set the model to evaluation mode
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
return predicted_class
# Set up the Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
# Load model and tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.load_state_dict(torch.load('bert_model_complete.pth'))
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Define prediction function
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
return predicted_class
# Set up Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
# Launch the interface
if __name__ == "__main__":
interface.launch()