versae commited on
Commit
fb8b575
1 Parent(s): 2e7a033

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -9,6 +9,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
9
 
10
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
11
  DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
 
 
12
  DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
13
  MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
14
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
@@ -132,11 +134,13 @@ class TextGeneration:
132
 
133
 
134
  #@st.cache(allow_output_mutation=True)
135
- @st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
136
  def load_text_generator():
137
- generator = TextGeneration()
138
- generator.load()
139
- return generator
 
 
140
 
141
 
142
  def main():
@@ -147,8 +151,6 @@ def main():
147
  initial_sidebar_state="expanded"
148
  )
149
  style()
150
- with st.spinner('Cargando el modelo. Por favor, espere...'):
151
- generator = load_text_generator()
152
 
153
  st.sidebar.markdown(SIDEBAR_INFO, unsafe_allow_html=True)
154
 
 
9
 
10
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
11
  DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
12
+ if DEVICE != "cpu" and not torch.cuda.is_available():
13
+ DEVICE = "cpu"
14
  DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
15
  MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
16
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
 
134
 
135
 
136
  #@st.cache(allow_output_mutation=True)
137
+ @st.cache(allow_output_mutation=True, hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
138
  def load_text_generator():
139
+ text_generator = TextGeneration()
140
+ text_generator.load()
141
+ return text_generator
142
+
143
+ generator = load_text_generator()
144
 
145
 
146
  def main():
 
151
  initial_sidebar_state="expanded"
152
  )
153
  style()
 
 
154
 
155
  st.sidebar.markdown(SIDEBAR_INFO, unsafe_allow_html=True)
156