|
import streamlit as st |
|
from transformers import BartForConditionalGeneration, BartTokenizer |
|
|
|
|
|
model_path = "disilbart-med-summary" |
|
tokenizer = BartTokenizer.from_pretrained(model_path) |
|
model = BartForConditionalGeneration.from_pretrained(model_path) |
|
|
|
|
|
def generate_summary(input_text): |
|
|
|
input_ids = tokenizer.encode(input_text, return_tensors="pt") |
|
|
|
|
|
summary_ids = model.generate(input_ids, max_length=4000, num_beams=4, no_repeat_ngram_size=2) |
|
|
|
|
|
summary_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
return summary_text |
|
|
|
|
|
def main(): |
|
|
|
st.markdown("<h3 style='text-align: center; color: #333;'>Medical Summary - Text Generation</h3>", unsafe_allow_html=True) |
|
|
|
|
|
user_input = st.text_area("Enter Text:", "") |
|
|
|
|
|
if st.button("Generate Summary"): |
|
if user_input: |
|
|
|
result = generate_summary(user_input) |
|
|
|
|
|
st.text_area("Generated Summary:", result, key="generated_summary") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|