Carlosito16 commited on
Commit
ce86648
1 Parent(s): 8b5df99

call t5 model class before assigning to LLM pipeline

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -13,6 +13,7 @@ import json
13
  import torch
14
  from tqdm.auto import tqdm
15
  from transformers import BitsAndBytesConfig
 
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
 
18
 
@@ -97,14 +98,27 @@ def load_faiss_index():
97
 
98
  @st.cache_resource
99
  def load_llm_model():
100
- #this one is for running with GPT
101
- llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
102
- task= 'text2text-generation',
103
- model_kwargs={
104
- # "device_map": "auto",
105
- "max_length": 256, "temperature": 0,
106
- "repetition_penalty": 1.5,
107
- "quantization_config": bitsandbyte_config}) #add this quantization config
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
 
13
  import torch
14
  from tqdm.auto import tqdm
15
  from transformers import BitsAndBytesConfig
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, T5Tokenizer, AutoModel, T5ForConditionalGeneration
17
  from langchain.text_splitter import RecursiveCharacterTextSplitter
18
 
19
 
 
98
 
99
  @st.cache_resource
100
  def load_llm_model():
101
+ #this one is for running with GPU
102
+
103
+ model = T5ForConditionalGeneration.from_pretrained(model_id='lmsys/fastchat-t5-3b-v1.0',
104
+ quantization_config = bitsandbyte_config,
105
+ device_map= 'auto')
106
+ tokenizer = AutoTokenizer.from_pretrained(core_model_name)
107
+
108
+
109
+ pipe = pipeline(
110
+ task= 'text2text-generation', model=model, tokenizer=tokenizer, max_new_tokens=256, model_kwargs={"temperature":0,
111
+ "repetition_penalty": 1.5}
112
+ )
113
+ llm = HuggingFacePipeline(pipeline=pipe)
114
+
115
+ # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
116
+ # task= 'text2text-generation',
117
+ # model_kwargs={
118
+ # # "device_map": "auto",
119
+ # "max_length": 256, "temperature": 0,
120
+ # "repetition_penalty": 1.5,
121
+ # "quantization_config": bitsandbyte_config}) #add this quantization config
122
 
123
 
124
  # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',