MassageMateNLP / app.py
BiEchi
final push
84d3475
raw
history blame
No virus
971 Bytes
import gradio as gr
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertModel
# ignore warnings
import warnings
warnings.filterwarnings("ignore")
def infer(text):
output_str = ''
for col in ['position_x', 'position_y', 'force', 'velocity_xy', 'velocity_z']:
model_path = f'models/bert/{col}'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
scores = output[0].detach().cpu().numpy()[0]
answer = ['-1', '0', '1'][scores.argmax()]
output_str += f'{col}: {answer}\n'
return output_str
iface = gr.Interface(fn=infer, inputs="text", outputs="text")
iface.launch()