versae commited on
Commit
dfc8ea3
1 Parent(s): 21b5dfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -5,12 +5,15 @@ import streamlit as st
5
  import torch
6
  from transformers import pipeline, set_seed
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
-
 
 
9
 
10
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
11
  DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
12
  if DEVICE != "cpu" and not torch.cuda.is_available():
13
  DEVICE = "cpu"
 
14
  DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
15
  MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
16
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
@@ -133,15 +136,13 @@ class TextGeneration:
133
  )[0]["generated_text"]
134
 
135
 
136
- #@st.cache(allow_output_mutation=True)
137
- @st.cache(allow_output_mutation=True, hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
138
  def load_text_generator():
139
  text_generator = TextGeneration()
140
  text_generator.load()
141
  return text_generator
142
 
143
- generator = load_text_generator()
144
-
145
 
146
  def main():
147
  st.set_page_config(
@@ -151,7 +152,7 @@ def main():
151
  initial_sidebar_state="expanded"
152
  )
153
  style()
154
-
155
  st.sidebar.markdown(SIDEBAR_INFO, unsafe_allow_html=True)
156
 
157
  max_length = st.sidebar.slider(
 
5
  import torch
6
  from transformers import pipeline, set_seed
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import logging
9
+ logger = logging.getLogger()
10
+ logger.addHandler(logging.StreamHandler())
11
 
12
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
13
  DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
14
  if DEVICE != "cpu" and not torch.cuda.is_available():
15
  DEVICE = "cpu"
16
+ logger.info(f"DEVICE {DEVICE}")
17
  DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
18
  MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
19
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
 
136
  )[0]["generated_text"]
137
 
138
 
139
+ #@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
140
+ @st.cache(allow_output_mutation=True)
141
  def load_text_generator():
142
  text_generator = TextGeneration()
143
  text_generator.load()
144
  return text_generator
145
 
 
 
146
 
147
  def main():
148
  st.set_page_config(
 
152
  initial_sidebar_state="expanded"
153
  )
154
  style()
155
+ generator = load_text_generator()
156
  st.sidebar.markdown(SIDEBAR_INFO, unsafe_allow_html=True)
157
 
158
  max_length = st.sidebar.slider(