Medical_GPT / model.py
YenJung's picture
Update model.py
061abd2
import sys
from collections import namedtuple
import click
import torch
from peft import PeftModel
from transformers import (
AutoModel,
AutoTokenizer,
BloomForCausalLM,
BloomTokenizerFast,
GenerationConfig,
LlamaForCausalLM,
LlamaTokenizer,
)
from utils import generate_prompt
def decide_model(args, device_map):
ModelClass = namedtuple("ModelClass", ('tokenizer', 'model'))
_MODEL_CLASSES = {
"llama": ModelClass(**{
"tokenizer": LlamaTokenizer,
"model": LlamaForCausalLM,
}),
"chatglm": ModelClass(**{
"tokenizer": AutoTokenizer, #ChatGLMTokenizer,
"model": AutoModel, #ChatGLMForConditionalGeneration,
}),
"bloom": ModelClass(**{
"tokenizer": BloomTokenizerFast,
"model": BloomForCausalLM,
}),
"Auto": ModelClass(**{
"tokenizer": AutoTokenizer,
"model": AutoModel,
})
}
model_type = "Auto" if args.model_type not in ["llama", "bloom", "chatglm"] else args.model_type
if model_type == "chatglm":
tokenizer = _MODEL_CLASSES[model_type].tokenizer.from_pretrained(
args.base_model,
trust_remote_code=True
)
# todo: ChatGLMForConditionalGeneration revision
model = _MODEL_CLASSES[model_type].model.from_pretrained(
args.base_model,
trust_remote_code=True,
device_map=device_map
)
else:
tokenizer = _MODEL_CLASSES[model_type].tokenizer.from_pretrained(args.base_model)
model = _MODEL_CLASSES[model_type].model.from_pretrained(
args.base_model,
load_in_8bit=False,
torch_dtype=torch.float16,
device_map=device_map
)
if model_type == "llama":
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left" # Allow batched inference
if device_map == "auto":
model = PeftModel.from_pretrained(
model,
args.finetuned_weights,
torch_dtype=torch.float16,
)
else:
model = PeftModel.from_pretrained(
model,
args.finetuned_weights,
device_map=device_map
)
return tokenizer, model
class ModelServe:
def __init__(
self,
load_8bit: bool = True,
model_type: str = "llama",
base_model: str = "linhvu/decapoda-research-llama-7b-hf",
finetuned_weights: str = "llama-7b-hf_alpaca-en-zh",
):
args = locals()
namedtupler = namedtuple("args", tuple(list(args.keys())))
local_args = namedtupler(**args)
if torch.cuda.is_available():
self.device = "cuda:0"
self.device_map = "auto"
#self.max_memory = {i: "12GB" for i in range(torch.cuda.device_count())}
#self.max_memory.update({"cpu": "30GB"})
else:
self.device = "cpu"
self.device_map = {"": self.device}
self.tokenizer, self.model = decide_model(args=local_args, device_map=self.device_map)
# unwind broken decapoda-research config
self.model.config.pad_token_id = self.tokenizer.pad_token_id = 0 # unk
self.model.config.bos_token_id = 1
self.model.config.eos_token_id = 2
if not load_8bit:
self.model.half() # seems to fix bugs for some users.
self.model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
self.model = torch.compile(self.model)
def generate(
self,
instruction: str,
input: str,
temperature: float = 0.7,
top_p: float = 0.75,
top_k: int = 40,
num_beams: int = 4,
max_new_tokens: int = 1024,
**kwargs
):
prompt = generate_prompt(instruction, input)
print(f"Prompt: {prompt}")
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
print("generating...")
with torch.no_grad():
generation_output = self.model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = self.tokenizer.decode(s)
print(f"Output: {output}")
return output.split("### 回覆:")[1].strip()