|
"""Inference for FastChat models.""" |
|
import abc |
|
import gc |
|
import json |
|
import math |
|
import os |
|
import sys |
|
import time |
|
from typing import Iterable, Optional, Dict |
|
import warnings |
|
|
|
import psutil |
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
LlamaTokenizer, |
|
LlamaForCausalLM, |
|
AutoModel, |
|
AutoModelForSeq2SeqLM, |
|
T5Tokenizer, |
|
AutoConfig, |
|
) |
|
from transformers.generation.logits_process import ( |
|
LogitsProcessorList, |
|
RepetitionPenaltyLogitsProcessor, |
|
TemperatureLogitsWarper, |
|
TopKLogitsWarper, |
|
TopPLogitsWarper, |
|
) |
|
|
|
from fastchat.conversation import get_conv_template, SeparatorStyle |
|
from fastchat.model.model_adapter import ( |
|
load_model, |
|
get_conversation_template, |
|
get_generate_stream_function, |
|
) |
|
from fastchat.modules.awq import AWQConfig |
|
from fastchat.modules.gptq import GptqConfig |
|
from fastchat.modules.exllama import ExllamaConfig |
|
from fastchat.modules.xfastertransformer import XftConfig |
|
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length |
|
|
|
|
|
def prepare_logits_processor( |
|
temperature: float, repetition_penalty: float, top_p: float, top_k: int |
|
) -> LogitsProcessorList: |
|
processor_list = LogitsProcessorList() |
|
|
|
if temperature >= 1e-5 and temperature != 1.0: |
|
processor_list.append(TemperatureLogitsWarper(temperature)) |
|
if repetition_penalty > 1.0: |
|
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) |
|
if 1e-8 <= top_p < 1.0: |
|
processor_list.append(TopPLogitsWarper(top_p)) |
|
if top_k > 0: |
|
processor_list.append(TopKLogitsWarper(top_k)) |
|
return processor_list |
|
|
|
|
|
@torch.inference_mode() |
|
def generate_stream( |
|
model, |
|
tokenizer, |
|
params: Dict, |
|
device: str, |
|
context_len: int, |
|
stream_interval: int = 2, |
|
judge_sent_end: bool = False, |
|
): |
|
if hasattr(model, "device"): |
|
device = model.device |
|
|
|
|
|
prompt = params["prompt"] |
|
len_prompt = len(prompt) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temperature = float(params.get("temperature", 0.9)) |
|
repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
|
top_p = float(params.get("top_p", 0.8)) |
|
top_k = int(params.get("top_k", -1)) |
|
max_new_tokens = int(params.get("max_new_tokens", 1000)) |
|
|
|
|
|
|
|
logprobs = params.get("logprobs", None) |
|
echo = bool(params.get("echo", True)) |
|
stop_str = params.get("stop", None) |
|
stop_token_ids = params.get("stop_token_ids", None) or [] |
|
if tokenizer.eos_token_id not in stop_token_ids: |
|
stop_token_ids.append(tokenizer.eos_token_id) |
|
|
|
logits_processor = prepare_logits_processor( |
|
temperature, repetition_penalty, top_p, top_k |
|
) |
|
input_ids = tokenizer(prompt).input_ids |
|
|
|
if model.config.is_encoder_decoder: |
|
max_src_len = context_len |
|
else: |
|
max_src_len = context_len - max_new_tokens - 1 |
|
|
|
input_ids = input_ids[-max_src_len:] |
|
output_ids = list(input_ids) |
|
input_echo_len = len(input_ids) |
|
|
|
if model.config.is_encoder_decoder: |
|
if logprobs is not None: |
|
raise NotImplementedError |
|
encoder_output = model.encoder( |
|
input_ids=torch.as_tensor([input_ids], device=device) |
|
)[0] |
|
start_ids = torch.as_tensor( |
|
[[model.generation_config.decoder_start_token_id]], |
|
dtype=torch.int64, |
|
device=device, |
|
) |
|
else: |
|
start_ids = torch.as_tensor([input_ids], device=device) |
|
|
|
past_key_values = out = None |
|
token_logprobs = [None] |
|
sent_interrupt = False |
|
finish_reason = None |
|
stopped = False |
|
for i in range(max_new_tokens): |
|
if i == 0: |
|
if model.config.is_encoder_decoder: |
|
out = model.decoder( |
|
input_ids=start_ids, |
|
encoder_hidden_states=encoder_output, |
|
use_cache=True, |
|
) |
|
logits = model.lm_head(out[0]) |
|
else: |
|
out = model(input_ids=start_ids, use_cache=True) |
|
logits = out.logits |
|
past_key_values = out.past_key_values |
|
|
|
if logprobs is not None: |
|
|
|
shift_input_ids = start_ids[..., 1:].contiguous() |
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() |
|
for label_id, logit in zip( |
|
shift_input_ids[0].tolist(), shift_logits[0] |
|
): |
|
token_logprobs.append(logit[label_id]) |
|
else: |
|
if model.config.is_encoder_decoder: |
|
out = model.decoder( |
|
input_ids=torch.as_tensor( |
|
[[token] if not sent_interrupt else output_ids], |
|
device=device, |
|
), |
|
encoder_hidden_states=encoder_output, |
|
use_cache=True, |
|
past_key_values=past_key_values if not sent_interrupt else None, |
|
) |
|
sent_interrupt = False |
|
|
|
logits = model.lm_head(out[0]) |
|
else: |
|
out = model( |
|
input_ids=torch.as_tensor( |
|
[[token] if not sent_interrupt else output_ids], |
|
device=device, |
|
), |
|
use_cache=True, |
|
past_key_values=past_key_values if not sent_interrupt else None, |
|
) |
|
sent_interrupt = False |
|
logits = out.logits |
|
past_key_values = out.past_key_values |
|
|
|
if logits_processor: |
|
if repetition_penalty > 1.0: |
|
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) |
|
else: |
|
tmp_output_ids = None |
|
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] |
|
else: |
|
last_token_logits = logits[0, -1, :] |
|
|
|
if device == "mps": |
|
|
|
last_token_logits = last_token_logits.float().to("cpu") |
|
|
|
if temperature < 1e-5 or top_p < 1e-8: |
|
_, indices = torch.topk(last_token_logits, 2) |
|
tokens = [int(index) for index in indices.tolist()] |
|
else: |
|
probs = torch.softmax(last_token_logits, dim=-1) |
|
indices = torch.multinomial(probs, num_samples=2) |
|
tokens = [int(token) for token in indices.tolist()] |
|
token = tokens[0] |
|
output_ids.append(token) |
|
if logprobs is not None: |
|
|
|
token_logprobs.append( |
|
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() |
|
) |
|
|
|
if token in stop_token_ids: |
|
stopped = True |
|
else: |
|
stopped = False |
|
|
|
|
|
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: |
|
if echo: |
|
tmp_output_ids = output_ids |
|
rfind_start = len_prompt |
|
else: |
|
tmp_output_ids = output_ids[input_echo_len:] |
|
rfind_start = 0 |
|
|
|
output = tokenizer.decode( |
|
tmp_output_ids, |
|
skip_special_tokens=True, |
|
spaces_between_special_tokens=False, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
ret_logprobs = None |
|
if logprobs is not None: |
|
ret_logprobs = { |
|
"text_offset": [], |
|
"tokens": [ |
|
tokenizer.decode(token) |
|
for token in ( |
|
output_ids if echo else output_ids[input_echo_len:] |
|
) |
|
], |
|
"token_logprobs": token_logprobs |
|
if echo |
|
else token_logprobs[input_echo_len:], |
|
"top_logprobs": [{}] |
|
* len(token_logprobs if echo else token_logprobs[input_echo_len:]), |
|
} |
|
|
|
curr_pos = 0 |
|
for text in ret_logprobs["tokens"]: |
|
ret_logprobs["text_offset"].append(curr_pos) |
|
curr_pos += len(text) |
|
|
|
|
|
if judge_sent_end and stopped and not is_sentence_complete(output): |
|
if len(tokens) > 1: |
|
token = tokens[1] |
|
output_ids[-1] = token |
|
else: |
|
output_ids.pop() |
|
stopped = False |
|
sent_interrupt = True |
|
|
|
partially_stopped = False |
|
if stop_str: |
|
if isinstance(stop_str, str): |
|
pos = output.rfind(stop_str, rfind_start) |
|
if pos != -1: |
|
output = output[:pos] |
|
stopped = True |
|
else: |
|
partially_stopped = is_partial_stop(output, stop_str) |
|
elif isinstance(stop_str, Iterable): |
|
for each_stop in stop_str: |
|
pos = output.rfind(each_stop, rfind_start) |
|
if pos != -1: |
|
output = output[:pos] |
|
stopped = True |
|
break |
|
else: |
|
partially_stopped = is_partial_stop(output, each_stop) |
|
if partially_stopped: |
|
break |
|
else: |
|
raise ValueError("Invalid stop field type.") |
|
|
|
|
|
if not partially_stopped: |
|
yield { |
|
"text": output, |
|
"logprobs": ret_logprobs, |
|
"usage": { |
|
"prompt_tokens": input_echo_len, |
|
"completion_tokens": i, |
|
"total_tokens": input_echo_len + i, |
|
}, |
|
"finish_reason": None, |
|
} |
|
|
|
if stopped: |
|
break |
|
|
|
|
|
else: |
|
finish_reason = "length" |
|
|
|
if stopped: |
|
finish_reason = "stop" |
|
|
|
yield { |
|
"text": output, |
|
"logprobs": ret_logprobs, |
|
"usage": { |
|
"prompt_tokens": input_echo_len, |
|
"completion_tokens": i, |
|
"total_tokens": input_echo_len + i, |
|
}, |
|
"finish_reason": finish_reason, |
|
} |
|
|
|
|
|
del past_key_values, out |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
if device == "xpu": |
|
torch.xpu.empty_cache() |
|
if device == "npu": |
|
torch.npu.empty_cache() |
|
|
|
|
|
class ChatIO(abc.ABC): |
|
@abc.abstractmethod |
|
def prompt_for_input(self, role: str) -> str: |
|
"""Prompt for input from a role.""" |
|
|
|
@abc.abstractmethod |
|
def prompt_for_output(self, role: str): |
|
"""Prompt for output from a role.""" |
|
|
|
@abc.abstractmethod |
|
def stream_output(self, output_stream): |
|
"""Stream output.""" |
|
|
|
@abc.abstractmethod |
|
def print_output(self, text: str): |
|
"""Print output.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ragatouille import RAGPretrainedModel |
|
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") |
|
RAG = RAGPretrainedModel.from_index('/mnt/beegfs/fahad.khan/GeoMinGPT/RAGatouille/examples/.ragatouille/colbert/indexes/UAE_Documents') |
|
|
|
def chat_loop( |
|
model_path: str, |
|
device: str, |
|
num_gpus: int, |
|
max_gpu_memory: str, |
|
dtype: Optional[torch.dtype], |
|
load_8bit: bool, |
|
cpu_offloading: bool, |
|
conv_template: Optional[str], |
|
conv_system_msg: Optional[str], |
|
temperature: float, |
|
repetition_penalty: float, |
|
max_new_tokens: int, |
|
chatio: ChatIO, |
|
gptq_config: Optional[GptqConfig] = None, |
|
awq_config: Optional[AWQConfig] = None, |
|
exllama_config: Optional[ExllamaConfig] = None, |
|
xft_config: Optional[XftConfig] = None, |
|
revision: str = "main", |
|
judge_sent_end: bool = True, |
|
debug: bool = True, |
|
history: bool = True, |
|
): |
|
|
|
model, tokenizer = load_model( |
|
model_path, |
|
device=device, |
|
num_gpus=num_gpus, |
|
max_gpu_memory=max_gpu_memory, |
|
dtype=dtype, |
|
load_8bit=load_8bit, |
|
cpu_offloading=cpu_offloading, |
|
gptq_config=gptq_config, |
|
awq_config=awq_config, |
|
exllama_config=exllama_config, |
|
xft_config=xft_config, |
|
revision=revision, |
|
debug=debug, |
|
) |
|
generate_stream_func = get_generate_stream_function(model, model_path) |
|
|
|
model_type = str(type(model)).lower() |
|
is_t5 = "t5" in model_type |
|
is_codet5p = "codet5p" in model_type |
|
is_xft = "xft" in model_type |
|
|
|
|
|
if is_t5 and repetition_penalty == 1.0: |
|
repetition_penalty = 1.2 |
|
|
|
|
|
context_len = get_context_length(model.config) |
|
|
|
|
|
def new_chat(): |
|
if conv_template: |
|
conv = get_conv_template(conv_template) |
|
else: |
|
conv = get_conversation_template(model_path) |
|
if conv_system_msg is not None: |
|
conv.set_system_message(conv_system_msg) |
|
return conv |
|
|
|
def reload_conv(conv): |
|
""" |
|
Reprints the conversation from the start. |
|
""" |
|
for message in conv.messages[conv.offset :]: |
|
chatio.prompt_for_output(message[0]) |
|
chatio.print_output(message[1]) |
|
|
|
conv = None |
|
|
|
|
|
import json |
|
from tqdm import tqdm |
|
import random |
|
|
|
combined_data = [] |
|
data_path = "/mnt/beegfs/fahad.khan/GeoMinGPT/Evaluation/Sample_600_UAE_qs.json" |
|
with open(data_path, "r") as file: |
|
for line in file: |
|
d = json.loads(line) |
|
combined_data.append(d) |
|
|
|
geology_mineral_prompts = [ |
|
"You are a geology and mineral expert, provide detailed and insightful answers to the user's questions.", |
|
"As an expert in geology and minerals, give comprehensive and expert answers to the user's questions.", |
|
"You are an experienced Geologist specializing in minerals, provide detailed and informative answers to the user's questions.", |
|
"Your role as a geology and mineral expert is to provide detailed and informative answers to the user's questions. Ensure that your responses are helpful and relevant to the topic.", |
|
"As a geology and mineral expert, your task is to give comprehensive and expert answers to the user's questions. Please be detailed and informative in your responses.", |
|
"Your duty as a geology and mineral expert is to offer valuable and informative answers to the user's questions. Please ensure that your responses are relevant and helpful.", |
|
"As a geology and mineral expert, it is important that you provide detailed and insightful responses to the user's questions. Please be as helpful and informative as possible.", |
|
"Your responsibility as a geology and mineral expert is to offer detailed and informative answers to the user's questions. Please make sure your responses are relevant and helpful.", |
|
"As a geology and mineral expert, your task is to provide the user with helpful and informative answers to their questions. Please be as detailed and thorough as possible.", |
|
"Your job as a geology and mineral expert is to provide the user with comprehensive and informative answers to their questions. Please make sure your responses are relevant and helpful.", |
|
"As a geology and mineral expert, it is crucial that you offer detailed and insightful answers to the user's questions. Please ensure that your responses are informative and helpful.", |
|
"Your role as a geology and mineral expert is to provide the user with valuable and informative answers to their questions. Please make sure your responses are detailed and relevant.", |
|
"As a geology and mineral expert, your responsibility is to offer the user helpful and informative answers to their questions. Please be as detailed and thorough as possible.", |
|
] |
|
|
|
for data in tqdm(combined_data): |
|
questions = data["Sample_qs"] |
|
|
|
for index, question in enumerate(questions): |
|
print(f"\n----------------------------\n{index}\n----------------------------\n") |
|
selected_prompt = random.choice(geology_mineral_prompts) |
|
|
|
|
|
if not history or not conv: |
|
conv = new_chat() |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
result_docs = RAG.search(query=question, k=1) |
|
full_prompt = "\nUsing the below information as a reference, answer the following question.\n### Reference Data: {}\n".format(result_docs[0]['content']) + \ |
|
"---------------------\n" + \ |
|
"\n###Question: " + \ |
|
"{}".format(question) |
|
|
|
score = result_docs[0]['score'] |
|
|
|
if score >10: |
|
inp = f"{selected_prompt}\n{full_prompt}" |
|
else : |
|
inp = f"{selected_prompt}\nQuestion: {question}" |
|
|
|
except EOFError: |
|
inp = "" |
|
|
|
conv.append_message(conv.roles[0], inp) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
|
|
|
|
if is_codet5p: |
|
prompt = inp |
|
|
|
gen_params = { |
|
"model": model_path, |
|
"prompt": prompt, |
|
"temperature": temperature, |
|
"repetition_penalty": repetition_penalty, |
|
"max_new_tokens": max_new_tokens, |
|
"stop": conv.stop_str, |
|
"stop_token_ids": conv.stop_token_ids, |
|
"echo": False, |
|
} |
|
|
|
try: |
|
chatio.prompt_for_output(conv.roles[1]) |
|
output_stream = generate_stream_func( |
|
model, |
|
tokenizer, |
|
gen_params, |
|
device, |
|
context_len=context_len, |
|
judge_sent_end=judge_sent_end, |
|
) |
|
|
|
t = time.time() |
|
outputs = chatio.stream_output(output_stream) |
|
duration = time.time() - t |
|
conv.update_last_message(outputs.strip()) |
|
|
|
output = conv.messages[-2:] |
|
json_obj = { |
|
"Question": question, |
|
"Context_Ques": output[0][1], |
|
"Answer": output[1][1] |
|
} |
|
with open("/mnt/beegfs/fahad.khan/GeoMinGPT/Evaluation/UAE_Eval_Colbert_600qs.json","a") as file: |
|
json.dump(json_obj, file, indent=None) |
|
file.write('\n') |
|
|
|
if debug: |
|
num_tokens = len(tokenizer.encode(outputs)) |
|
msg = { |
|
"conv_template": conv.name, |
|
"prompt": prompt, |
|
"outputs": outputs, |
|
"speed (token/s)": round(num_tokens / duration, 2), |
|
} |
|
print(f"\n{msg}\n") |
|
|
|
except KeyboardInterrupt: |
|
print("stopped generation.") |
|
|
|
if conv.messages[-1][1] is None: |
|
conv.messages.pop() |
|
|
|
if conv.messages[-1][0] == conv.roles[0]: |
|
conv.messages.pop() |
|
|
|
reload_conv(conv) |