Spaces:
Sleeping
Sleeping
changed class interface with iterator
Browse files- backend.py +3 -3
- interface.py +11 -6
backend.py
CHANGED
@@ -20,7 +20,7 @@ login(huggingface_token)
|
|
20 |
|
21 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
22 |
|
23 |
-
model_id = "google/gemma-2-2b-it"
|
24 |
model = AutoModelForCausalLM.from_pretrained(
|
25 |
model_id,
|
26 |
device_map="auto",
|
@@ -28,12 +28,12 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
28 |
token=True)
|
29 |
|
30 |
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
31 |
-
model.eval()
|
32 |
|
33 |
# what models will be used by LlamaIndex:
|
34 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
35 |
|
36 |
-
Settings.llm = GemmaLLMInterface(
|
37 |
#Settings.llm = GemmaLLMInterface(model_name=model_id)
|
38 |
|
39 |
############################---------------------------------
|
|
|
20 |
|
21 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
22 |
|
23 |
+
"""model_id = "google/gemma-2-2b-it"
|
24 |
model = AutoModelForCausalLM.from_pretrained(
|
25 |
model_id,
|
26 |
device_map="auto",
|
|
|
28 |
token=True)
|
29 |
|
30 |
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
31 |
+
model.eval()"""
|
32 |
|
33 |
# what models will be used by LlamaIndex:
|
34 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
35 |
|
36 |
+
Settings.llm = GemmaLLMInterface()
|
37 |
#Settings.llm = GemmaLLMInterface(model_name=model_id)
|
38 |
|
39 |
############################---------------------------------
|
interface.py
CHANGED
@@ -9,11 +9,17 @@ from pydantic import Field, field_validator
|
|
9 |
|
10 |
# for transformers 2
|
11 |
class GemmaLLMInterface(CustomLLM):
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def _format_prompt(self, message: str) -> str:
|
19 |
return (
|
@@ -23,7 +29,6 @@ class GemmaLLMInterface(CustomLLM):
|
|
23 |
|
24 |
@property
|
25 |
def metadata(self) -> LLMMetadata:
|
26 |
-
"""Get LLM metadata."""
|
27 |
return LLMMetadata(
|
28 |
context_window=self.context_window,
|
29 |
num_output=self.num_output,
|
|
|
9 |
|
10 |
# for transformers 2
|
11 |
class GemmaLLMInterface(CustomLLM):
|
12 |
+
def __init__(self, model_name: str = "google/gemma-2b-it", **kwargs):
|
13 |
+
super().__init__(**kwargs)
|
14 |
+
self.model_name = model_name
|
15 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
16 |
+
self.model_name,
|
17 |
+
device_map="auto",
|
18 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
19 |
+
)
|
20 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
21 |
+
self.context_window = 8192
|
22 |
+
self.num_output = 2048
|
23 |
|
24 |
def _format_prompt(self, message: str) -> str:
|
25 |
return (
|
|
|
29 |
|
30 |
@property
|
31 |
def metadata(self) -> LLMMetadata:
|
|
|
32 |
return LLMMetadata(
|
33 |
context_window=self.context_window,
|
34 |
num_output=self.num_output,
|