KennethTM commited on
Commit
7cf951b
1 Parent(s): 65d4797

Remove torch.float16 when loading model

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -1,15 +1,15 @@
1
  #streamlit run app.py
2
 
3
  import streamlit as st
4
- import torch
5
  import galai as gal
 
6
 
7
  #https://github.com/paperswithcode/galai/blob/main/notebooks/Introduction%20to%20Galactica%20Models.ipynb
8
 
9
  #@st.cache(suppress_st_warning=True, allow_output_mutation=True)
10
  @st.cache_resource
11
  def load_model(model_name):
12
- model = gal.load_model(model_name, dtype=torch.float16)
13
  return model
14
 
15
  if 'text' not in st.session_state:
 
1
  #streamlit run app.py
2
 
3
  import streamlit as st
 
4
  import galai as gal
5
+ #import torch
6
 
7
  #https://github.com/paperswithcode/galai/blob/main/notebooks/Introduction%20to%20Galactica%20Models.ipynb
8
 
9
  #@st.cache(suppress_st_warning=True, allow_output_mutation=True)
10
  @st.cache_resource
11
  def load_model(model_name):
12
+ model = gal.load_model(model_name) #, dtype=torch.float16
13
  return model
14
 
15
  if 'text' not in st.session_state: