Carlosito16 commited on
Commit
e106a6d
1 Parent(s): 00662e9

add quantization config

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -12,6 +12,7 @@ import csv
12
  import json
13
  import torch
14
  from tqdm.auto import tqdm
 
15
  from langchain.text_splitter import RecursiveCharacterTextSplitter
16
 
17
 
@@ -33,6 +34,8 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
33
 
34
 
35
 
 
 
36
  prompt_template = """
37
 
38
  You are the chatbot and the face of Asian Institute of Technology (AIT). Your job is to give answers to prospective and current students about the school.
@@ -59,7 +62,10 @@ st.set_page_config(
59
  page_title = 'aitGPT',
60
  page_icon = '✅')
61
 
62
-
 
 
 
63
 
64
 
65
  @st.cache_data
@@ -91,19 +97,21 @@ def load_faiss_index():
91
 
92
  @st.cache_resource
93
  def load_llm_model():
94
- # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
95
- # task= 'text2text-generation',
96
- # model_kwargs={ "device_map": "auto",
97
- # "load_in_8bit": True,"max_length": 256, "temperature": 0,
98
- # "repetition_penalty": 1.5})
 
 
99
 
100
 
101
- llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
102
- task= 'text2text-generation',
103
 
104
- model_kwargs={ "max_length": 256, "temperature": 0,
105
- "torch_dtype":torch.float32,
106
- "repetition_penalty": 1.3})
107
  return llm
108
 
109
 
 
12
  import json
13
  import torch
14
  from tqdm.auto import tqdm
15
+ from transformers import BitsAndBytesConfig
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
 
18
 
 
34
 
35
 
36
 
37
+
38
+
39
  prompt_template = """
40
 
41
  You are the chatbot and the face of Asian Institute of Technology (AIT). Your job is to give answers to prospective and current students about the school.
 
62
  page_title = 'aitGPT',
63
  page_icon = '✅')
64
 
65
+ bitsandbyte_config = BitsAndBytesConfig(
66
+ load_in_4bit=True,
67
+ bnb_4bit_quant_type="nf4",
68
+ bnb_4bit_compute_dtype=torch.float16)
69
 
70
 
71
  @st.cache_data
 
97
 
98
  @st.cache_resource
99
  def load_llm_model():
100
+ #this one is for running with GPT
101
+ llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
102
+ task= 'text2text-generation',
103
+ model_kwargs={ "device_map": "auto",
104
+ "max_length": 256, "temperature": 0,
105
+ "repetition_penalty": 1.5,
106
+ "quantization_config": bitsandbyte_config}) #add this quantization config
107
 
108
 
109
+ # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
110
+ # task= 'text2text-generation',
111
 
112
+ # model_kwargs={ "max_length": 256, "temperature": 0,
113
+ # "torch_dtype":torch.float32,
114
+ # "repetition_penalty": 1.3})
115
  return llm
116
 
117