"""Inference for FastChat models.""" import abc import gc import json import math import os import sys import time from typing import Iterable, Optional, Dict import warnings import psutil import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, LlamaForCausalLM, AutoModel, AutoModelForSeq2SeqLM, T5Tokenizer, AutoConfig, ) from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, ) from fastchat.conversation import get_conv_template, SeparatorStyle from fastchat.model.model_adapter import ( load_model, get_conversation_template, get_generate_stream_function, ) from fastchat.modules.awq import AWQConfig from fastchat.modules.gptq import GptqConfig from fastchat.modules.exllama import ExllamaConfig from fastchat.modules.xfastertransformer import XftConfig from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length def prepare_logits_processor( temperature: float, repetition_penalty: float, top_p: float, top_k: int ) -> LogitsProcessorList: processor_list = LogitsProcessorList() # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. if temperature >= 1e-5 and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) if repetition_penalty > 1.0: processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) if 1e-8 <= top_p < 1.0: processor_list.append(TopPLogitsWarper(top_p)) if top_k > 0: processor_list.append(TopKLogitsWarper(top_k)) return processor_list @torch.inference_mode() def generate_stream( model, tokenizer, params: Dict, device: str, context_len: int, stream_interval: int = 2, judge_sent_end: bool = False, ): if hasattr(model, "device"): device = model.device # Read parameters prompt = params["prompt"] len_prompt = len(prompt) # temperature = float(params.get("temperature", 1.0)) # repetition_penalty = float(params.get("repetition_penalty", 1.0)) # top_p = float(params.get("top_p", 1.0)) # top_k = int(params.get("top_k", -1)) # -1 means disable # max_new_tokens = int(params.get("max_new_tokens", 256)) ###ADD #uncomment the original above, play params below temperature = float(params.get("temperature", 0.9)) repetition_penalty = float(params.get("repetition_penalty", 1.0)) top_p = float(params.get("top_p", 0.8)) top_k = int(params.get("top_k", -1)) # -1 means disable max_new_tokens = int(params.get("max_new_tokens", 1000)) ###ADD logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. echo = bool(params.get("echo", True)) stop_str = params.get("stop", None) stop_token_ids = params.get("stop_token_ids", None) or [] if tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(tokenizer.eos_token_id) logits_processor = prepare_logits_processor( temperature, repetition_penalty, top_p, top_k ) input_ids = tokenizer(prompt).input_ids if model.config.is_encoder_decoder: max_src_len = context_len else: # truncate max_src_len = context_len - max_new_tokens - 1 input_ids = input_ids[-max_src_len:] output_ids = list(input_ids) input_echo_len = len(input_ids) if model.config.is_encoder_decoder: if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. raise NotImplementedError encoder_output = model.encoder( input_ids=torch.as_tensor([input_ids], device=device) )[0] start_ids = torch.as_tensor( [[model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=device, ) else: start_ids = torch.as_tensor([input_ids], device=device) past_key_values = out = None token_logprobs = [None] # The first token has no logprobs. sent_interrupt = False finish_reason = None stopped = False for i in range(max_new_tokens): if i == 0: # prefill if model.config.is_encoder_decoder: out = model.decoder( input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True, ) logits = model.lm_head(out[0]) else: out = model(input_ids=start_ids, use_cache=True) logits = out.logits past_key_values = out.past_key_values if logprobs is not None: # Prefull logprobs for the prompt. shift_input_ids = start_ids[..., 1:].contiguous() shift_logits = logits[..., :-1, :].contiguous() shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() for label_id, logit in zip( shift_input_ids[0].tolist(), shift_logits[0] ): token_logprobs.append(logit[label_id]) else: # decoding if model.config.is_encoder_decoder: out = model.decoder( input_ids=torch.as_tensor( [[token] if not sent_interrupt else output_ids], device=device, ), encoder_hidden_states=encoder_output, use_cache=True, past_key_values=past_key_values if not sent_interrupt else None, ) sent_interrupt = False logits = model.lm_head(out[0]) else: out = model( input_ids=torch.as_tensor( [[token] if not sent_interrupt else output_ids], device=device, ), use_cache=True, past_key_values=past_key_values if not sent_interrupt else None, ) sent_interrupt = False logits = out.logits past_key_values = out.past_key_values if logits_processor: if repetition_penalty > 1.0: tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) else: tmp_output_ids = None last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] else: last_token_logits = logits[0, -1, :] if device == "mps": # Switch to CPU by avoiding some bugs in mps backend. last_token_logits = last_token_logits.float().to("cpu") if temperature < 1e-5 or top_p < 1e-8: # greedy _, indices = torch.topk(last_token_logits, 2) tokens = [int(index) for index in indices.tolist()] else: probs = torch.softmax(last_token_logits, dim=-1) indices = torch.multinomial(probs, num_samples=2) tokens = [int(token) for token in indices.tolist()] token = tokens[0] output_ids.append(token) if logprobs is not None: # Cannot use last_token_logits because logprobs is based on raw logits. token_logprobs.append( torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() ) if token in stop_token_ids: stopped = True else: stopped = False # Yield the output tokens if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: if echo: tmp_output_ids = output_ids rfind_start = len_prompt else: tmp_output_ids = output_ids[input_echo_len:] rfind_start = 0 output = tokenizer.decode( tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, ) ret_logprobs = None if logprobs is not None: ret_logprobs = { "text_offset": [], "tokens": [ tokenizer.decode(token) for token in ( output_ids if echo else output_ids[input_echo_len:] ) ], "token_logprobs": token_logprobs if echo else token_logprobs[input_echo_len:], "top_logprobs": [{}] * len(token_logprobs if echo else token_logprobs[input_echo_len:]), } # Compute text_offset curr_pos = 0 for text in ret_logprobs["tokens"]: ret_logprobs["text_offset"].append(curr_pos) curr_pos += len(text) # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way if judge_sent_end and stopped and not is_sentence_complete(output): if len(tokens) > 1: token = tokens[1] output_ids[-1] = token else: output_ids.pop() stopped = False sent_interrupt = True partially_stopped = False if stop_str: if isinstance(stop_str, str): pos = output.rfind(stop_str, rfind_start) if pos != -1: output = output[:pos] stopped = True else: partially_stopped = is_partial_stop(output, stop_str) elif isinstance(stop_str, Iterable): for each_stop in stop_str: pos = output.rfind(each_stop, rfind_start) if pos != -1: output = output[:pos] stopped = True break else: partially_stopped = is_partial_stop(output, each_stop) if partially_stopped: break else: raise ValueError("Invalid stop field type.") # Prevent yielding partial stop sequence if not partially_stopped: yield { "text": output, "logprobs": ret_logprobs, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, "finish_reason": None, } if stopped: break # Finish stream event, which contains finish reason else: finish_reason = "length" if stopped: finish_reason = "stop" yield { "text": output, "logprobs": ret_logprobs, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, "finish_reason": finish_reason, } # Clean del past_key_values, out gc.collect() torch.cuda.empty_cache() if device == "xpu": torch.xpu.empty_cache() if device == "npu": torch.npu.empty_cache() class ChatIO(abc.ABC): @abc.abstractmethod def prompt_for_input(self, role: str) -> str: """Prompt for input from a role.""" @abc.abstractmethod def prompt_for_output(self, role: str): """Prompt for output from a role.""" @abc.abstractmethod def stream_output(self, output_stream): """Stream output.""" @abc.abstractmethod def print_output(self, text: str): """Print output.""" # adding RAG component here # from langchain_community.vectorstores import Chroma # from langchain_community.embeddings import SentenceTransformerEmbeddings # def chroma_search(): # # directory = "UAE_Docs_Embeddings" # directory = "/mnt/beegfs/fahad.khan/GeoMinGPT/VectorDB/UAE_Specific_Docs_Embeddings" # embeddings = SentenceTransformerEmbeddings(model_name ="sentence-transformers/all-MiniLM-L6-v2") # vectorDB = Chroma(persist_directory=directory, embedding_function=embeddings) # return vectorDB # vectorDB = chroma_search() from ragatouille import RAGPretrainedModel RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") RAG = RAGPretrainedModel.from_index('/mnt/beegfs/fahad.khan/GeoMinGPT/RAGatouille/examples/.ragatouille/colbert/indexes/UAE_Documents') def chat_loop( model_path: str, device: str, num_gpus: int, max_gpu_memory: str, dtype: Optional[torch.dtype], load_8bit: bool, cpu_offloading: bool, conv_template: Optional[str], conv_system_msg: Optional[str], temperature: float, repetition_penalty: float, max_new_tokens: int, chatio: ChatIO, gptq_config: Optional[GptqConfig] = None, awq_config: Optional[AWQConfig] = None, exllama_config: Optional[ExllamaConfig] = None, xft_config: Optional[XftConfig] = None, revision: str = "main", judge_sent_end: bool = True, debug: bool = True, history: bool = True, ): # Model model, tokenizer = load_model( model_path, device=device, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory, dtype=dtype, load_8bit=load_8bit, cpu_offloading=cpu_offloading, gptq_config=gptq_config, awq_config=awq_config, exllama_config=exllama_config, xft_config=xft_config, revision=revision, debug=debug, ) generate_stream_func = get_generate_stream_function(model, model_path) model_type = str(type(model)).lower() is_t5 = "t5" in model_type is_codet5p = "codet5p" in model_type is_xft = "xft" in model_type # Hardcode T5's default repetition penalty to be 1.2 if is_t5 and repetition_penalty == 1.0: repetition_penalty = 1.2 # Set context length context_len = get_context_length(model.config) # Chat def new_chat(): if conv_template: conv = get_conv_template(conv_template) else: conv = get_conversation_template(model_path) if conv_system_msg is not None: conv.set_system_message(conv_system_msg) return conv def reload_conv(conv): """ Reprints the conversation from the start. """ for message in conv.messages[conv.offset :]: chatio.prompt_for_output(message[0]) chatio.print_output(message[1]) conv = None #writing custom code here import json from tqdm import tqdm import random combined_data = [] data_path = "/mnt/beegfs/fahad.khan/GeoMinGPT/Evaluation/Sample_600_UAE_qs.json" with open(data_path, "r") as file: for line in file: d = json.loads(line) combined_data.append(d) geology_mineral_prompts = [ "You are a geology and mineral expert, provide detailed and insightful answers to the user's questions.", "As an expert in geology and minerals, give comprehensive and expert answers to the user's questions.", "You are an experienced Geologist specializing in minerals, provide detailed and informative answers to the user's questions.", "Your role as a geology and mineral expert is to provide detailed and informative answers to the user's questions. Ensure that your responses are helpful and relevant to the topic.", "As a geology and mineral expert, your task is to give comprehensive and expert answers to the user's questions. Please be detailed and informative in your responses.", "Your duty as a geology and mineral expert is to offer valuable and informative answers to the user's questions. Please ensure that your responses are relevant and helpful.", "As a geology and mineral expert, it is important that you provide detailed and insightful responses to the user's questions. Please be as helpful and informative as possible.", "Your responsibility as a geology and mineral expert is to offer detailed and informative answers to the user's questions. Please make sure your responses are relevant and helpful.", "As a geology and mineral expert, your task is to provide the user with helpful and informative answers to their questions. Please be as detailed and thorough as possible.", "Your job as a geology and mineral expert is to provide the user with comprehensive and informative answers to their questions. Please make sure your responses are relevant and helpful.", "As a geology and mineral expert, it is crucial that you offer detailed and insightful answers to the user's questions. Please ensure that your responses are informative and helpful.", "Your role as a geology and mineral expert is to provide the user with valuable and informative answers to their questions. Please make sure your responses are detailed and relevant.", "As a geology and mineral expert, your responsibility is to offer the user helpful and informative answers to their questions. Please be as detailed and thorough as possible.", ] for data in tqdm(combined_data): questions = data["Sample_qs"] # random.shuffle(questions) for index, question in enumerate(questions): print(f"\n----------------------------\n{index}\n----------------------------\n") selected_prompt = random.choice(geology_mineral_prompts) if not history or not conv: conv = new_chat() try: # inp = chatio.prompt_for_input(conv.roles[0]) # inp = f"{selected_prompt}\nQuestion: {question} Explain the response in detail, considering that you are geology and mineral expert. Provide examples or evidence to support your response. Discuss the sub domains and analyze the all implications. Your response should be expressive and limited to 1000 words" result_docs = RAG.search(query=question, k=1) full_prompt = "\nUsing the below information as a reference, answer the following question.\n### Reference Data: {}\n".format(result_docs[0]['content']) + \ "---------------------\n" + \ "\n###Question: " + \ "{}".format(question) score = result_docs[0]['score'] if score >10: inp = f"{selected_prompt}\n{full_prompt}" else : inp = f"{selected_prompt}\nQuestion: {question}" except EOFError: inp = "" conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() # print(prompt) if is_codet5p: # codet5p is a code completion model. prompt = inp gen_params = { "model": model_path, "prompt": prompt, "temperature": temperature, "repetition_penalty": repetition_penalty, "max_new_tokens": max_new_tokens, "stop": conv.stop_str, "stop_token_ids": conv.stop_token_ids, "echo": False, } try: chatio.prompt_for_output(conv.roles[1]) output_stream = generate_stream_func( model, tokenizer, gen_params, device, context_len=context_len, judge_sent_end=judge_sent_end, ) t = time.time() outputs = chatio.stream_output(output_stream) duration = time.time() - t conv.update_last_message(outputs.strip()) output = conv.messages[-2:] json_obj = { "Question": question, "Context_Ques": output[0][1], "Answer": output[1][1] } with open("/mnt/beegfs/fahad.khan/GeoMinGPT/Evaluation/UAE_Eval_Colbert_600qs.json","a") as file: json.dump(json_obj, file, indent=None) file.write('\n') if debug: num_tokens = len(tokenizer.encode(outputs)) msg = { "conv_template": conv.name, "prompt": prompt, "outputs": outputs, "speed (token/s)": round(num_tokens / duration, 2), } print(f"\n{msg}\n") except KeyboardInterrupt: print("stopped generation.") # If generation didn't finish if conv.messages[-1][1] is None: conv.messages.pop() # Remove last user message, so there isn't a double up if conv.messages[-1][0] == conv.roles[0]: conv.messages.pop() reload_conv(conv)