|
import logging |
|
from threading import Thread |
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
|
from ..presets import * |
|
from .base_model import BaseLLMModel |
|
|
|
|
|
class GoogleGemmaClient(BaseLLMModel): |
|
def __init__(self, model_name, api_key, user_name="") -> None: |
|
super().__init__(model_name=model_name, user=user_name) |
|
|
|
global GEMMA_TOKENIZER, GEMMA_MODEL |
|
|
|
self.default_max_generation_token = self.token_upper_limit |
|
self.max_generation_token = self.token_upper_limit |
|
if GEMMA_TOKENIZER is None or GEMMA_MODEL is None: |
|
model_path = None |
|
if os.path.exists("models"): |
|
model_dirs = os.listdir("models") |
|
if model_name in model_dirs: |
|
model_path = f"models/{model_name}" |
|
if model_path is not None: |
|
model_source = model_path |
|
else: |
|
if os.path.exists( |
|
os.path.join("models", MODEL_METADATA[model_name]["model_name"]) |
|
): |
|
model_source = os.path.join( |
|
"models", MODEL_METADATA[model_name]["model_name"] |
|
) |
|
else: |
|
try: |
|
model_source = MODEL_METADATA[model_name]["repo_id"] |
|
except: |
|
model_source = model_name |
|
dtype = torch.bfloat16 |
|
GEMMA_TOKENIZER = AutoTokenizer.from_pretrained( |
|
model_source, use_auth_token=os.environ["HF_AUTH_TOKEN"] |
|
) |
|
GEMMA_MODEL = AutoModelForCausalLM.from_pretrained( |
|
model_source, |
|
device_map="auto", |
|
torch_dtype=dtype, |
|
trust_remote_code=True, |
|
resume_download=True, |
|
use_auth_token=os.environ["HF_AUTH_TOKEN"], |
|
) |
|
|
|
def deinitialize(self): |
|
global GEMMA_TOKENIZER, GEMMA_MODEL |
|
GEMMA_TOKENIZER = None |
|
GEMMA_MODEL = None |
|
self.clear_cuda_cache() |
|
logging.info("GEMMA deinitialized") |
|
|
|
def _get_gemma_style_input(self): |
|
global GEMMA_TOKENIZER |
|
|
|
messages = self.history |
|
prompt = GEMMA_TOKENIZER.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True |
|
) |
|
inputs = GEMMA_TOKENIZER.encode( |
|
prompt, add_special_tokens=True, return_tensors="pt" |
|
) |
|
return inputs |
|
|
|
def get_answer_at_once(self): |
|
global GEMMA_TOKENIZER, GEMMA_MODEL |
|
inputs = self._get_gemma_style_input() |
|
outputs = GEMMA_MODEL.generate( |
|
input_ids=inputs.to(GEMMA_MODEL.device), |
|
max_new_tokens=self.max_generation_token, |
|
) |
|
generated_token_count = outputs.shape[1] - inputs.shape[1] |
|
outputs = GEMMA_TOKENIZER.decode(outputs[0], skip_special_tokens=True) |
|
outputs = outputs.split("<start_of_turn>model\n")[-1][:-5] |
|
self.clear_cuda_cache() |
|
return outputs, generated_token_count |
|
|
|
def get_answer_stream_iter(self): |
|
global GEMMA_TOKENIZER, GEMMA_MODEL |
|
inputs = self._get_gemma_style_input() |
|
streamer = TextIteratorStreamer( |
|
GEMMA_TOKENIZER, timeout=10.0, skip_prompt=True, skip_special_tokens=True |
|
) |
|
input_kwargs = dict( |
|
input_ids=inputs.to(GEMMA_MODEL.device), |
|
max_new_tokens=self.max_generation_token, |
|
streamer=streamer, |
|
) |
|
t = Thread(target=GEMMA_MODEL.generate, kwargs=input_kwargs) |
|
t.start() |
|
|
|
partial_text = "" |
|
for new_text in streamer: |
|
partial_text += new_text |
|
yield partial_text |
|
self.clear_cuda_cache() |
|
|