Tabish009 commited on
Commit
46525f9
1 Parent(s): 7eff3c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -1,5 +1,9 @@
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
3
 
4
  # Load the model and tokenizer
5
  @st.cache_resource
@@ -31,6 +35,7 @@ def main():
31
  with st.spinner("Generating response..."):
32
  response = generate_response(prompt)
33
  st.write(response)
 
34
 
35
  if __name__ == "__main__":
36
  main()
 
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import accelerate
4
+
5
+ accelerator = accelerate.Accelerator(device_map="auto")
6
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=accelerator.device_map)
7
 
8
  # Load the model and tokenizer
9
  @st.cache_resource
 
35
  with st.spinner("Generating response..."):
36
  response = generate_response(prompt)
37
  st.write(response)
38
+
39
 
40
  if __name__ == "__main__":
41
  main()