Daniel Fried commited on
Commit
51676cf
1 Parent(s): 78ec172

add cloud logging

Browse files
Files changed (2) hide show
  1. modules/app.py +34 -26
  2. modules/cloud_logging.py +21 -0
modules/app.py CHANGED
@@ -3,6 +3,11 @@ from typing import List
3
  import traceback
4
  import os
5
  import base64
 
 
 
 
 
6
  # needs to be imported *before* transformers
7
  if os.path.exists('use_normal_tokenizers'):
8
  import tokenizers
@@ -51,11 +56,11 @@ app = FastAPI(docs_url=None, redoc_url=None)
51
  app.mount("/static", StaticFiles(directory="static"), name="static")
52
 
53
 
54
- print("loading model")
55
  model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
56
- print("loading tokenizer")
57
  tokenizer = AutoTokenizer.from_pretrained(model_name)
58
- print("loading complete")
59
 
60
  if CUDA:
61
  model = model.half().cuda()
@@ -96,7 +101,7 @@ def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel
96
  any_truncated = False
97
  retries_attempted += 1
98
  if VERBOSE:
99
- print(f"retry {retries_attempted}")
100
  if len(parts) == 1:
101
  prompt = parts[0]
102
  else:
@@ -122,7 +127,7 @@ def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel
122
  completion = completion[len(prompt):]
123
  if EOM not in completion:
124
  if VERBOSE:
125
- print(f"warning: {EOM} not found")
126
  completion += EOM
127
  # TODO: break inner loop here
128
  done = False
@@ -135,18 +140,18 @@ def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel
135
  text = ''.join(complete)
136
 
137
  if VERBOSE:
138
- print("generated text:")
139
- print(prompt)
140
- print()
141
- print("parts:")
142
- print(parts)
143
- print()
144
- print("infills:")
145
- print(infills)
146
- print()
147
- print("restitched text:")
148
- print(text)
149
- print()
150
 
151
  return {
152
  'text': text,
@@ -169,17 +174,17 @@ async def generate_maybe(info: str):
169
  # form = await request.json()
170
  # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
171
  # fix padding, following https://stackoverflow.com/a/9956217/1319683
172
- print(info)
173
  info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
174
- print(info)
175
  form = json.loads(info)
176
- pprint.pprint(form)
177
  # print(form)
178
  prompt = form['prompt']
179
  length_limit = int(form['length'])
180
  temperature = float(form['temperature'])
181
- if VERBOSE:
182
- print(prompt)
 
 
 
183
  try:
184
  generation, truncated = generate(prompt, length_limit, temperature)
185
  if truncated:
@@ -189,6 +194,7 @@ async def generate_maybe(info: str):
189
  return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation, 'message': message}
190
  except Exception as e:
191
  traceback.print_exception(*sys.exc_info())
 
192
  return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
193
 
194
  @app.get('/infill')
@@ -198,15 +204,17 @@ async def infill_maybe(info: str):
198
  # form = await request.json()
199
  # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
200
  # fix padding, following https://stackoverflow.com/a/9956217/1319683
201
- print(info)
202
  info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
203
- print(info)
204
  form = json.loads(info)
205
- pprint.pprint(form)
206
  length_limit = int(form['length'])
207
  temperature = float(form['temperature'])
208
  max_retries = 1
209
  extra_sentinel = True
 
 
 
 
 
210
  try:
211
  if len(form['parts']) > 4:
212
  return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Can't use more than 3 <infill> tokens in this demo (for efficiency)."}
@@ -221,7 +229,7 @@ async def infill_maybe(info: str):
221
  # return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']}
222
  except Exception as e:
223
  traceback.print_exception(*sys.exc_info())
224
- print(e)
225
  return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}
226
 
227
 
3
  import traceback
4
  import os
5
  import base64
6
+
7
+ import logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ import modules.cloud_logging
10
+
11
  # needs to be imported *before* transformers
12
  if os.path.exists('use_normal_tokenizers'):
13
  import tokenizers
56
  app.mount("/static", StaticFiles(directory="static"), name="static")
57
 
58
 
59
+ logging.info("loading model")
60
  model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
61
+ logging.info("loading tokenizer")
62
  tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+ logging.info("loading complete")
64
 
65
  if CUDA:
66
  model = model.half().cuda()
101
  any_truncated = False
102
  retries_attempted += 1
103
  if VERBOSE:
104
+ logging.info(f"retry {retries_attempted}")
105
  if len(parts) == 1:
106
  prompt = parts[0]
107
  else:
127
  completion = completion[len(prompt):]
128
  if EOM not in completion:
129
  if VERBOSE:
130
+ logging.info(f"warning: {EOM} not found")
131
  completion += EOM
132
  # TODO: break inner loop here
133
  done = False
140
  text = ''.join(complete)
141
 
142
  if VERBOSE:
143
+ logging.info("generated text:")
144
+ logging.info(prompt)
145
+ logging.info()
146
+ logging.info("parts:")
147
+ logging.info(parts)
148
+ logging.info()
149
+ logging.info("infills:")
150
+ logging.info(infills)
151
+ logging.info()
152
+ logging.info("restitched text:")
153
+ logging.info(text)
154
+ logging.info()
155
 
156
  return {
157
  'text': text,
174
  # form = await request.json()
175
  # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
176
  # fix padding, following https://stackoverflow.com/a/9956217/1319683
 
177
  info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
 
178
  form = json.loads(info)
 
179
  # print(form)
180
  prompt = form['prompt']
181
  length_limit = int(form['length'])
182
  temperature = float(form['temperature'])
183
+ logging.info(json.dumps({
184
+ 'length': length_limit,
185
+ 'temperature': temperature,
186
+ 'prompt': prompt,
187
+ }))
188
  try:
189
  generation, truncated = generate(prompt, length_limit, temperature)
190
  if truncated:
194
  return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation, 'message': message}
195
  except Exception as e:
196
  traceback.print_exception(*sys.exc_info())
197
+ logging.error(e)
198
  return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
199
 
200
  @app.get('/infill')
204
  # form = await request.json()
205
  # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
206
  # fix padding, following https://stackoverflow.com/a/9956217/1319683
 
207
  info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
 
208
  form = json.loads(info)
 
209
  length_limit = int(form['length'])
210
  temperature = float(form['temperature'])
211
  max_retries = 1
212
  extra_sentinel = True
213
+ logging.info(json.dumps({
214
+ 'length': length_limit,
215
+ 'temperature': temperature,
216
+ 'parts_joined': '<infill>'.join(form['parts']),
217
+ }))
218
  try:
219
  if len(form['parts']) > 4:
220
  return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Can't use more than 3 <infill> tokens in this demo (for efficiency)."}
229
  # return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']}
230
  except Exception as e:
231
  traceback.print_exception(*sys.exc_info())
232
+ logging.error(e)
233
  return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}
234
 
235
 
modules/cloud_logging.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ def make_logging_client():
3
+ cred_filename = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
4
+ if not cred_filename:
5
+ return None
6
+ print("cred filename:", cred_filename)
7
+ cred_string = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS_STRING')
8
+ print("cred string:", bool(cred_string))
9
+ if not os.path.exists(cred_filename):
10
+ if cred_string:
11
+ print(f"writing cred string to {cred_filename}")
12
+ with open(cred_filename, 'w') as f:
13
+ f.write(cred_string)
14
+ else:
15
+ return None
16
+ from google.cloud import logging
17
+ logging_client = logging.Client()
18
+ logging_client.setup_logging()
19
+ return logging_client
20
+
21
+ logging_client = make_logging_client()