neuralworm commited on
Commit
35ae31f
Β·
1 Parent(s): 752fdc4

initial commit

Browse files
Files changed (1) hide show
  1. gen.py +21 -7
gen.py CHANGED
@@ -16,13 +16,27 @@ quantization_config = BitsAndBytesConfig(
16
  bnb_4bit_quant_type="nf4",
17
  )
18
 
19
- # Load the model with the quantization configuration
20
- model = AutoModelForCausalLM.from_pretrained(
21
- 'google/gemma-2-2b-it',
22
- device_map="auto",
23
- quantization_config=quantization_config,
24
- use_auth_token=hf_token
25
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  # Definir el prompt para generar un JSON con eventos anidados
 
16
  bnb_4bit_quant_type="nf4",
17
  )
18
 
19
+ # Check if a GPU is available
20
+ if torch.cuda.is_available():
21
+ # Load the model with 4-bit quantization (for GPU)
22
+ quantization_config = BitsAndBytesConfig(
23
+ load_in_4bit=True,
24
+ bnb_4bit_compute_dtype=torch.bfloat16,
25
+ bnb_4bit_quant_type="nf4",
26
+ )
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ 'google/gemma-2-2b-it',
29
+ device_map="auto",
30
+ quantization_config=quantization_config,
31
+ use_auth_token=hf_token
32
+ )
33
+ else:
34
+ # Load the model without quantization (for CPU)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ 'google/gemma-2-2b-it',
37
+ device_map="auto",
38
+ use_auth_token=hf_token
39
+ )
40
 
41
 
42
  # Definir el prompt para generar un JSON con eventos anidados