stance-pl / app.py
Dawid Motyka
minor interface changes
24fa0b5
raw
history blame contribute delete
No virus
2.18 kB
import os
import gradio as gr
import numpy as np
import torch
from transformers import AutoTokenizer, Pipeline
from inference_utils import prepare_stance_texts
from models import StanceEncoderModel
CLASS_DICT = {0: 'FAVOR', 1: 'AGAINST', 2: 'NEITHER'}
params = {'lang': 'pl',
'masked_lm_prompt': 4, }
class StancePipeline(Pipeline):
def _sanitize_parameters(self, **pipeline_parameters):
return pipeline_parameters, {}, {}
def preprocess(self, input):
prompt_text, prompt_target = prepare_stance_texts([input['text'], ], [input['target'], ], params,
self.tokenizer)
inputs = self.tokenizer(prompt_text, prompt_target, return_tensors="pt", padding=True, truncation='only_first')
return {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'],
'sequence_ids': torch.tensor((np.array(inputs.sequence_ids()) == 1).astype(int)).unsqueeze(0)}
def _forward(self, model_inputs):
outputs = self.model(**model_inputs)
return outputs
def postprocess(self, model_outputs):
probas = model_outputs["logits"].softmax(-1)
score = probas.max(-1)[0].item()
return {'stance': CLASS_DICT[probas.argmax(-1).item()], 'score': score}
model = StanceEncoderModel.from_pretrained('clarin-knext/stance-pl-1',
use_auth_token=os.environ['TOKEN'])
tokenizer = AutoTokenizer.from_pretrained('clarin-knext/stance-pl-1',
use_auth_token=os.environ['TOKEN'])
pipeline = StancePipeline(model=model, tokenizer=tokenizer, batch_size=1)
def predict(text, target):
predictions = pipeline({'text': text, 'target': target})
return f'{predictions["stance"]} ({predictions["score"]:.3f})'
gradio_app = gr.Interface(
predict,
inputs=[gr.TextArea(label="Text", placeholder="Enter text here..."),
gr.Textbox(label="Target", placeholder="Enter stance target here...")],
outputs=[gr.Label(label="Stance class")],
title="Polish stance detection",
)
if __name__ == "__main__":
gradio_app.launch()