john commited on
Commit
bd8afe1
1 Parent(s): 297e42d
Files changed (2) hide show
  1. app.py +40 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ from peft import PeftModel
4
+ import gradio as gr
5
+
6
+ base_model_id='mistralai/Mistral-7B-Instruct-v0.1'
7
+ PEFT_MODEL = 'johnstrenio/mistral_ski'
8
+
9
+ bnb_config = BitsAndBytesConfig(
10
+ load_in_4bit=True, # load model in 4-bit precision
11
+ bnb_4bit_quant_type="nf4", # pre-trained model should be quantized in 4-bit NF format
12
+ bnb_4bit_use_double_quant=True, # Using double quantization as mentioned in QLoRA paper
13
+ bnb_4bit_compute_dtype=torch.bfloat16, # During computation, pre-trained model should be loaded in BF16 format
14
+ )
15
+
16
+ base_model = AutoModelForCausalLM.from_pretrained(
17
+ base_model_id, # Mistral, same as before
18
+ quantization_config=bnb_config, # Same quantization config as before
19
+ device_map="auto",
20
+ trust_remote_code=True,
21
+ )
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+ model = PeftModel.from_pretrained(base_model, PEFT_MODEL)
26
+ model.eval()
27
+
28
+ def predict(text):
29
+ prompt = "[INST] " + text + " [/INST]"
30
+ model_input = tokenizer(prompt, return_tensors="pt").to("cuda")
31
+ resp = tokenizer.decode(model.generate(**model_input, max_new_tokens=200, pad_token_id=2, repetition_penalty=1.15)[0], skip_special_tokens=True)
32
+ return resp
33
+
34
+ demo = gr.Interface(
35
+ fn=predict,
36
+ inputs='text',
37
+ outputs='text',
38
+ )
39
+
40
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ backcall==0.2.0
3
+ bitsandbytes==0.41.1
4
+ datasets==2.14.6
5
+ gradio==4.4.1
6
+ gradio_client==0.7.0
7
+ peft @ git+https://github.com/huggingface/peft.git@56556faa17263be8ef1802c172141705b71c28dc
8
+ torch==2.1.0
9
+ transformers @ git+https://github.com/huggingface/transformers.git@f370bebdc352cd7c1bea2f88ae0c140ab694c5fd
10
+ trl==0.7.2