merve HF staff commited on
Commit
e10e13d
1 Parent(s): 0f75abe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -1,12 +1,19 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import gradio as gr
 
 
 
3
 
4
  tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompt-generator-v12")
5
  model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompt-generator-v12", from_tf=True)
6
 
7
- def generate(prompt):
 
8
 
 
 
9
  batch = tokenizer(prompt, return_tensors="pt")
 
10
  generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
11
  output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
12
  return output[0]
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import gradio as gr
3
+ import spaces
4
+ import torch
5
+
6
 
7
  tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompt-generator-v12")
8
  model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompt-generator-v12", from_tf=True)
9
 
10
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
 
13
+ @spaces.GPU
14
+ def generate(prompt):
15
  batch = tokenizer(prompt, return_tensors="pt")
16
+ batch.to(device)
17
  generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
18
  output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
19
  return output[0]