Zwea Htet commited on
Commit
5d94ab8
·
1 Parent(s): 4d4ef0e

updated code

Browse files
Files changed (2) hide show
  1. models/llamaCustom.py +3 -0
  2. requirements.txt +1 -1
models/llamaCustom.py CHANGED
@@ -20,6 +20,7 @@ from llama_index import (
20
  StorageContext,
21
  load_index_from_storage,
22
  )
 
23
  from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
24
  from llama_index.prompts import Prompt
25
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
@@ -78,6 +79,7 @@ class OurLLM(CustomLLM):
78
  model_name=self.model_name,
79
  )
80
 
 
81
  def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
82
  prompt_length = len(prompt)
83
  response = self.pipeline(prompt, max_new_tokens=NUM_OUTPUT)[0]["generated_text"]
@@ -86,6 +88,7 @@ class OurLLM(CustomLLM):
86
  text = response[prompt_length:]
87
  return CompletionResponse(text=text)
88
 
 
89
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
90
  raise NotImplementedError()
91
 
 
20
  StorageContext,
21
  load_index_from_storage,
22
  )
23
+ from llama_index.llms.base import llm_completion_callback
24
  from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
25
  from llama_index.prompts import Prompt
26
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
79
  model_name=self.model_name,
80
  )
81
 
82
+ @llm_completion_callback()
83
  def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
84
  prompt_length = len(prompt)
85
  response = self.pipeline(prompt, max_new_tokens=NUM_OUTPUT)[0]["generated_text"]
 
88
  text = response[prompt_length:]
89
  return CompletionResponse(text=text)
90
 
91
+ @llm_completion_callback()
92
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
93
  raise NotImplementedError()
94
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- llama_index==0.7.16
2
  torch
3
  transformers
4
  panda
 
1
+ llama_index
2
  torch
3
  transformers
4
  panda