udaytag commited on
Commit
00d0bb7
1 Parent(s): 4bd6e98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -52
app.py CHANGED
@@ -1,52 +1,54 @@
1
- import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
-
4
- # Define the path where the model and tokenizer are saved
5
- save_directory = "RAG_model"
6
-
7
-
8
- # Load the model and tokenizer from the saved directory
9
- @st.cache(allow_output_mutation=True)
10
- def load_model():
11
- model = AutoModelForCausalLM.from_pretrained(save_directory)
12
- tokenizer = AutoTokenizer.from_pretrained(save_directory)
13
- return model, tokenizer
14
-
15
-
16
- model, tokenizer = load_model()
17
-
18
- # Set up the text generation pipeline
19
- query_pipeline = pipeline(
20
- "text-generation",
21
- model=model,
22
- tokenizer=tokenizer,
23
- device=-1, # Use CPU
24
- device_map="auto",
25
- )
26
-
27
- st.title("Text Generation with Llama-2 Model")
28
- st.write("This is a simple Streamlit app to generate text using the Llama-2 model.")
29
-
30
- # Text input for the user
31
- user_input = st.text_area("Enter your prompt:", "")
32
-
33
- # Generate text when the user clicks the button
34
- if st.button("Generate"):
35
- if user_input:
36
- with st.spinner("Generating..."):
37
- sequences = query_pipeline(
38
- user_input,
39
- do_sample=True,
40
- top_k=10,
41
- num_return_sequences=1,
42
- eos_token_id=tokenizer.eos_token_id,
43
- max_length=200,
44
- )
45
- for seq in sequences:
46
- st.write("Generated text:")
47
- st.write(seq['generated_text'])
48
- else:
49
- st.write("Please enter a prompt to generate text.")
50
-
51
- # Add an example usage
52
- st.write("Example usage: Enter a prompt like 'What is Artificial Intelligence?' and click 'Generate'.")
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+
4
+ # Define the path where the model and tokenizer are saved
5
+ save_directory = "RAG_model"
6
+
7
+
8
+ # Load the model and tokenizer from the saved directory
9
+ @st.cache(allow_output_mutation=True)
10
+ def load_model():
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ save_directory,
13
+ torch_dtype="auto", # Ensure automatic dtype selection
14
+ device_map="cpu" # Explicitly set to CPU
15
+ )
16
+ tokenizer = AutoTokenizer.from_pretrained(save_directory)
17
+ return model, tokenizer
18
+
19
+ model, tokenizer = load_model()
20
+
21
+ # Set up the text generation pipeline
22
+ query_pipeline = pipeline(
23
+ "text-generation",
24
+ model=model,
25
+ tokenizer=tokenizer,
26
+ device=-1 # Use CPU
27
+ )
28
+
29
+ st.title("Text Generation with Llama-2 Model")
30
+ st.write("This is a simple Streamlit app to generate text using the Llama-2 model.")
31
+
32
+ # Text input for the user
33
+ user_input = st.text_area("Enter your prompt:", "")
34
+
35
+ # Generate text when the user clicks the button
36
+ if st.button("Generate"):
37
+ if user_input:
38
+ with st.spinner("Generating..."):
39
+ sequences = query_pipeline(
40
+ user_input,
41
+ do_sample=True,
42
+ top_k=10,
43
+ num_return_sequences=1,
44
+ eos_token_id=tokenizer.eos_token_id,
45
+ max_length=200,
46
+ )
47
+ for seq in sequences:
48
+ st.write("Generated text:")
49
+ st.write(seq['generated_text'])
50
+ else:
51
+ st.write("Please enter a prompt to generate text.")
52
+
53
+ # Add an example usage
54
+ st.write("Example usage: Enter a prompt like 'What is Artificial Intelligence?' and click 'Generate'.")