File size: 1,789 Bytes
e5ffa90
481c6b3
 
 
 
 
db2fdad
481c6b3
db2fdad
5f726f0
 
e5ffa90
db2fdad
 
e5ffa90
64909f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5ffa90
 
 
b3b3581
481c6b3
6c86820
e5ffa90
6c86820
 
 
 
e5ffa90
6c86820
 
 
245cab5
6c86820
e5ffa90
6c86820
e5ffa90
155ed73
481c6b3
 
155ed73
e5ffa90
6c86820
db2fdad
ba0c1c3
db2fdad
e5ffa90
db2fdad
e5ffa90
d481ecd
e42c394
 
 
b3b3581
e42c394
b3b3581
e42c394
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
# from tokenizer import tokenizer
import config
from transformers import pipeline, AutoTokenizer, AutoModel
import gradio as gr


# T = tokenizer.TweetTokenizer(
#     preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)

# def preprocess(text):
#     tokens = T.tokenize(text)
#     print(tokens, file=sys.stderr)
#     ptokens = []
#     for index, token in enumerate(tokens):
#         if "@" in token:
#             if index > 0:
#                 # check if previous token was mention
#                 if "@" in tokens[index-1]:
#                     pass
#                 else:
#                     ptokens.append("mention_0")
#             else:
#                 ptokens.append("mention_0")
#         else:
#             ptokens.append(token)

#     print(ptokens, file=sys.stderr)
#     return " ".join(ptokens)


def sentence_prediction(sentence):
    # sentence = preprocess(sentence)
    
    model_path = config.MODEL_PATH

    test_dataset = dataset.BERTDataset(
        review=[sentence],
        target=[0]
    )

    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        num_workers=-1
    )

    device = config.device

    model = BERTBaseUncased()
    model.load_state_dict(torch.load(
        model_path, map_location=torch.device(device)))
    model.to(device)

    outputs, [] = engine.predict_fn(test_data_loader, model, device)

    outputs =  classifier(sentence)
    
    print(outputs)
    return outputs #{"label":outputs[0]}




demo = gr.Interface(
  fn=sentence_prediction, 
  inputs='text',
  outputs='label',
)

demo.launch()