BramVanroy commited on
Commit
568a9e0
1 Parent(s): b902422

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -6,7 +6,7 @@ from typing import Iterator
6
  import torch
7
 
8
  import gradio as gr
9
- import spaces
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
 
@@ -16,11 +16,12 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
16
 
17
  model_id = "BramVanroy/GEITje-7B-ultra"
18
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
 
19
  tokenizer = AutoTokenizer.from_pretrained(model_id)
20
  tokenizer.pad_token_id = tokenizer.eos_token_id
21
 
22
 
23
- @spaces.GPU
24
  def generate(
25
  message: str,
26
  chat_history: list[tuple[str, str]],
 
6
  import torch
7
 
8
  import gradio as gr
9
+ # import spaces
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
 
 
16
 
17
  model_id = "BramVanroy/GEITje-7B-ultra"
18
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
19
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu")
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
21
  tokenizer.pad_token_id = tokenizer.eos_token_id
22
 
23
 
24
+ # @spaces.GPU
25
  def generate(
26
  message: str,
27
  chat_history: list[tuple[str, str]],