Zwea Htet commited on
Commit
a550aaa
·
1 Parent(s): e594eb9

update llama custom

Browse files
Files changed (1) hide show
  1. models/llamaCustom.py +19 -33
models/llamaCustom.py CHANGED
@@ -21,8 +21,8 @@ from llama_index import (
21
  load_index_from_storage,
22
  )
23
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
24
-
25
- # from utils.customLLM import CustomLLM
26
 
27
  load_dotenv()
28
  # openai.api_key = os.getenv("OPENAI_API_KEY")
@@ -104,36 +104,10 @@ class OurLLM(CustomLLM):
104
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
105
  raise NotImplementedError()
106
 
107
- # def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
108
- # prompt_length = len(prompt)
109
- # response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"]
110
-
111
- # # only return newly generated tokens
112
- # return response[prompt_length:]
113
-
114
- # @property
115
- # def _identifying_params(self) -> Mapping[str, Any]:
116
- # return {"name_of_model": self.model_name}
117
-
118
- # @property
119
- # def _llm_type(self) -> str:
120
- # return "custom"
121
-
122
  class LlamaCustom:
123
- # define llm
124
- # llm_predictor = LLMPredictor(llm=OurLLM())
125
- # service_context = ServiceContext.from_defaults(
126
- # llm_predictor=llm_predictor, prompt_helper=prompt_helper
127
- # )
128
  def __init__(self, model_name: str) -> None:
129
- pipe = load_model(mode_name=model_name)
130
- llm = OurLLM(model_name=model_name, model_pipeline=pipe)
131
- self.service_context = ServiceContext.from_defaults(
132
- llm=llm, prompt_helper=prompt_helper
133
- )
134
  self.vector_index = self.initialize_index(model_name=model_name)
135
 
136
- @st.cache_resource
137
  def initialize_index(_self, model_name: str):
138
  index_name = model_name.split("/")[-1]
139
 
@@ -151,11 +125,26 @@ class LlamaCustom:
151
  # index = pickle.loads(file.readlines())
152
  return index
153
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  # documents = prepare_data(r"./assets/regItems.json")
155
  documents = SimpleDirectoryReader(input_dir="./assets/pdf").load_data()
156
 
157
  index = GPTVectorStoreIndex.from_documents(
158
- documents, service_context=self.service_context
159
  )
160
 
161
  # local write access
@@ -168,10 +157,7 @@ class LlamaCustom:
168
 
169
  def get_response(self, query_str):
170
  print("query_str: ", query_str)
171
- # query_engine = self.vector_index.as_query_engine()
172
- query_engine = self.vector_index.as_query_engine(
173
- text_qa_template=text_qa_template, refine_template=refine_template
174
- )
175
  response = query_engine.query(query_str)
176
  print("metadata: ", response.metadata)
177
  return str(response)
 
21
  load_index_from_storage,
22
  )
23
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
24
+ from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
26
 
27
  load_dotenv()
28
  # openai.api_key = os.getenv("OPENAI_API_KEY")
 
104
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
105
  raise NotImplementedError()
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  class LlamaCustom:
 
 
 
 
 
108
  def __init__(self, model_name: str) -> None:
 
 
 
 
 
109
  self.vector_index = self.initialize_index(model_name=model_name)
110
 
 
111
  def initialize_index(_self, model_name: str):
112
  index_name = model_name.split("/")[-1]
113
 
 
125
  # index = pickle.loads(file.readlines())
126
  return index
127
  else:
128
+ prompt_helper = PromptHelper(
129
+ context_window=CONTEXT_WINDOW,
130
+ num_output=NUM_OUTPUT,
131
+ chunk_overlap_ratio=CHUNK_OVERLAP_RATION,
132
+ )
133
+
134
+ # define llm
135
+ pipe = load_model(mode_name=model_name)
136
+ llm = OurLLM(model_name=model_name, model_pipeline=pipe)
137
+
138
+ llm_predictor = LLMPredictor(llm=llm)
139
+ service_context = ServiceContext.from_defaults(
140
+ llm_predictor=llm_predictor, prompt_helper=prompt_helper
141
+ )
142
+
143
  # documents = prepare_data(r"./assets/regItems.json")
144
  documents = SimpleDirectoryReader(input_dir="./assets/pdf").load_data()
145
 
146
  index = GPTVectorStoreIndex.from_documents(
147
+ documents, service_context=service_context
148
  )
149
 
150
  # local write access
 
157
 
158
  def get_response(self, query_str):
159
  print("query_str: ", query_str)
160
+ query_engine = self.vector_index.as_query_engine()
 
 
 
161
  response = query_engine.query(query_str)
162
  print("metadata: ", response.metadata)
163
  return str(response)