Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
# Load model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_name = "Ateeqq/Text-Rewriter-Paraphraser" | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
# Rewrite function | |
def rewrite(text): | |
input_ids = tokenizer(f"paraphraser: {text}", return_tensors="pt", truncation=True, max_length=1024).input_ids.to(device) | |
output = model.generate( | |
input_ids=input_ids, | |
num_beams=5, | |
no_repeat_ngram_size=3, | |
temperature=0.9, | |
max_length=1024, | |
early_stopping=True, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
# UI | |
st.title("📝 Text Rewriter (Paraphraser)") | |
text_input = st.text_area("Enter text to rewrite:", height=300) | |
if st.button("Rewrite"): | |
with st.spinner("Rewriting..."): | |
result = rewrite(text_input) | |
st.success("Done!") | |
st.markdown("### 🔁 Rewritten Text") | |
st.write(result) | |