Spaces:
Sleeping
Sleeping
import os.path as osp | |
import streamlit as st | |
from spacy import displacy | |
import torch | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForTokenClassification, \ | |
AutoModelForSeq2SeqLM, pipeline | |
model_paths = { | |
'Text Summarization': './data/pretrained_models/roberta2roberta_L-24_bbc', | |
'Question Answering': './data/pretrained_models/roberta-base-squad2', | |
'Entity Recognition': './data/pretrained_models/xlm-roberta-large-finetuned-conll03-english' | |
} | |
def get_tokenizer(operation): | |
path = model_paths[operation] | |
tokenizer = AutoTokenizer.from_pretrained(path) | |
return tokenizer | |
def get_model(operation): | |
path = model_paths[operation] | |
if operation == 'Text Summarization': | |
model = AutoModelForSeq2SeqLM.from_pretrained(path) | |
elif operation == 'Question Answering': | |
model = AutoModelForQuestionAnswering.from_pretrained(path) | |
elif operation == 'Entity Recognition': | |
model = AutoModelForTokenClassification.from_pretrained(path) | |
return model | |
def get_predictions(input_text, operations, input_question=None): | |
all_outputs = [] | |
for operation in operations: | |
tokenizer = get_tokenizer(operation) | |
model = get_model(operation) | |
if operation == 'Text Summarization': | |
summary = get_summary(input_text, tokenizer, model) | |
all_outputs.append(('Predicted Summary', summary)) | |
if operation == 'Question Answering': | |
answer = get_answer(input_text, tokenizer, model, input_question) | |
all_outputs.append(('Predicted Answer', answer)) | |
if operation == 'Entity Recognition': | |
output = get_entities(input_text, tokenizer, model) | |
all_outputs.append(('Predicted Entities', output)) | |
return all_outputs | |
def get_summary(input_text, _tokenizer, _model): | |
inputs = _tokenizer(input_text, max_length=512, return_tensors="pt", truncation=True) | |
summary_ids = _model.generate(inputs["input_ids"], num_beams=3, max_length=50, early_stopping=True)[0] | |
summary = _tokenizer.decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
return summary | |
def get_answer(input_text, _tokenizer, _model, input_question): | |
inputs = _tokenizer(input_question, input_text, max_length=512, return_tensors="pt", truncation=True) | |
with torch.inference_mode(): | |
answer = _model(**inputs) | |
answer_start_index = answer.start_logits.argmax() | |
answer_end_index = answer.end_logits.argmax() | |
predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1] | |
predicted_answer = _tokenizer.decode(predict_answer_tokens) | |
return predicted_answer | |
def get_entities(input_text, _tokenizer, _model): | |
classifier = pipeline("ner", model=_model, tokenizer=_tokenizer, aggregation_strategy='simple') | |
output = classifier(input_text) | |
output_dict = {'text': input_text} | |
entities = [] | |
for i in output: | |
temp = {'start': i['start'], 'end': i['end'], 'label': i['entity_group']} | |
entities.append(temp) | |
output_dict['ents'] = entities | |
display_output = displacy.render(output_dict, style='ent', manual=True) | |
return display_output | |
if __name__ == '__main__': | |
print("hello") | |
input_text = '''Lord of the Rings book was written by JRR Tolkein. It is one of the most popular fictional book ever written. | |
A very popular movie series is also based on the same book. Now, people are starting a television series based on the book as well.''' | |
path = model_paths['Entity Recognition'] | |
tokenizer = AutoTokenizer.from_pretrained(path) | |
model = AutoModelForTokenClassification.from_pretrained(path) | |
classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy='simple') | |
output = classifier(input_text) | |
output_dict = {'text': input_text} | |
entities = [] | |
for i in output: | |
temp = {'start': i['start'], 'end': i['end'], 'label': i['entity_group']} | |
entities.append(temp) | |
output_dict['ents'] = entities | |
display_output = displacy.render(output_dict, style='ent', manual=True) | |
print(display_output) | |