File size: 1,945 Bytes
02d2b5c
 
46eeeb0
 
 
 
 
02d2b5c
c8db798
 
 
 
d289606
46eeeb0
c8db798
46eeeb0
 
 
 
4805566
 
02d2b5c
46eeeb0
02d2b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
56c36b4
02d2b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56c36b4
02d2b5c
56c36b4
02d2b5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from transformers import BertForSequenceClassification
import gradio as gr
from transformers import BertTokenizer
import torch
from transformers import BertForSequenceClassification, BertTokenizer
import gradio as gr

import torch
from transformers import BertForSequenceClassification

# Load the model architecture with the number of labels
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Load the state dict while mapping to CPU
try:
    model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu')), strict=False)
except Exception as e:
    print(f"Error loading state dict: {e}")

    
model.eval()  # Set the model to evaluation mode


# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def predict(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = logits.argmax().item()
    return predicted_class

# Set up the Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")

# Load model and tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.load_state_dict(torch.load('bert_model_complete.pth'))
model.eval()

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define prediction function
def predict(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = logits.argmax().item()
    return predicted_class

# Set up Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")

# Launch the interface
if __name__ == "__main__":
    interface.launch()