dvruette commited on
Commit
6ea89ff
1 Parent(s): cb15a9e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -1
main.py CHANGED
@@ -5,6 +5,7 @@ from threading import Thread
5
  import time
6
  import torch
7
  import gradio as gr
 
8
  from concept_guidance.chat_template import DEFAULT_CHAT_TEMPLATE
9
  from concept_guidance.patching import patch_model, load_weights
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer, Conversation
@@ -12,8 +13,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIter
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  # device = "cpu"
 
17
 
18
  # comment in/out the models you want to use
19
  # RAM requirements: ~16GB x #models (+ ~4GB overhead)
@@ -86,6 +88,7 @@ def add_user_prompt(user_message, history):
86
  history.append([user_message, None])
87
  return history
88
 
 
89
  @torch.no_grad()
90
  def generate_completion(
91
  history,
 
5
  import time
6
  import torch
7
  import gradio as gr
8
+ import spaces
9
  from concept_guidance.chat_template import DEFAULT_CHAT_TEMPLATE
10
  from concept_guidance.patching import patch_model, load_weights
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer, Conversation
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  # device = "cpu"
18
+ device = "cuda"
19
 
20
  # comment in/out the models you want to use
21
  # RAM requirements: ~16GB x #models (+ ~4GB overhead)
 
88
  history.append([user_message, None])
89
  return history
90
 
91
+ @spaces.GPU
92
  @torch.no_grad()
93
  def generate_completion(
94
  history,