terryyz commited on
Commit
ae5bb78
1 Parent(s): 619704f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import shutil
4
  import requests
5
  import spaces
 
6
 
7
  import gradio as gr
8
  from huggingface_hub import Repository
@@ -11,6 +12,8 @@ from peft import PeftModel
11
 
12
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
13
 
 
 
14
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
  CHECKPOINT_URL = "Salesforce/codegen-350M-mono"
16
 
@@ -143,8 +146,9 @@ theme = gr.themes.Monochrome(
143
  ],
144
  )
145
 
146
-
147
  def stream(model, code, generate_kwargs):
 
148
  input_ids = tokenizer(code, return_tensors="pt").to("cuda")
149
  generated_ids = model.generate(**input_ids, **generate_kwargs)
150
  return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
 
3
  import shutil
4
  import requests
5
  import spaces
6
+ import torch
7
 
8
  import gradio as gr
9
  from huggingface_hub import Repository
 
12
 
13
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
14
 
15
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+
17
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
18
  CHECKPOINT_URL = "Salesforce/codegen-350M-mono"
19
 
 
146
  ],
147
  )
148
 
149
+ @spaces.GPU(enable_queue=True)
150
  def stream(model, code, generate_kwargs):
151
+ model.to(device)
152
  input_ids = tokenizer(code, return_tensors="pt").to("cuda")
153
  generated_ids = model.generate(**input_ids, **generate_kwargs)
154
  return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()