NLP-Demo / inference.py
niks-salodkar's picture
added files and code
606291c
raw history blame
No virus
4.27 kB
import os.path as osp
import streamlit as st
from spacy import displacy
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForTokenClassification, \
AutoModelForSeq2SeqLM, pipeline
model_paths = {
'Text Summarization': './data/pretrained_models/roberta2roberta_L-24_bbc',
'Question Answering': './data/pretrained_models/roberta-base-squad2',
'Entity Recognition': './data/pretrained_models/xlm-roberta-large-finetuned-conll03-english'
}
@st.experimental_memo
def get_tokenizer(operation):
path = model_paths[operation]
tokenizer = AutoTokenizer.from_pretrained(path)
return tokenizer
@st.experimental_memo
def get_model(operation):
path = model_paths[operation]
if operation == 'Text Summarization':
model = AutoModelForSeq2SeqLM.from_pretrained(path)
elif operation == 'Question Answering':
model = AutoModelForQuestionAnswering.from_pretrained(path)
elif operation == 'Entity Recognition':
model = AutoModelForTokenClassification.from_pretrained(path)
return model
def get_predictions(input_text, operations, input_question=None):
all_outputs = []
for operation in operations:
tokenizer = get_tokenizer(operation)
model = get_model(operation)
if operation == 'Text Summarization':
summary = get_summary(input_text, tokenizer, model)
all_outputs.append(('Predicted Summary', summary))
if operation == 'Question Answering':
answer = get_answer(input_text, tokenizer, model, input_question)
all_outputs.append(('Predicted Answer', answer))
if operation == 'Entity Recognition':
output = get_entities(input_text, tokenizer, model)
all_outputs.append(('Predicted Entities', output))
return all_outputs
@st.experimental_memo
def get_summary(input_text, _tokenizer, _model):
inputs = _tokenizer(input_text, max_length=512, return_tensors="pt", truncation=True)
summary_ids = _model.generate(inputs["input_ids"], num_beams=3, max_length=50, early_stopping=True)[0]
summary = _tokenizer.decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return summary
@st.experimental_memo
def get_answer(input_text, _tokenizer, _model, input_question):
inputs = _tokenizer(input_question, input_text, max_length=512, return_tensors="pt", truncation=True)
with torch.inference_mode():
answer = _model(**inputs)
answer_start_index = answer.start_logits.argmax()
answer_end_index = answer.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1]
predicted_answer = _tokenizer.decode(predict_answer_tokens)
return predicted_answer
@st.experimental_memo
def get_entities(input_text, _tokenizer, _model):
classifier = pipeline("ner", model=_model, tokenizer=_tokenizer, aggregation_strategy='simple')
output = classifier(input_text)
output_dict = {'text': input_text}
entities = []
for i in output:
temp = {'start': i['start'], 'end': i['end'], 'label': i['entity_group']}
entities.append(temp)
output_dict['ents'] = entities
display_output = displacy.render(output_dict, style='ent', manual=True)
return display_output
if __name__ == '__main__':
print("hello")
input_text = '''Lord of the Rings book was written by JRR Tolkein. It is one of the most popular fictional book ever written.
A very popular movie series is also based on the same book. Now, people are starting a television series based on the book as well.'''
path = model_paths['Entity Recognition']
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForTokenClassification.from_pretrained(path)
classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy='simple')
output = classifier(input_text)
output_dict = {'text': input_text}
entities = []
for i in output:
temp = {'start': i['start'], 'end': i['end'], 'label': i['entity_group']}
entities.append(temp)
output_dict['ents'] = entities
display_output = displacy.render(output_dict, style='ent', manual=True)
print(display_output)