hysts HF staff commited on
Commit
1b944a6
1 Parent(s): eb63bbc
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -5,7 +5,7 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
9
 
10
  DESCRIPTION = """\
11
  # Gemma 2 2B IT
@@ -24,7 +24,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
24
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
 
26
  model_id = "google/gemma-2-2b-it"
27
- tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_id,
30
  device_map="auto",
@@ -34,7 +34,7 @@ model.config.sliding_window = 4096
34
  model.eval()
35
 
36
 
37
- @spaces.GPU(duration=90)
38
  def generate(
39
  message: str,
40
  chat_history: list[dict],
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  DESCRIPTION = """\
11
  # Gemma 2 2B IT
 
24
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
 
26
  model_id = "google/gemma-2-2b-it"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_id,
30
  device_map="auto",
 
34
  model.eval()
35
 
36
 
37
+ @spaces.GPU
38
  def generate(
39
  message: str,
40
  chat_history: list[dict],