MobiLlama / fastchat /serve /custom_inference.py
Ashmal's picture
Upload folder using huggingface_hub
5472531 verified
raw
history blame
21.6 kB
"""Inference for FastChat models."""
import abc
import gc
import json
import math
import os
import sys
import time
from typing import Iterable, Optional, Dict
import warnings
import psutil
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaTokenizer,
LlamaForCausalLM,
AutoModel,
AutoModelForSeq2SeqLM,
T5Tokenizer,
AutoConfig,
)
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from fastchat.conversation import get_conv_template, SeparatorStyle
from fastchat.model.model_adapter import (
load_model,
get_conversation_template,
get_generate_stream_function,
)
from fastchat.modules.awq import AWQConfig
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.modules.xfastertransformer import XftConfig
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
@torch.inference_mode()
def generate_stream(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 2,
judge_sent_end: bool = False,
):
if hasattr(model, "device"):
device = model.device
# Read parameters
prompt = params["prompt"]
len_prompt = len(prompt)
# temperature = float(params.get("temperature", 1.0))
# repetition_penalty = float(params.get("repetition_penalty", 1.0))
# top_p = float(params.get("top_p", 1.0))
# top_k = int(params.get("top_k", -1)) # -1 means disable
# max_new_tokens = int(params.get("max_new_tokens", 256))
###ADD
#uncomment the original above, play params below
temperature = float(params.get("temperature", 0.9))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 0.8))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 1000))
###ADD
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
if tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(tokenizer.eos_token_id)
logits_processor = prepare_logits_processor(
temperature, repetition_penalty, top_p, top_k
)
input_ids = tokenizer(prompt).input_ids
if model.config.is_encoder_decoder:
max_src_len = context_len
else: # truncate
max_src_len = context_len - max_new_tokens - 1
input_ids = input_ids[-max_src_len:]
output_ids = list(input_ids)
input_echo_len = len(input_ids)
if model.config.is_encoder_decoder:
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
raise NotImplementedError
encoder_output = model.encoder(
input_ids=torch.as_tensor([input_ids], device=device)
)[0]
start_ids = torch.as_tensor(
[[model.generation_config.decoder_start_token_id]],
dtype=torch.int64,
device=device,
)
else:
start_ids = torch.as_tensor([input_ids], device=device)
past_key_values = out = None
token_logprobs = [None] # The first token has no logprobs.
sent_interrupt = False
finish_reason = None
stopped = False
for i in range(max_new_tokens):
if i == 0: # prefill
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=start_ids,
encoder_hidden_states=encoder_output,
use_cache=True,
)
logits = model.lm_head(out[0])
else:
out = model(input_ids=start_ids, use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
if logprobs is not None:
# Prefull logprobs for the prompt.
shift_input_ids = start_ids[..., 1:].contiguous()
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
for label_id, logit in zip(
shift_input_ids[0].tolist(), shift_logits[0]
):
token_logprobs.append(logit[label_id])
else: # decoding
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=torch.as_tensor(
[[token] if not sent_interrupt else output_ids],
device=device,
),
encoder_hidden_states=encoder_output,
use_cache=True,
past_key_values=past_key_values if not sent_interrupt else None,
)
sent_interrupt = False
logits = model.lm_head(out[0])
else:
out = model(
input_ids=torch.as_tensor(
[[token] if not sent_interrupt else output_ids],
device=device,
),
use_cache=True,
past_key_values=past_key_values if not sent_interrupt else None,
)
sent_interrupt = False
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if repetition_penalty > 1.0:
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
last_token_logits = logits[0, -1, :]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-5 or top_p < 1e-8: # greedy
_, indices = torch.topk(last_token_logits, 2)
tokens = [int(index) for index in indices.tolist()]
else:
probs = torch.softmax(last_token_logits, dim=-1)
indices = torch.multinomial(probs, num_samples=2)
tokens = [int(token) for token in indices.tolist()]
token = tokens[0]
output_ids.append(token)
if logprobs is not None:
# Cannot use last_token_logits because logprobs is based on raw logits.
token_logprobs.append(
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
)
if token in stop_token_ids:
stopped = True
else:
stopped = False
# Yield the output tokens
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
if echo:
tmp_output_ids = output_ids
rfind_start = len_prompt
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
ret_logprobs = None
if logprobs is not None:
ret_logprobs = {
"text_offset": [],
"tokens": [
tokenizer.decode(token)
for token in (
output_ids if echo else output_ids[input_echo_len:]
)
],
"token_logprobs": token_logprobs
if echo
else token_logprobs[input_echo_len:],
"top_logprobs": [{}]
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
}
# Compute text_offset
curr_pos = 0
for text in ret_logprobs["tokens"]:
ret_logprobs["text_offset"].append(curr_pos)
curr_pos += len(text)
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
if judge_sent_end and stopped and not is_sentence_complete(output):
if len(tokens) > 1:
token = tokens[1]
output_ids[-1] = token
else:
output_ids.pop()
stopped = False
sent_interrupt = True
partially_stopped = False
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
else:
partially_stopped = is_partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")
# Prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"logprobs": ret_logprobs,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
if stopped:
break
# Finish stream event, which contains finish reason
else:
finish_reason = "length"
if stopped:
finish_reason = "stop"
yield {
"text": output,
"logprobs": ret_logprobs,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
# Clean
del past_key_values, out
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()
class ChatIO(abc.ABC):
@abc.abstractmethod
def prompt_for_input(self, role: str) -> str:
"""Prompt for input from a role."""
@abc.abstractmethod
def prompt_for_output(self, role: str):
"""Prompt for output from a role."""
@abc.abstractmethod
def stream_output(self, output_stream):
"""Stream output."""
@abc.abstractmethod
def print_output(self, text: str):
"""Print output."""
# adding RAG component here
# from langchain_community.vectorstores import Chroma
# from langchain_community.embeddings import SentenceTransformerEmbeddings
# def chroma_search():
# # directory = "UAE_Docs_Embeddings"
# directory = "/mnt/beegfs/fahad.khan/GeoMinGPT/VectorDB/UAE_Specific_Docs_Embeddings"
# embeddings = SentenceTransformerEmbeddings(model_name ="sentence-transformers/all-MiniLM-L6-v2")
# vectorDB = Chroma(persist_directory=directory, embedding_function=embeddings)
# return vectorDB
# vectorDB = chroma_search()
from ragatouille import RAGPretrainedModel
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
RAG = RAGPretrainedModel.from_index('/mnt/beegfs/fahad.khan/GeoMinGPT/RAGatouille/examples/.ragatouille/colbert/indexes/UAE_Documents')
def chat_loop(
model_path: str,
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype],
load_8bit: bool,
cpu_offloading: bool,
conv_template: Optional[str],
conv_system_msg: Optional[str],
temperature: float,
repetition_penalty: float,
max_new_tokens: int,
chatio: ChatIO,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
xft_config: Optional[XftConfig] = None,
revision: str = "main",
judge_sent_end: bool = True,
debug: bool = True,
history: bool = True,
):
# Model
model, tokenizer = load_model(
model_path,
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
xft_config=xft_config,
revision=revision,
debug=debug,
)
generate_stream_func = get_generate_stream_function(model, model_path)
model_type = str(type(model)).lower()
is_t5 = "t5" in model_type
is_codet5p = "codet5p" in model_type
is_xft = "xft" in model_type
# Hardcode T5's default repetition penalty to be 1.2
if is_t5 and repetition_penalty == 1.0:
repetition_penalty = 1.2
# Set context length
context_len = get_context_length(model.config)
# Chat
def new_chat():
if conv_template:
conv = get_conv_template(conv_template)
else:
conv = get_conversation_template(model_path)
if conv_system_msg is not None:
conv.set_system_message(conv_system_msg)
return conv
def reload_conv(conv):
"""
Reprints the conversation from the start.
"""
for message in conv.messages[conv.offset :]:
chatio.prompt_for_output(message[0])
chatio.print_output(message[1])
conv = None
#writing custom code here
import json
from tqdm import tqdm
import random
combined_data = []
data_path = "/mnt/beegfs/fahad.khan/GeoMinGPT/Evaluation/Sample_600_UAE_qs.json"
with open(data_path, "r") as file:
for line in file:
d = json.loads(line)
combined_data.append(d)
geology_mineral_prompts = [
"You are a geology and mineral expert, provide detailed and insightful answers to the user's questions.",
"As an expert in geology and minerals, give comprehensive and expert answers to the user's questions.",
"You are an experienced Geologist specializing in minerals, provide detailed and informative answers to the user's questions.",
"Your role as a geology and mineral expert is to provide detailed and informative answers to the user's questions. Ensure that your responses are helpful and relevant to the topic.",
"As a geology and mineral expert, your task is to give comprehensive and expert answers to the user's questions. Please be detailed and informative in your responses.",
"Your duty as a geology and mineral expert is to offer valuable and informative answers to the user's questions. Please ensure that your responses are relevant and helpful.",
"As a geology and mineral expert, it is important that you provide detailed and insightful responses to the user's questions. Please be as helpful and informative as possible.",
"Your responsibility as a geology and mineral expert is to offer detailed and informative answers to the user's questions. Please make sure your responses are relevant and helpful.",
"As a geology and mineral expert, your task is to provide the user with helpful and informative answers to their questions. Please be as detailed and thorough as possible.",
"Your job as a geology and mineral expert is to provide the user with comprehensive and informative answers to their questions. Please make sure your responses are relevant and helpful.",
"As a geology and mineral expert, it is crucial that you offer detailed and insightful answers to the user's questions. Please ensure that your responses are informative and helpful.",
"Your role as a geology and mineral expert is to provide the user with valuable and informative answers to their questions. Please make sure your responses are detailed and relevant.",
"As a geology and mineral expert, your responsibility is to offer the user helpful and informative answers to their questions. Please be as detailed and thorough as possible.",
]
for data in tqdm(combined_data):
questions = data["Sample_qs"]
# random.shuffle(questions)
for index, question in enumerate(questions):
print(f"\n----------------------------\n{index}\n----------------------------\n")
selected_prompt = random.choice(geology_mineral_prompts)
if not history or not conv:
conv = new_chat()
try:
# inp = chatio.prompt_for_input(conv.roles[0])
# inp = f"{selected_prompt}\nQuestion: {question} Explain the response in detail, considering that you are geology and mineral expert. Provide examples or evidence to support your response. Discuss the sub domains and analyze the all implications. Your response should be expressive and limited to 1000 words"
result_docs = RAG.search(query=question, k=1)
full_prompt = "\nUsing the below information as a reference, answer the following question.\n### Reference Data: {}\n".format(result_docs[0]['content']) + \
"---------------------\n" + \
"\n###Question: " + \
"{}".format(question)
score = result_docs[0]['score']
if score >10:
inp = f"{selected_prompt}\n{full_prompt}"
else :
inp = f"{selected_prompt}\nQuestion: {question}"
except EOFError:
inp = ""
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# print(prompt)
if is_codet5p: # codet5p is a code completion model.
prompt = inp
gen_params = {
"model": model_path,
"prompt": prompt,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_tokens": max_new_tokens,
"stop": conv.stop_str,
"stop_token_ids": conv.stop_token_ids,
"echo": False,
}
try:
chatio.prompt_for_output(conv.roles[1])
output_stream = generate_stream_func(
model,
tokenizer,
gen_params,
device,
context_len=context_len,
judge_sent_end=judge_sent_end,
)
t = time.time()
outputs = chatio.stream_output(output_stream)
duration = time.time() - t
conv.update_last_message(outputs.strip())
output = conv.messages[-2:]
json_obj = {
"Question": question,
"Context_Ques": output[0][1],
"Answer": output[1][1]
}
with open("/mnt/beegfs/fahad.khan/GeoMinGPT/Evaluation/UAE_Eval_Colbert_600qs.json","a") as file:
json.dump(json_obj, file, indent=None)
file.write('\n')
if debug:
num_tokens = len(tokenizer.encode(outputs))
msg = {
"conv_template": conv.name,
"prompt": prompt,
"outputs": outputs,
"speed (token/s)": round(num_tokens / duration, 2),
}
print(f"\n{msg}\n")
except KeyboardInterrupt:
print("stopped generation.")
# If generation didn't finish
if conv.messages[-1][1] is None:
conv.messages.pop()
# Remove last user message, so there isn't a double up
if conv.messages[-1][0] == conv.roles[0]:
conv.messages.pop()
reload_conv(conv)