|
|
|
|
|
import ctypes |
|
|
|
|
|
class TokenWord(ctypes.Structure): |
|
_fields_ = [ |
|
("token", ctypes.c_int), |
|
("word", ctypes.c_char * 2048) |
|
] |
|
|
|
|
|
class TPUChatglm: |
|
def __init__(self): |
|
self.lib = ctypes.cdll.LoadLibrary('./build/libtpuchat.so') |
|
device_id = 3 |
|
bmodel_path = "../model/baichuan2-7b-test_int8.bmodel" |
|
token_path = "../model/tokenizer.model" |
|
self.device_id = device_id |
|
self.bmodel_path = bmodel_path |
|
self.token_path = token_path |
|
self.libset() |
|
self.init() |
|
|
|
def libset(self): |
|
self.lib.Baichuan2_with_devid_and_model.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p] |
|
self.lib.Baichuan2_with_devid_and_model.restype = ctypes.c_void_p |
|
|
|
self.lib.Baichuan2_delete.argtypes = [ctypes.c_void_p] |
|
|
|
|
|
self.lib.Baichuan2_deinit.argtypes = [ctypes.c_void_p] |
|
|
|
|
|
self.lib.Baichuan2_predict_first_token.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
|
self.lib.Baichuan2_predict_first_token.restype = ctypes.c_char_p |
|
|
|
|
|
self.lib.Baichuan2_predict_next_token.argtypes = [ctypes.c_void_p] |
|
self.lib.Baichuan2_predict_next_token.restype = ctypes.c_char_p |
|
|
|
|
|
self.lib.get_eos.argtypes = [ctypes.c_void_p] |
|
self.lib.get_eos.restype = ctypes.c_int |
|
|
|
self.lib.get_history.argtypes = [ctypes.c_void_p] |
|
self.lib.get_history.restype = ctypes.c_char_p |
|
|
|
self.lib.set_history.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
|
|
|
def init(self): |
|
self.obj = self.lib.Baichuan2_with_devid_and_model(self.device_id, self.bmodel_path.encode('utf-8'), |
|
self.token_path.encode('utf-8')) |
|
|
|
def predict_first_token(self, context): |
|
return self.lib.Baichuan2_predict_first_token(self.obj, context.encode('utf-8')).decode('utf-8') |
|
|
|
def predict_next_token(self): |
|
return self.lib.Baichuan2_predict_next_token(self.obj).decode('utf-8') |
|
|
|
def predict(self, context): |
|
|
|
first_token = self.predict_first_token(context) |
|
|
|
res = '' |
|
while True: |
|
next_token = self.predict_next_token() |
|
if next_token == '_GETMAX_' or next_token == '_GETEOS_': |
|
|
|
break |
|
|
|
res += next_token |
|
return res |
|
|
|
def stream_predict(self, query, history): |
|
history.append((query, '')) |
|
|
|
prompt = '' |
|
|
|
|
|
|
|
prompt = "<reserved_106>" + query + "<reserved_107>" |
|
|
|
res = '' |
|
first_token = self.predict_first_token(prompt) |
|
res += first_token |
|
|
|
while True: |
|
next_token = self.predict_next_token() |
|
if next_token == '_GETMAX_' or next_token == '_GETEOS_': |
|
break |
|
res += next_token |
|
history[-1] = (query, res) |
|
yield res, history |
|
|
|
def get_config(self): |
|
pass |