File size: 1,562 Bytes
c78d747 f7ca09f c78d747 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import gradio as gra
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModel
import onnxruntime as rt
ort_session = rt.InferenceSession("/sin-kaf/onnx_model/model.onnx")
ort_session.get_providers()
# model = ORTModel.load_model("/DATA/sin-kaf/onnx_model/model.onnx")
# model = AutoModelForSequenceClassification.from_pretrained('/DATA/sin-kaf/test_trainer/checkpoint-18500')
tokenizer = AutoTokenizer.from_pretrained("Overfit-GM/distilbert-base-turkish-cased-offensive")
def user_greeting(sent):
encoded_dict = tokenizer.encode_plus(
sent,
add_special_tokens = True,
max_length = 64,
pad_to_max_length = True,
return_attention_mask = True,
return_tensors = 'pt',
)
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
input_ids = torch.cat([input_ids], dim=0)
input_mask = torch.cat([attention_masks], dim=0)
input_feed = {
"input_ids": input_ids.tolist(),
"attention_mask":input_mask.tolist(),
}
output = ort_session.run(None, input_feed)
return np.argmax((output[0][0]))
# outputs = model(input_ids, input_mask)
# return torch.argmax(outputs['logits'])
app = gra.Interface(fn = user_greeting, inputs="text", outputs="text")
app.launch()
# app.launch(server_name="0.0.0.0") |