chuangpt / modules /models /GoogleGemma.py
eggacheb's picture
Upload 105 files
1ea2ba0 verified
raw
history blame contribute delete
No virus
3.91 kB
import logging
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from ..presets import *
from .base_model import BaseLLMModel
class GoogleGemmaClient(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
global GEMMA_TOKENIZER, GEMMA_MODEL
# self.deinitialize()
self.default_max_generation_token = self.token_upper_limit
self.max_generation_token = self.token_upper_limit
if GEMMA_TOKENIZER is None or GEMMA_MODEL is None:
model_path = None
if os.path.exists("models"):
model_dirs = os.listdir("models")
if model_name in model_dirs:
model_path = f"models/{model_name}"
if model_path is not None:
model_source = model_path
else:
if os.path.exists(
os.path.join("models", MODEL_METADATA[model_name]["model_name"])
):
model_source = os.path.join(
"models", MODEL_METADATA[model_name]["model_name"]
)
else:
try:
model_source = MODEL_METADATA[model_name]["repo_id"]
except:
model_source = model_name
dtype = torch.bfloat16
GEMMA_TOKENIZER = AutoTokenizer.from_pretrained(
model_source, use_auth_token=os.environ["HF_AUTH_TOKEN"]
)
GEMMA_MODEL = AutoModelForCausalLM.from_pretrained(
model_source,
device_map="auto",
torch_dtype=dtype,
trust_remote_code=True,
resume_download=True,
use_auth_token=os.environ["HF_AUTH_TOKEN"],
)
def deinitialize(self):
global GEMMA_TOKENIZER, GEMMA_MODEL
GEMMA_TOKENIZER = None
GEMMA_MODEL = None
self.clear_cuda_cache()
logging.info("GEMMA deinitialized")
def _get_gemma_style_input(self):
global GEMMA_TOKENIZER
# messages = [{"role": "system", "content": self.system_prompt}, *self.history] # system prompt is not supported
messages = self.history
prompt = GEMMA_TOKENIZER.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = GEMMA_TOKENIZER.encode(
prompt, add_special_tokens=True, return_tensors="pt"
)
return inputs
def get_answer_at_once(self):
global GEMMA_TOKENIZER, GEMMA_MODEL
inputs = self._get_gemma_style_input()
outputs = GEMMA_MODEL.generate(
input_ids=inputs.to(GEMMA_MODEL.device),
max_new_tokens=self.max_generation_token,
)
generated_token_count = outputs.shape[1] - inputs.shape[1]
outputs = GEMMA_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
outputs = outputs.split("<start_of_turn>model\n")[-1][:-5]
self.clear_cuda_cache()
return outputs, generated_token_count
def get_answer_stream_iter(self):
global GEMMA_TOKENIZER, GEMMA_MODEL
inputs = self._get_gemma_style_input()
streamer = TextIteratorStreamer(
GEMMA_TOKENIZER, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
input_kwargs = dict(
input_ids=inputs.to(GEMMA_MODEL.device),
max_new_tokens=self.max_generation_token,
streamer=streamer,
)
t = Thread(target=GEMMA_MODEL.generate, kwargs=input_kwargs)
t.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
self.clear_cuda_cache()