|
import streamlit as st |
|
from transformers import ( |
|
AutoModelForSeq2SeqLM, |
|
AutoModelForTokenClassification, |
|
AutoTokenizer) |
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") |
|
extractive_summary = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") |
|
tag_model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def summarize(text): |
|
input_ids = tokenizer.encode(text, return_tensors="pt") |
|
output = extractive_summary.generate(input_ids, max_new_tokens=4096) |
|
return tokenizer.decode(output[0]) |
|
|
|
def tag(text): |
|
input_ids = tokenizer.encode(text, return_tensors="pt") |
|
|
|
|
|
output = tag_model(input_ids)[0] |
|
return [(output.words[i], output.labels[i]) for i in range(len(output.words))] |
|
|
|
|
|
|
|
def optimize(text): |
|
|
|
return optimized_text |
|
|
|
st.title("NLP Demo") |
|
|
|
text = st.text_area("Input text:", "Enter text here") |
|
|
|
if st.button("Summarize"): |
|
summary = summarize(text) |
|
st.write(summary) |
|
|
|
if st.button("Tag"): |
|
tags = tag(text) |
|
st.write(tags) |
|
|
|
if st.button("Optimize"): |
|
optimized_text = optimize(text) |
|
st.write(optimized_text) |