|
import streamlit as st
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
|
|
|
|
save_directory = "RAG_model"
|
|
|
|
|
|
|
|
@st.cache(allow_output_mutation=True)
|
|
def load_model():
|
|
model = AutoModelForCausalLM.from_pretrained(save_directory)
|
|
tokenizer = AutoTokenizer.from_pretrained(save_directory)
|
|
return model, tokenizer
|
|
|
|
|
|
model, tokenizer = load_model()
|
|
|
|
|
|
query_pipeline = pipeline(
|
|
"text-generation",
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
device=-1,
|
|
device_map="auto",
|
|
)
|
|
|
|
st.title("Text Generation with Llama-2 Model")
|
|
st.write("This is a simple Streamlit app to generate text using the Llama-2 model.")
|
|
|
|
|
|
user_input = st.text_area("Enter your prompt:", "")
|
|
|
|
|
|
if st.button("Generate"):
|
|
if user_input:
|
|
with st.spinner("Generating..."):
|
|
sequences = query_pipeline(
|
|
user_input,
|
|
do_sample=True,
|
|
top_k=10,
|
|
num_return_sequences=1,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
max_length=200,
|
|
)
|
|
for seq in sequences:
|
|
st.write("Generated text:")
|
|
st.write(seq['generated_text'])
|
|
else:
|
|
st.write("Please enter a prompt to generate text.")
|
|
|
|
|
|
st.write("Example usage: Enter a prompt like 'What is Artificial Intelligence?' and click 'Generate'.")
|
|
|