Rewritter / app.py
abuhanzala's picture
Create app.py
192180f verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "Ateeqq/Text-Rewriter-Paraphraser"
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
return tokenizer, model
tokenizer, model = load_model()
# Rewrite function
def rewrite(text):
input_ids = tokenizer(f"paraphraser: {text}", return_tensors="pt", truncation=True, max_length=1024).input_ids.to(device)
output = model.generate(
input_ids=input_ids,
num_beams=5,
no_repeat_ngram_size=3,
temperature=0.9,
max_length=1024,
early_stopping=True,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# UI
st.title("📝 Text Rewriter (Paraphraser)")
text_input = st.text_area("Enter text to rewrite:", height=300)
if st.button("Rewrite"):
with st.spinner("Rewriting..."):
result = rewrite(text_input)
st.success("Done!")
st.markdown("### 🔁 Rewritten Text")
st.write(result)