ArlowGPT / app.py
yuchenxie's picture
Update app.py
96ad412 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load model
@st.cache_resource
def load_model():
import torch # Ensure torch is imported in the required environment
tokenizer = AutoTokenizer.from_pretrained("yuchenxie/ArlowGPT-3B-Multilingual")
model = AutoModelForCausalLM.from_pretrained("yuchenxie/ArlowGPT-3B-Multilingual")
# Ensure `pad_token_id` is set
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer, model
tokenizer, model = load_model()
# Streamlit UI
st.title("ArlowGPT")
st.write("This space uses ArlowGPT-3B Multilingual. Generations may be slow.")
# Input box
user_input = st.text_area("Enter your prompt here:", placeholder="Type something...")
# Generate button
if st.button("Generate"):
if user_input.strip() == "":
st.warning("Please enter a prompt.")
else:
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with st.spinner("Generating response..."):
outputs = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=2048,
num_return_sequences=1,
temperature=0.9,
top_k=50,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.subheader("Generated Response")
st.write(response)