Update app.py
Browse files
app.py
CHANGED
@@ -4,14 +4,15 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
4 |
# Define the path where the model and tokenizer are saved
|
5 |
save_directory = "RAG_model"
|
6 |
|
7 |
-
|
8 |
# Load the model and tokenizer from the saved directory
|
9 |
-
@st.
|
10 |
def load_model():
|
11 |
model = AutoModelForCausalLM.from_pretrained(
|
12 |
save_directory,
|
13 |
-
torch_dtype=
|
14 |
-
device_map=
|
|
|
|
|
15 |
)
|
16 |
tokenizer = AutoTokenizer.from_pretrained(save_directory)
|
17 |
return model, tokenizer
|
@@ -51,4 +52,4 @@ if st.button("Generate"):
|
|
51 |
st.write("Please enter a prompt to generate text.")
|
52 |
|
53 |
# Add an example usage
|
54 |
-
st.write("Example usage: Enter a prompt like 'What is Artificial Intelligence?' and click 'Generate'.")
|
|
|
4 |
# Define the path where the model and tokenizer are saved
|
5 |
save_directory = "RAG_model"
|
6 |
|
|
|
7 |
# Load the model and tokenizer from the saved directory
|
8 |
+
@st.cache_resource
|
9 |
def load_model():
|
10 |
model = AutoModelForCausalLM.from_pretrained(
|
11 |
save_directory,
|
12 |
+
torch_dtype=None, # Ensure dtype is not set to a quantization dtype
|
13 |
+
device_map=None, # Ensure no device_map for CPU
|
14 |
+
load_in_8bit=False, # Ensure quantization is not used
|
15 |
+
load_in_4bit=False # Ensure quantization is not used
|
16 |
)
|
17 |
tokenizer = AutoTokenizer.from_pretrained(save_directory)
|
18 |
return model, tokenizer
|
|
|
52 |
st.write("Please enter a prompt to generate text.")
|
53 |
|
54 |
# Add an example usage
|
55 |
+
st.write("Example usage: Enter a prompt like 'What is Artificial Intelligence?' and click 'Generate'.")
|