hdnh2006
commited on
Commit
•
539ede7
1
Parent(s):
4f267c9
handler uses LlamaForCausalLM
Browse files- handler.py +5 -2
handler.py
CHANGED
@@ -13,7 +13,7 @@ for text generation, leveraging the capabilities of the Llama 2 model.
|
|
13 |
"""
|
14 |
|
15 |
import torch
|
16 |
-
from transformers import pipeline, BitsAndBytesConfig
|
17 |
from typing import Dict, List, Any
|
18 |
import logging
|
19 |
import sys
|
@@ -51,7 +51,10 @@ class EndpointHandler:
|
|
51 |
bnb_4bit_compute_dtype=torch.bfloat16
|
52 |
)
|
53 |
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
13 |
"""
|
14 |
|
15 |
import torch
|
16 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline, BitsAndBytesConfig
|
17 |
from typing import Dict, List, Any
|
18 |
import logging
|
19 |
import sys
|
|
|
51 |
bnb_4bit_compute_dtype=torch.bfloat16
|
52 |
)
|
53 |
|
54 |
+
tokenizer = LlamaTokenizer.from_pretrained(path)
|
55 |
+
model = LlamaForCausalLM.from_pretrained(path, device_map=0, quantization_config=self.bnb_config)
|
56 |
+
|
57 |
+
self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
58 |
|
59 |
|
60 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|