terryyz commited on
Commit
e90f65d
1 Parent(s): 5d50dcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -59,7 +59,6 @@ if not torch.cuda.is_available():
59
 
60
  if torch.cuda.is_available():
61
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
62
- print(device)
63
 
64
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
65
  CHECKPOINT_URL = "Salesforce/codegen-350M-mono"
@@ -184,7 +183,7 @@ def generate(
184
  # model.to(device)
185
  input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
186
  # generated_ids = model.generate(**input_ids
187
- generated_ids = model.generate(**input_ids, **generate_kwargs)
188
 
189
  return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
190
 
 
59
 
60
  if torch.cuda.is_available():
61
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
62
 
63
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
64
  CHECKPOINT_URL = "Salesforce/codegen-350M-mono"
 
183
  # model.to(device)
184
  input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
185
  # generated_ids = model.generate(**input_ids
186
+ generated_ids = model.generate(**input_ids)#, **generate_kwargs)
187
 
188
  return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
189