Dimitre commited on
Commit
d3cd834
1 Parent(s): ac51a1c

Adding 4-bit support

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import streamlit as st
5
  import torch
6
  from dotenv import load_dotenv
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  from hangman import guess_letter
10
  from hf_utils import query_hint, query_word
@@ -34,6 +34,8 @@ def setup(model_id: str, device: str) -> None:
34
  model_id (str): Model ID used to load the tokenizer and model.
35
  """
36
  logger.info(f"Loading model and tokenizer from model: '{model_id}'")
 
 
37
  tokenizer = AutoTokenizer.from_pretrained(
38
  model_id,
39
  token=os.environ["HF_ACCESS_TOKEN"],
@@ -42,6 +44,7 @@ def setup(model_id: str, device: str) -> None:
42
  model_id,
43
  torch_dtype=torch.float16,
44
  token=os.environ["HF_ACCESS_TOKEN"],
 
45
  ).to(device)
46
  logger.info("Setup finished")
47
  return {"tokenizer": tokenizer, "model": model}
 
4
  import streamlit as st
5
  import torch
6
  from dotenv import load_dotenv
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
8
 
9
  from hangman import guess_letter
10
  from hf_utils import query_hint, query_word
 
34
  model_id (str): Model ID used to load the tokenizer and model.
35
  """
36
  logger.info(f"Loading model and tokenizer from model: '{model_id}'")
37
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
38
+
39
  tokenizer = AutoTokenizer.from_pretrained(
40
  model_id,
41
  token=os.environ["HF_ACCESS_TOKEN"],
 
44
  model_id,
45
  torch_dtype=torch.float16,
46
  token=os.environ["HF_ACCESS_TOKEN"],
47
+ quantization_config=quantization_config,
48
  ).to(device)
49
  logger.info("Setup finished")
50
  return {"tokenizer": tokenizer, "model": model}