vilarin commited on
Commit
652ef04
1 Parent(s): d83af28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -36,15 +36,16 @@ h3 {
36
  text-align: center;
37
  }
38
  """
39
-
40
- model = AutoModelForCausalLM.from_pretrained(
41
- MODELS,
42
- device_map="auto",
43
- quantization_config=BitsAndBytesConfig(load_in_4bit=True)
44
  )
45
- tokenizer = GemmaTokenizerFast.from_pretrained(MODELS)
46
- model.config.sliding_window = 4096
47
- model.eval()
 
48
 
49
  @spaces.GPU(duration=90)
50
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
@@ -75,7 +76,6 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
75
  do_sample=True,
76
  temperature=temperature,
77
  num_beams=1,
78
- repetition_penalty=repetition_penalty,
79
  )
80
 
81
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
 
36
  text-align: center;
37
  }
38
  """
39
+ if torch.cuda.is_available():
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ MODELS,
42
+ device_map="auto",
43
+ quantization_config=BitsAndBytesConfig(load_in_4bit=True)
44
  )
45
+ tokenizer = GemmaTokenizerFast.from_pretrained(MODELS)
46
+ model.config.sliding_window = 4096
47
+ model.eval()
48
+
49
 
50
  @spaces.GPU(duration=90)
51
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
 
76
  do_sample=True,
77
  temperature=temperature,
78
  num_beams=1,
 
79
  )
80
 
81
  thread = Thread(target=model.generate, kwargs=generate_kwargs)