ksh-nyp commited on
Commit
9e4e68e
1 Parent(s): 9005c68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -1
app.py CHANGED
@@ -42,11 +42,45 @@ packing = False
42
  # Load the entire model on the GPU 0
43
  device_map = {"": 0}
44
 
45
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Initialize the pipeline with the LLaMA model
48
  model_name = "ksh-nyp/llama-2-7b-chat-TCMKB2"
49
  pipe = pipeline("text-generation", model=model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def generate_text(prompt):
52
  # Generate text based on the input prompt
 
42
  # Load the entire model on the GPU 0
43
  device_map = {"": 0}
44
 
45
+ # Load tokenizer and model with QLoRA configuration
46
+ compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
47
+
48
+ bnb_config = BitsAndBytesConfig(
49
+ load_in_4bit=use_4bit,
50
+ bnb_4bit_quant_type=bnb_4bit_quant_type,
51
+ bnb_4bit_compute_dtype=compute_dtype,
52
+ bnb_4bit_use_double_quant=use_nested_quant,
53
+ )
54
+
55
+ # Check GPU compatibility with bfloat16
56
+ if compute_dtype == torch.float16 and use_4bit:
57
+ major, _ = torch.cuda.get_device_capability()
58
+ if major >= 8:
59
+ print("=" * 80)
60
+ print("Your GPU supports bfloat16: accelerate training with bf16=True")
61
+ print("=" * 80)
62
 
63
  # Initialize the pipeline with the LLaMA model
64
  model_name = "ksh-nyp/llama-2-7b-chat-TCMKB2"
65
  pipe = pipeline("text-generation", model=model_name)
66
+
67
+ # Load base model
68
+ model = AutoModelForCausalLM.from_pretrained(
69
+ model_name,
70
+ quantization_config=bnb_config,
71
+ device_map=device_map
72
+ )
73
+ model.config.use_cache = False
74
+ model.config.pretraining_tp = 1
75
+
76
+ # Load LLaMA tokenizer
77
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
78
+ tokenizer.pad_token = tokenizer.eos_token
79
+ tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
80
+
81
+ from transformers import pipeline
82
+
83
+
84
 
85
  def generate_text(prompt):
86
  # Generate text based on the input prompt