Spaces:
Runtime error
Runtime error
File size: 1,945 Bytes
02d2b5c 46eeeb0 02d2b5c c8db798 d289606 46eeeb0 c8db798 46eeeb0 4805566 02d2b5c 46eeeb0 02d2b5c 56c36b4 02d2b5c 56c36b4 02d2b5c 56c36b4 02d2b5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
from transformers import BertForSequenceClassification
import gradio as gr
from transformers import BertTokenizer
import torch
from transformers import BertForSequenceClassification, BertTokenizer
import gradio as gr
import torch
from transformers import BertForSequenceClassification
# Load the model architecture with the number of labels
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
# Load the state dict while mapping to CPU
try:
model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu')), strict=False)
except Exception as e:
print(f"Error loading state dict: {e}")
model.eval() # Set the model to evaluation mode
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
return predicted_class
# Set up the Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
# Load model and tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.load_state_dict(torch.load('bert_model_complete.pth'))
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Define prediction function
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
return predicted_class
# Set up Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
# Launch the interface
if __name__ == "__main__":
interface.launch()
|