gufett0 commited on
Commit
b277c0d
·
1 Parent(s): 40986a4

changed class interface with iterator

Browse files
Files changed (2) hide show
  1. backend.py +3 -3
  2. interface.py +11 -6
backend.py CHANGED
@@ -20,7 +20,7 @@ login(huggingface_token)
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
- model_id = "google/gemma-2-2b-it"
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
  device_map="auto",
@@ -28,12 +28,12 @@ model = AutoModelForCausalLM.from_pretrained(
28
  token=True)
29
 
30
  model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
31
- model.eval()
32
 
33
  # what models will be used by LlamaIndex:
34
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
35
 
36
- Settings.llm = GemmaLLMInterface(model=model)
37
  #Settings.llm = GemmaLLMInterface(model_name=model_id)
38
 
39
  ############################---------------------------------
 
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
+ """model_id = "google/gemma-2-2b-it"
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
  device_map="auto",
 
28
  token=True)
29
 
30
  model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
31
+ model.eval()"""
32
 
33
  # what models will be used by LlamaIndex:
34
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
35
 
36
+ Settings.llm = GemmaLLMInterface()
37
  #Settings.llm = GemmaLLMInterface(model_name=model_id)
38
 
39
  ############################---------------------------------
interface.py CHANGED
@@ -9,11 +9,17 @@ from pydantic import Field, field_validator
9
 
10
  # for transformers 2
11
  class GemmaLLMInterface(CustomLLM):
12
- model: Any = None
13
- tokenizer: Any = None
14
- context_window: int = 8192
15
- num_output: int = 2048
16
- model_name: str = "gemma-2b-it"
 
 
 
 
 
 
17
 
18
  def _format_prompt(self, message: str) -> str:
19
  return (
@@ -23,7 +29,6 @@ class GemmaLLMInterface(CustomLLM):
23
 
24
  @property
25
  def metadata(self) -> LLMMetadata:
26
- """Get LLM metadata."""
27
  return LLMMetadata(
28
  context_window=self.context_window,
29
  num_output=self.num_output,
 
9
 
10
  # for transformers 2
11
  class GemmaLLMInterface(CustomLLM):
12
+ def __init__(self, model_name: str = "google/gemma-2b-it", **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.model_name = model_name
15
+ self.model = AutoModelForCausalLM.from_pretrained(
16
+ self.model_name,
17
+ device_map="auto",
18
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
19
+ )
20
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
21
+ self.context_window = 8192
22
+ self.num_output = 2048
23
 
24
  def _format_prompt(self, message: str) -> str:
25
  return (
 
29
 
30
  @property
31
  def metadata(self) -> LLMMetadata:
 
32
  return LLMMetadata(
33
  context_window=self.context_window,
34
  num_output=self.num_output,