"
+
+super_source_prefix = f""" '):
+ prompt = prompt[:-4]
+ prompt = prompt.replace(' ".join(
+ answer_sources)
+ if verbose:
+ if int(t_run):
+ sorted_sources_urls += 'Total Time: %d [s] ' % t_run
+ if count_input_tokens and count_output_tokens:
+ sorted_sources_urls += 'Input Tokens: %s | Output Tokens: %d ' % (
+ count_input_tokens, count_output_tokens)
+ sorted_sources_urls += f"
+ Sources:
+ {0}
+ {0}
+ Exceptions: DISCLAIMERS: Sources
Sources [Score | Link]:"""
+super_source_postfix = f"""End Sources
', chat_turn_sep)
+ if not prompt.endswith(chat_turn_sep):
+ prompt += chat_turn_sep
+ # most recent first, add older if can
+ # only include desired chat history
+ if len(prompt + context1) > max_prompt_length:
+ break
+ context1 += prompt
+
+ _, pre_response, terminate_response, chat_sep, chat_turn_sep = \
+ generate_prompt({}, prompt_type, prompt_dict,
+ chat, reduced=True,
+ making_context=True,
+ system_prompt=system_prompt,
+ histi=-1)
+ if context1 and not context1.endswith(chat_turn_sep):
+ context1 += chat_turn_sep # ensure if terminates abruptly, then human continues on next line
+ return context1
+
+
+def get_limited_prompt(instruction,
+ iinput,
+ tokenizer,
+ prompter=None,
+ inference_server=None,
+ prompt_type=None, prompt_dict=None, chat=False, max_new_tokens=None,
+ system_prompt='',
+ context='', chat_conversation=None, text_context_list=None,
+ keep_sources_in_context=False,
+ model_max_length=None, memory_restriction_level=0,
+ langchain_mode=None, add_chat_history_to_context=True,
+ verbose=False,
+ doc_importance=0.5,
+ min_max_new_tokens=256,
+ ):
+ if prompter:
+ prompt_type = prompter.prompt_type
+ prompt_dict = prompter.prompt_dict
+ chat = prompter.chat
+ stream_output = prompter.stream_output
+ system_prompt = prompter.system_prompt
+
+ # merge handles if chat_conversation is None
+ history = []
+ history = merge_chat_conversation_history(chat_conversation, history)
+ history_to_context_func = functools.partial(history_to_context,
+ langchain_mode=langchain_mode,
+ add_chat_history_to_context=add_chat_history_to_context,
+ prompt_type=prompt_type,
+ prompt_dict=prompt_dict,
+ chat=chat,
+ model_max_length=model_max_length,
+ memory_restriction_level=memory_restriction_level,
+ keep_sources_in_context=keep_sources_in_context,
+ system_prompt=system_prompt)
+ context2 = history_to_context_func(history)
+ context1 = context
+ if context1 is None:
+ context1 = ''
+
+ from h2oai_pipeline import H2OTextGenerationPipeline
+ data_point_just_instruction = dict(context='', instruction=instruction, input='')
+ prompt_just_instruction = prompter.generate_prompt(data_point_just_instruction)
+ instruction, num_instruction_tokens = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer)
+ num_instruction_tokens_real = get_token_count(prompt_just_instruction, tokenizer)
+ num_instruction_tokens += (num_instruction_tokens_real - num_instruction_tokens)
+
+ context1, num_context1_tokens = H2OTextGenerationPipeline.limit_prompt(context1, tokenizer)
+ context2, num_context2_tokens = H2OTextGenerationPipeline.limit_prompt(context2, tokenizer)
+ iinput, num_iinput_tokens = H2OTextGenerationPipeline.limit_prompt(iinput, tokenizer)
+ if text_context_list is None:
+ text_context_list = []
+ num_doc_tokens = sum([get_token_count(x + '\n\n', tokenizer) for x in text_context_list])
+
+ num_prompt_tokens0 = (num_instruction_tokens or 0) + \
+ (num_context1_tokens or 0) + \
+ (num_context2_tokens or 0) + \
+ (num_iinput_tokens or 0) + \
+ (num_doc_tokens or 0)
+
+ # go down to no less than 256, about 1 paragraph
+ # use max_new_tokens before use num_prompt_tokens0 else would be negative or ~0
+ min_max_new_tokens = min(min_max_new_tokens, max_new_tokens)
+ # by default assume can handle all chat and docs
+ chat_index = 0
+
+ # allowed residual is either half of what is allowed if doc exceeds half, or is rest of what doc didn't consume
+ num_non_doc_tokens = num_prompt_tokens0 - num_doc_tokens
+ # to doc first then non-doc, shouldn't matter much either way
+ doc_max_length = max(model_max_length - num_non_doc_tokens, doc_importance * model_max_length)
+ top_k_docs, one_doc_size, num_doc_tokens = get_docs_tokens(tokenizer, text_context_list=text_context_list,
+ max_input_tokens=doc_max_length)
+ non_doc_max_length = max(model_max_length - num_doc_tokens, (1.0 - doc_importance) * model_max_length)
+
+ if num_non_doc_tokens > non_doc_max_length:
+ # need to limit in some way, keep portion of history but all of context and instruction
+ # 1) drop iinput (unusual to include anyways)
+ # 2) reduce history
+ # 3) reduce context1
+ # 4) limit instruction so will fit
+ diff1 = non_doc_max_length - (
+ num_instruction_tokens + num_context1_tokens + num_context2_tokens + min_max_new_tokens)
+ diff2 = non_doc_max_length - (num_instruction_tokens + num_context1_tokens + min_max_new_tokens)
+ diff3 = non_doc_max_length - (num_instruction_tokens + min_max_new_tokens)
+ diff4 = non_doc_max_length - min_max_new_tokens
+ if diff1 > 0:
+ # then should be able to do #1
+ iinput = ''
+ num_iinput_tokens = 0
+ elif diff2 > 0 > diff1:
+ # then may be able to do #1 + #2
+ iinput = ''
+ num_iinput_tokens = 0
+ chat_index_final = len(history)
+ for chat_index in range(len(history)):
+ # NOTE: history and chat_conversation are older for first entries
+ # FIXME: This is a slow for many short conversations
+ context2 = history_to_context_func(history[chat_index:])
+ num_context2_tokens = get_token_count(context2, tokenizer)
+ diff1 = non_doc_max_length - (
+ num_instruction_tokens + num_context1_tokens + num_context2_tokens + min_max_new_tokens)
+ if diff1 > 0:
+ chat_index_final = chat_index
+ if verbose:
+ print("chat_conversation used %d out of %d" % (chat_index, len(history)), flush=True)
+ break
+ chat_index = chat_index_final # i.e. if chat_index == len(history), then nothing can be consumed
+ elif diff3 > 0 > diff2:
+ # then may be able to do #1 + #2 + #3
+ iinput = ''
+ num_iinput_tokens = 0
+ context2 = ''
+ num_context2_tokens = 0
+ context1, num_context1_tokens = H2OTextGenerationPipeline.limit_prompt(context1, tokenizer,
+ max_prompt_length=diff3)
+ if num_context1_tokens <= diff3:
+ pass
+ else:
+ print("failed to reduce", flush=True)
+ else:
+ # then must be able to do #1 + #2 + #3 + #4
+ iinput = ''
+ num_iinput_tokens = 0
+ context2 = ''
+ num_context2_tokens = 0
+ context1 = ''
+ num_context1_tokens = 0
+ # diff4 accounts for real prompting for instruction
+ # FIXME: history_to_context could include instruction, in case system prompt long, we overcount and could have more free tokens
+ instruction, num_instruction_tokens = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer,
+ max_prompt_length=diff4)
+ # get actual tokens
+ data_point_just_instruction = dict(context='', instruction=instruction, input='')
+ prompt_just_instruction = prompter.generate_prompt(data_point_just_instruction)
+ num_instruction_tokens_real = get_token_count(prompt_just_instruction, tokenizer)
+ num_instruction_tokens += (num_instruction_tokens_real - num_instruction_tokens)
+
+ # update full context
+ context = context1 + context2
+ # update token counts (docs + non-docs, all tokens)
+ num_prompt_tokens = (num_instruction_tokens or 0) + \
+ (num_context1_tokens or 0) + \
+ (num_context2_tokens or 0) + \
+ (num_iinput_tokens or 0) + \
+ (num_doc_tokens or 0)
+
+ # update max_new_tokens
+ if inference_server and inference_server.startswith('http'):
+ # assume TGI/Gradio setup to consume tokens and have long output too, even if exceeds model capacity.
+ pass
+ else:
+ # limit so max_new_tokens = prompt + new < max
+ # otherwise model can fail etc. e.g. for distilgpt2 asking for 1024 tokens is enough to fail if prompt=1 token
+ max_new_tokens = min(max_new_tokens, model_max_length - num_prompt_tokens)
+
+ if prompter is None:
+ # get prompter
+ debug = False
+ stream_output = False # doesn't matter
+ prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output,
+ system_prompt=system_prompt)
+
+ data_point = dict(context=context, instruction=instruction, input=iinput)
+ # handle promptA/promptB addition if really from history.
+ # if not from history, then reduced=False inside correct
+ # if mixed, then no specific correct thing to do, so treat like history and promptA/B will come first still
+ context_from_history = len(history) > 0 and len(context1) > 0
+ prompt = prompter.generate_prompt(data_point, context_from_history=context_from_history)
+ num_prompt_tokens_actual = get_token_count(prompt, tokenizer)
+
+ return prompt, \
+ instruction, iinput, context, \
+ num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
+ chat_index, top_k_docs, one_doc_size
+
+
+def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None):
+ if text_context_list is None or len(text_context_list) == 0:
+ return 0, None, 0
+ if max_input_tokens is None:
+ max_input_tokens = tokenizer.model_max_length
+ tokens = [get_token_count(x + '\n\n', tokenizer) for x in text_context_list]
+ tokens_cumsum = np.cumsum(tokens)
+ where_res = np.where(tokens_cumsum < max_input_tokens)[0]
+ # if below condition fails, then keep top_k_docs=-1 and trigger special handling next
+ if where_res.shape[0] > 0:
+ top_k_docs = 1 + where_res[-1]
+ one_doc_size = None
+ num_doc_tokens = tokens_cumsum[top_k_docs - 1] # by index
+ else:
+ # if here, means 0 and just do best with 1 doc
+ top_k_docs = 1
+ text_context_list = text_context_list[:top_k_docs]
+ # critical protection
+ from src.h2oai_pipeline import H2OTextGenerationPipeline
+ doc_content = text_context_list[0]
+ doc_content, new_tokens0 = H2OTextGenerationPipeline.limit_prompt(doc_content,
+ tokenizer,
+ max_prompt_length=max_input_tokens)
+ text_context_list[0] = doc_content
+ one_doc_size = len(doc_content)
+ num_doc_tokens = get_token_count(doc_content + '\n\n', tokenizer)
+ print("Unexpected large chunks and can't add to context, will add 1 anyways. Tokens %s -> %s" % (
+ tokens[0], new_tokens0), flush=True)
+ return top_k_docs, one_doc_size, num_doc_tokens
+
+
+def entrypoint_main():
+ """
+ Examples:
+
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
+ python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
+
+ # generate without lora weights, no prompt
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
+
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
+ # OpenChatKit settings:
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
+
+ python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
+ python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
+ python generate.py --base_model='philschmid/bart-large-cnn-samsum'
+ python generate.py --base_model='philschmid/flan-t5-base-samsum'
+ python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
+
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
+
+ must have 4*48GB GPU and run without 8bit in order for sharding to work with use_gpu_id=False
+ can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
+ python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --use_gpu_id=False --prompt_type='human_bot'
+
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
+ """
+ H2O_Fire(main)
+
+
+if __name__ == "__main__":
+ entrypoint_main()
diff --git a/src/gpt4all_llm.py b/src/gpt4all_llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f973d42a7775d7f3e5a9c27e725429ca6d607e1
--- /dev/null
+++ b/src/gpt4all_llm.py
@@ -0,0 +1,403 @@
+import inspect
+import os
+from typing import Dict, Any, Optional, List, Iterator
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+from langchain.schema.output import GenerationChunk
+from pydantic import root_validator
+from langchain.llms import gpt4all
+
+from utils import FakeTokenizer, get_ngpus_vis, url_alive, download_simple
+
+
+def get_model_tokenizer_gpt4all(base_model, n_jobs=None, max_seq_len=None, llamacpp_dict=None):
+ assert llamacpp_dict is not None
+ # defaults (some of these are generation parameters, so need to be passed in at generation time)
+ model_name = base_model.lower()
+ model = get_llm_gpt4all(model_name, model=None,
+ # max_new_tokens=max_new_tokens,
+ # temperature=temperature,
+ # repetition_penalty=repetition_penalty,
+ # top_k=top_k,
+ # top_p=top_p,
+ # callbacks=callbacks,
+ n_jobs=n_jobs,
+ # verbose=verbose,
+ # streaming=stream_output,
+ # prompter=prompter,
+ # context=context,
+ # iinput=iinput,
+ inner_class=True,
+ max_seq_len=max_seq_len,
+ llamacpp_dict=llamacpp_dict,
+ )
+ return model, FakeTokenizer(), 'cpu'
+
+
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+
+
+class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Run on new LLM token. Only available when streaming is enabled."""
+ # streaming to std already occurs without this
+ # sys.stdout.write(token)
+ # sys.stdout.flush()
+ pass
+
+
+def get_model_kwargs(llamacpp_dict, default_kwargs, cls, exclude_list=[]):
+ # default from class
+ model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list}
+ # from our defaults
+ model_kwargs.update(default_kwargs)
+ # from user defaults
+ model_kwargs.update(llamacpp_dict)
+ # ensure only valid keys
+ func_names = list(inspect.signature(cls).parameters)
+ model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
+ # make int or float if can to satisfy types for class
+ for k, v in model_kwargs.items():
+ try:
+ if float(v) == int(v):
+ model_kwargs[k] = int(v)
+ else:
+ model_kwargs[k] = float(v)
+ except:
+ pass
+ return model_kwargs
+
+
+def get_gpt4all_default_kwargs(max_new_tokens=256,
+ temperature=0.1,
+ repetition_penalty=1.0,
+ top_k=40,
+ top_p=0.7,
+ n_jobs=None,
+ verbose=False,
+ max_seq_len=None,
+ ):
+ if n_jobs in [None, -1]:
+ n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count()//2)))
+ n_jobs = max(1, min(20, n_jobs)) # hurts beyond some point
+ n_gpus = get_ngpus_vis()
+ default_kwargs = dict(context_erase=0.5,
+ n_batch=1,
+ max_tokens=max_seq_len - max_new_tokens,
+ n_predict=max_new_tokens,
+ repeat_last_n=64 if repetition_penalty != 1.0 else 0,
+ repeat_penalty=repetition_penalty,
+ temp=temperature,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ use_mlock=True,
+ n_ctx=max_seq_len,
+ n_threads=n_jobs,
+ verbose=verbose)
+ if n_gpus != 0:
+ default_kwargs.update(dict(n_gpu_layers=100))
+ return default_kwargs
+
+
+def get_llm_gpt4all(model_name,
+ model=None,
+ max_new_tokens=256,
+ temperature=0.1,
+ repetition_penalty=1.0,
+ top_k=40,
+ top_p=0.7,
+ streaming=False,
+ callbacks=None,
+ prompter=None,
+ context='',
+ iinput='',
+ n_jobs=None,
+ verbose=False,
+ inner_class=False,
+ max_seq_len=None,
+ llamacpp_dict=None,
+ ):
+ if not inner_class:
+ assert prompter is not None
+
+ default_kwargs = \
+ get_gpt4all_default_kwargs(max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ repetition_penalty=repetition_penalty,
+ top_k=top_k,
+ top_p=top_p,
+ n_jobs=n_jobs,
+ verbose=verbose,
+ max_seq_len=max_seq_len,
+ )
+ if model_name == 'llama':
+ cls = H2OLlamaCpp
+ if model is None:
+ llamacpp_dict = llamacpp_dict.copy()
+ model_path = llamacpp_dict.pop('model_path_llama')
+ if os.path.isfile(os.path.basename(model_path)):
+ # e.g. if offline but previously downloaded
+ model_path = os.path.basename(model_path)
+ elif url_alive(model_path):
+ # online
+ ggml_path = os.getenv('GGML_PATH')
+ dest = os.path.join(ggml_path, os.path.basename(model_path)) if ggml_path else None
+ model_path = download_simple(model_path, dest=dest)
+ else:
+ model_path = model
+ model_kwargs = get_model_kwargs(llamacpp_dict, default_kwargs, cls, exclude_list=['lc_kwargs'])
+ model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
+ prompter=prompter, context=context, iinput=iinput))
+
+ # migration to new langchain fix:
+ odd_keys = ['model_kwargs', 'grammar_path', 'grammar']
+ for key in odd_keys:
+ model_kwargs.pop(key, None)
+
+ llm = cls(**model_kwargs)
+ llm.client.verbose = verbose
+ inner_model = llm.client
+ elif model_name == 'gpt4all_llama':
+ cls = H2OGPT4All
+ if model is None:
+ llamacpp_dict = llamacpp_dict.copy()
+ model_path = llamacpp_dict.pop('model_name_gpt4all_llama')
+ if url_alive(model_path):
+ # online
+ ggml_path = os.getenv('GGML_PATH')
+ dest = os.path.join(ggml_path, os.path.basename(model_path)) if ggml_path else None
+ model_path = download_simple(model_path, dest=dest)
+ else:
+ model_path = model
+ model_kwargs = get_model_kwargs(llamacpp_dict, default_kwargs, cls, exclude_list=['lc_kwargs'])
+ model_kwargs.update(
+ dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
+ prompter=prompter, context=context, iinput=iinput))
+ llm = cls(**model_kwargs)
+ inner_model = llm.client
+ elif model_name == 'gptj':
+ cls = H2OGPT4All
+ if model is None:
+ llamacpp_dict = llamacpp_dict.copy()
+ model_path = llamacpp_dict.pop('model_name_gptj') if model is None else model
+ if url_alive(model_path):
+ ggml_path = os.getenv('GGML_PATH')
+ dest = os.path.join(ggml_path, os.path.basename(model_path)) if ggml_path else None
+ model_path = download_simple(model_path, dest=dest)
+ else:
+ model_path = model
+ model_kwargs = get_model_kwargs(llamacpp_dict, default_kwargs, cls, exclude_list=['lc_kwargs'])
+ model_kwargs.update(
+ dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
+ prompter=prompter, context=context, iinput=iinput))
+ llm = cls(**model_kwargs)
+ inner_model = llm.client
+ else:
+ raise RuntimeError("No such model_name %s" % model_name)
+ if inner_class:
+ return inner_model
+ else:
+ return llm
+
+
+class H2OGPT4All(gpt4all.GPT4All):
+ model: Any
+ prompter: Any
+ context: Any = ''
+ iinput: Any = ''
+ """Path to the pre-trained GPT4All model file."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that the python package exists in the environment."""
+ try:
+ if isinstance(values["model"], str):
+ from gpt4all import GPT4All as GPT4AllModel
+
+ full_path = values["model"]
+ model_path, delimiter, model_name = full_path.rpartition("/")
+ model_path += delimiter
+
+ values["client"] = GPT4AllModel(
+ model_name=model_name,
+ model_path=model_path or None,
+ model_type=values["backend"],
+ allow_download=True,
+ )
+ if values["n_threads"] is not None:
+ # set n_threads
+ values["client"].model.set_thread_count(values["n_threads"])
+ else:
+ values["client"] = values["model"]
+ if values["n_threads"] is not None:
+ # set n_threads
+ values["client"].model.set_thread_count(values["n_threads"])
+ try:
+ values["backend"] = values["client"].model_type
+ except AttributeError:
+ # The below is for compatibility with GPT4All Python bindings <= 0.2.3.
+ values["backend"] = values["client"].model.model_type
+
+ except ImportError:
+ raise ValueError(
+ "Could not import gpt4all python package. "
+ "Please install it with `pip install gpt4all`."
+ )
+ return values
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs,
+ ) -> str:
+ # Roughly 4 chars per token if natural language
+ n_ctx = 2048
+ prompt = prompt[-self.max_tokens * 4:]
+
+ # use instruct prompting
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
+ prompt = self.prompter.generate_prompt(data_point)
+
+ verbose = False
+ if verbose:
+ print("_call prompt: %s" % prompt, flush=True)
+ # FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
+ return super()._call(prompt, stop=stop, run_manager=run_manager)
+
+ # FIXME: Unsure what uses
+ #def get_token_ids(self, text: str) -> List[int]:
+ # return self.client.tokenize(b" " + text.encode("utf-8"))
+
+
+from langchain.llms import LlamaCpp
+
+
+class H2OLlamaCpp(LlamaCpp):
+ model_path: Any
+ prompter: Any
+ context: Any
+ iinput: Any
+ """Path to the pre-trained GPT4All model file."""
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that llama-cpp-python library is installed."""
+ if isinstance(values["model_path"], str):
+ model_path = values["model_path"]
+ model_param_names = [
+ "lora_path",
+ "lora_base",
+ "n_ctx",
+ "n_parts",
+ "seed",
+ "f16_kv",
+ "logits_all",
+ "vocab_only",
+ "use_mlock",
+ "n_threads",
+ "n_batch",
+ "use_mmap",
+ "last_n_tokens_size",
+ ]
+ model_params = {k: values[k] for k in model_param_names}
+ # For backwards compatibility, only include if non-null.
+ if values["n_gpu_layers"] is not None:
+ model_params["n_gpu_layers"] = values["n_gpu_layers"]
+
+ try:
+ try:
+ from llama_cpp import Llama
+ except ImportError:
+ from llama_cpp_cuda import Llama
+
+ values["client"] = Llama(model_path, **model_params)
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import llama-cpp-python library. "
+ "Please install the llama-cpp-python library to "
+ "use this embedding model: pip install llama-cpp-python"
+ )
+ except Exception as e:
+ raise ValueError(
+ f"Could not load Llama model from path: {model_path}. "
+ f"Received error {e}"
+ )
+ else:
+ values["client"] = values["model_path"]
+ return values
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs,
+ ) -> str:
+ verbose = False
+ # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
+ # still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
+ prompt = prompt[-self.n_ctx * 4:]
+ prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
+ num_prompt_tokens = len(prompt_tokens)
+ if num_prompt_tokens > self.n_ctx:
+ # conservative by using int()
+ chars_per_token = int(len(prompt) / num_prompt_tokens)
+ prompt = prompt[-self.n_ctx * chars_per_token:]
+ if verbose:
+ print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
+ prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
+ num_prompt_tokens2 = len(prompt_tokens2)
+ print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
+
+ # use instruct prompting
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
+ prompt = self.prompter.generate_prompt(data_point)
+
+ if verbose:
+ print("_call prompt: %s" % prompt, flush=True)
+
+ if self.streaming:
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
+ text = ""
+ for token in self.stream(input=prompt, stop=stop):
+ # for token in self.stream(input=prompt, stop=stop, run_manager=run_manager):
+ text_chunk = token # ["choices"][0]["text"]
+ # self.stream already calls text_callback
+ # if text_callback:
+ # text_callback(text_chunk)
+ text += text_chunk
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
+ return text[len(prompt):]
+ else:
+ params = self._get_parameters(stop)
+ params = {**params, **kwargs}
+ result = self.client(prompt=prompt, **params)
+ return result["choices"][0]["text"]
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
+ logprobs = 0
+ chunk = GenerationChunk(
+ text=prompt,
+ generation_info={"logprobs": logprobs},
+ )
+ yield chunk
+ if run_manager:
+ run_manager.on_llm_new_token(
+ token=chunk.text, verbose=self.verbose, log_probs=logprobs
+ )
+ # actual new tokens
+ for chunk in super()._stream(prompt, stop=stop, run_manager=run_manager, **kwargs):
+ yield chunk
+
+ def get_token_ids(self, text: str) -> List[int]:
+ return self.client.tokenize(b" " + text.encode("utf-8"))
diff --git a/src/gpt_langchain.py b/src/gpt_langchain.py
new file mode 100644
index 0000000000000000000000000000000000000000..144d9ec5c3783430db8c0714828028137ceac94d
--- /dev/null
+++ b/src/gpt_langchain.py
@@ -0,0 +1,5394 @@
+import ast
+import asyncio
+import copy
+import functools
+import glob
+import gzip
+import inspect
+import json
+import os
+import pathlib
+import pickle
+import shutil
+import subprocess
+import tempfile
+import time
+import traceback
+import types
+import typing
+import urllib.error
+import uuid
+import zipfile
+from collections import defaultdict
+from datetime import datetime
+from functools import reduce
+from operator import concat
+import filelock
+import tabulate
+import yaml
+
+from joblib import delayed
+from langchain.callbacks import streaming_stdout
+from langchain.embeddings import HuggingFaceInstructEmbeddings
+from langchain.llms.huggingface_pipeline import VALID_TASKS
+from langchain.llms.utils import enforce_stop_tokens
+from langchain.schema import LLMResult, Generation
+from langchain.tools import PythonREPLTool
+from langchain.tools.json.tool import JsonSpec
+from tqdm import tqdm
+
+from src.db_utils import length_db1, set_dbid, set_userid, get_dbid, get_userid_direct, get_username_direct, \
+ set_userid_direct
+from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
+ get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
+ have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \
+ get_list_or_str, have_pillow, only_selenium, only_playwright, only_unstructured_urls, get_sha, get_short_name, \
+ get_accordion, have_jq, get_doc, get_source, have_chromamigdb, get_token_count, reverse_ucurve_list
+from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
+ LangChainAction, LangChainMode, DocumentChoice, LangChainTypes, font_size, head_acc, super_source_prefix, \
+ super_source_postfix, langchain_modes_intrinsic, get_langchain_prompts, LangChainAgent
+from evaluate_params import gen_hyper, gen_hyper0
+from gen import get_model, SEED, get_limited_prompt, get_docs_tokens
+from prompter import non_hf_types, PromptType, Prompter
+from src.serpapi import H2OSerpAPIWrapper
+from utils_langchain import StreamingGradioCallbackHandler, _chunk_sources, _add_meta, add_parser, fix_json_meta
+
+import_matplotlib()
+
+import numpy as np
+import pandas as pd
+import requests
+from langchain.chains.qa_with_sources import load_qa_with_sources_chain
+# , GCSDirectoryLoader, GCSFileLoader
+# , OutlookMessageLoader # GPL3
+# ImageCaptionLoader, # use our own wrapper
+# ReadTheDocsLoader, # no special file, some path, so have to give as special option
+from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
+ UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
+ EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
+ UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader, \
+ UnstructuredExcelLoader, JSONLoader
+from langchain.text_splitter import Language
+from langchain.chains.question_answering import load_qa_chain
+from langchain.docstore.document import Document
+from langchain import PromptTemplate, HuggingFaceTextGenInference, HuggingFacePipeline
+from langchain.vectorstores import Chroma
+from chromamig import ChromaMig
+
+
+def split_list(input_list, split_size):
+ for i in range(0, len(input_list), split_size):
+ yield input_list[i:i + split_size]
+
+
+def get_db(sources, use_openai_embedding=False, db_type='faiss',
+ persist_directory=None, load_db_if_exists=True,
+ langchain_mode='notset',
+ langchain_mode_paths={},
+ langchain_mode_types={},
+ collection_name=None,
+ hf_embedding_model=None,
+ migrate_embedding_model=False,
+ auto_migrate_db=False,
+ n_jobs=-1):
+ if not sources:
+ return None
+ user_path = langchain_mode_paths.get(langchain_mode)
+ if persist_directory is None:
+ langchain_type = langchain_mode_types.get(langchain_mode, LangChainTypes.EITHER.value)
+ persist_directory, langchain_type = get_persist_directory(langchain_mode, langchain_type=langchain_type)
+ langchain_mode_types[langchain_mode] = langchain_type
+ assert hf_embedding_model is not None
+
+ # get freshly-determined embedding model
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
+ assert collection_name is not None or langchain_mode != 'notset'
+ if collection_name is None:
+ collection_name = langchain_mode.replace(' ', '_')
+
+ # Create vector database
+ if db_type == 'faiss':
+ from langchain.vectorstores import FAISS
+ db = FAISS.from_documents(sources, embedding)
+ elif db_type == 'weaviate':
+ import weaviate
+ from weaviate.embedded import EmbeddedOptions
+ from langchain.vectorstores import Weaviate
+
+ if os.getenv('WEAVIATE_URL', None):
+ client = _create_local_weaviate_client()
+ else:
+ client = weaviate.Client(
+ embedded_options=EmbeddedOptions(persistence_data_path=persist_directory)
+ )
+ index_name = collection_name.capitalize()
+ db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
+ index_name=index_name)
+ elif db_type in ['chroma', 'chroma_old']:
+ assert persist_directory is not None
+ # use_base already handled when making persist_directory, unless was passed into get_db()
+ makedirs(persist_directory, exist_ok=True)
+
+ # see if already actually have persistent db, and deal with possible changes in embedding
+ db, use_openai_embedding, hf_embedding_model = \
+ get_existing_db(None, persist_directory, load_db_if_exists, db_type,
+ use_openai_embedding,
+ langchain_mode, langchain_mode_paths, langchain_mode_types,
+ hf_embedding_model, migrate_embedding_model, auto_migrate_db,
+ verbose=False,
+ n_jobs=n_jobs)
+ if db is None:
+ import logging
+ logging.getLogger("chromadb").setLevel(logging.ERROR)
+ if db_type == 'chroma':
+ from chromadb.config import Settings
+ settings_extra_kwargs = dict(is_persistent=True)
+ else:
+ from chromamigdb.config import Settings
+ settings_extra_kwargs = dict(chroma_db_impl="duckdb+parquet")
+ client_settings = Settings(anonymized_telemetry=False,
+ persist_directory=persist_directory,
+ **settings_extra_kwargs)
+ if n_jobs in [None, -1]:
+ n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count() // 2)))
+ num_threads = max(1, min(n_jobs, 8))
+ else:
+ num_threads = max(1, n_jobs)
+ collection_metadata = {"hnsw:num_threads": num_threads}
+ from_kwargs = dict(embedding=embedding,
+ persist_directory=persist_directory,
+ collection_name=collection_name,
+ client_settings=client_settings,
+ collection_metadata=collection_metadata)
+ if db_type == 'chroma':
+ import chromadb
+ api = chromadb.PersistentClient(path=persist_directory)
+ max_batch_size = api._producer.max_batch_size
+ sources_batches = split_list(sources, max_batch_size)
+ for sources_batch in sources_batches:
+ db = Chroma.from_documents(documents=sources_batch, **from_kwargs)
+ db.persist()
+ else:
+ db = ChromaMig.from_documents(documents=sources, **from_kwargs)
+ clear_embedding(db)
+ save_embed(db, use_openai_embedding, hf_embedding_model)
+ else:
+ # then just add
+ # doesn't check or change embedding, just saves it in case not saved yet, after persisting
+ db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
+ use_openai_embedding=use_openai_embedding,
+ hf_embedding_model=hf_embedding_model)
+ else:
+ raise RuntimeError("No such db_type=%s" % db_type)
+
+ # once here, db is not changing and embedding choices in calling functions does not matter
+ return db
+
+
+def _get_unique_sources_in_weaviate(db):
+ batch_size = 100
+ id_source_list = []
+ result = db._client.data_object.get(class_name=db._index_name, limit=batch_size)
+
+ while result['objects']:
+ id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']]
+ last_id = id_source_list[-1][0]
+ result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id)
+
+ unique_sources = {source for _, source in id_source_list}
+ return unique_sources
+
+
+def del_from_db(db, sources, db_type=None):
+ if db_type in ['chroma', 'chroma_old'] and db is not None:
+ # sources should be list of x.metadata['source'] from document metadatas
+ if isinstance(sources, str):
+ sources = [sources]
+ else:
+ assert isinstance(sources, (list, tuple, types.GeneratorType))
+ metadatas = set(sources)
+ client_collection = db._client.get_collection(name=db._collection.name,
+ embedding_function=db._collection._embedding_function)
+ for source in metadatas:
+ meta = dict(source=source)
+ try:
+ client_collection.delete(where=meta)
+ except KeyError:
+ pass
+
+
+def add_to_db(db, sources, db_type='faiss',
+ avoid_dup_by_file=False,
+ avoid_dup_by_content=True,
+ use_openai_embedding=False,
+ hf_embedding_model=None):
+ assert hf_embedding_model is not None
+ num_new_sources = len(sources)
+ if not sources:
+ return db, num_new_sources, []
+ if db_type == 'faiss':
+ db.add_documents(sources)
+ elif db_type == 'weaviate':
+ # FIXME: only control by file name, not hash yet
+ if avoid_dup_by_file or avoid_dup_by_content:
+ unique_sources = _get_unique_sources_in_weaviate(db)
+ sources = [x for x in sources if x.metadata['source'] not in unique_sources]
+ num_new_sources = len(sources)
+ if num_new_sources == 0:
+ return db, num_new_sources, []
+ db.add_documents(documents=sources)
+ elif db_type in ['chroma', 'chroma_old']:
+ collection = get_documents(db)
+ # files we already have:
+ metadata_files = set([x['source'] for x in collection['metadatas']])
+ if avoid_dup_by_file:
+ # Too weak in case file changed content, assume parent shouldn't pass true for this for now
+ raise RuntimeError("Not desired code path")
+ if avoid_dup_by_content:
+ # look at hash, instead of page_content
+ # migration: If no hash previously, avoid updating,
+ # since don't know if need to update and may be expensive to redo all unhashed files
+ metadata_hash_ids = set(
+ [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
+ # avoid sources with same hash
+ sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
+ num_nohash = len([x for x in sources if not x.metadata.get('hashid')])
+ print("Found %s new sources (%d have no hash in original source,"
+ " so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True)
+ # get new file names that match existing file names. delete existing files we are overridding
+ dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
+ print("Removing %s duplicate files from db because ingesting those as new documents" % len(
+ dup_metadata_files), flush=True)
+ client_collection = db._client.get_collection(name=db._collection.name,
+ embedding_function=db._collection._embedding_function)
+ for dup_file in dup_metadata_files:
+ dup_file_meta = dict(source=dup_file)
+ try:
+ client_collection.delete(where=dup_file_meta)
+ except KeyError:
+ pass
+ num_new_sources = len(sources)
+ if num_new_sources == 0:
+ return db, num_new_sources, []
+ if hasattr(db, '_persist_directory'):
+ print("Existing db, adding to %s" % db._persist_directory, flush=True)
+ # chroma only
+ lock_file = get_db_lock_file(db)
+ context = filelock.FileLock
+ else:
+ lock_file = None
+ context = NullContext
+ with context(lock_file):
+ # this is place where add to db, but others maybe accessing db, so lock access.
+ # else see RuntimeError: Index seems to be corrupted or unsupported
+ import chromadb
+ api = chromadb.PersistentClient(path=db._persist_directory)
+ max_batch_size = api._producer.max_batch_size
+ sources_batches = split_list(sources, max_batch_size)
+ for sources_batch in sources_batches:
+ db.add_documents(documents=sources_batch)
+ db.persist()
+ clear_embedding(db)
+ # save here is for migration, in case old db directory without embedding saved
+ save_embed(db, use_openai_embedding, hf_embedding_model)
+ else:
+ raise RuntimeError("No such db_type=%s" % db_type)
+
+ new_sources_metadata = [x.metadata for x in sources]
+
+ return db, num_new_sources, new_sources_metadata
+
+
+def create_or_update_db(db_type, persist_directory, collection_name,
+ user_path, langchain_type,
+ sources, use_openai_embedding, add_if_exists, verbose,
+ hf_embedding_model, migrate_embedding_model, auto_migrate_db,
+ n_jobs=-1):
+ if not os.path.isdir(persist_directory) or not add_if_exists:
+ if os.path.isdir(persist_directory):
+ if verbose:
+ print("Removing %s" % persist_directory, flush=True)
+ remove(persist_directory)
+ if verbose:
+ print("Generating db", flush=True)
+ if db_type == 'weaviate':
+ import weaviate
+ from weaviate.embedded import EmbeddedOptions
+
+ if os.getenv('WEAVIATE_URL', None):
+ client = _create_local_weaviate_client()
+ else:
+ client = weaviate.Client(
+ embedded_options=EmbeddedOptions(persistence_data_path=persist_directory)
+ )
+
+ index_name = collection_name.replace(' ', '_').capitalize()
+ if client.schema.exists(index_name) and not add_if_exists:
+ client.schema.delete_class(index_name)
+ if verbose:
+ print("Removing %s" % index_name, flush=True)
+ elif db_type in ['chroma', 'chroma_old']:
+ pass
+
+ if not add_if_exists:
+ if verbose:
+ print("Generating db", flush=True)
+ else:
+ if verbose:
+ print("Loading and updating db", flush=True)
+
+ db = get_db(sources,
+ use_openai_embedding=use_openai_embedding,
+ db_type=db_type,
+ persist_directory=persist_directory,
+ langchain_mode=collection_name,
+ langchain_mode_paths={collection_name: user_path},
+ langchain_mode_types={collection_name: langchain_type},
+ hf_embedding_model=hf_embedding_model,
+ migrate_embedding_model=migrate_embedding_model,
+ auto_migrate_db=auto_migrate_db,
+ n_jobs=n_jobs)
+
+ return db
+
+
+from langchain.embeddings import FakeEmbeddings
+
+
+class H2OFakeEmbeddings(FakeEmbeddings):
+ """Fake embedding model, but constant instead of random"""
+
+ size: int
+ """The size of the embedding vector."""
+
+ def _get_embedding(self) -> typing.List[float]:
+ return [1] * self.size
+
+ def embed_documents(self, texts: typing.List[str]) -> typing.List[typing.List[float]]:
+ return [self._get_embedding() for _ in texts]
+
+ def embed_query(self, text: str) -> typing.List[float]:
+ return self._get_embedding()
+
+
+def get_embedding(use_openai_embedding, hf_embedding_model=None, preload=False):
+ assert hf_embedding_model is not None
+ # Get embedding model
+ if use_openai_embedding:
+ assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
+ from langchain.embeddings import OpenAIEmbeddings
+ embedding = OpenAIEmbeddings(disallowed_special=())
+ elif hf_embedding_model == 'fake':
+ embedding = H2OFakeEmbeddings(size=1)
+ else:
+ if isinstance(hf_embedding_model, str):
+ pass
+ elif isinstance(hf_embedding_model, dict):
+ # embedding itself preloaded globally
+ return hf_embedding_model['model']
+ else:
+ # object
+ return hf_embedding_model
+ # to ensure can fork without deadlock
+ from langchain.embeddings import HuggingFaceEmbeddings
+
+ device, torch_dtype, context_class = get_device_dtype()
+ model_kwargs = dict(device=device)
+ if 'instructor' in hf_embedding_model:
+ encode_kwargs = {'normalize_embeddings': True}
+ embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model,
+ model_kwargs=model_kwargs,
+ encode_kwargs=encode_kwargs)
+ else:
+ embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
+ embedding.client.preload = preload
+ return embedding
+
+
+def get_answer_from_sources(chain, sources, question):
+ return chain(
+ {
+ "input_documents": sources,
+ "question": question,
+ },
+ return_only_outputs=True,
+ )["output_text"]
+
+
+"""Wrapper around Huggingface text generation inference API."""
+from functools import partial
+from typing import Any, Dict, List, Optional, Set, Iterable
+
+from pydantic import Extra, Field, root_validator
+
+from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
+from langchain.llms.base import LLM
+
+
+class GradioInference(LLM):
+ """
+ Gradio generation inference API.
+ """
+ inference_server_url: str = ""
+
+ temperature: float = 0.8
+ top_p: Optional[float] = 0.95
+ top_k: Optional[int] = None
+ num_beams: Optional[int] = 1
+ max_new_tokens: int = 512
+ min_new_tokens: int = 1
+ early_stopping: bool = False
+ max_time: int = 180
+ repetition_penalty: Optional[float] = None
+ num_return_sequences: Optional[int] = 1
+ do_sample: bool = False
+ chat_client: bool = False
+
+ return_full_text: bool = False
+ stream_output: bool = False
+ sanitize_bot_response: bool = False
+
+ prompter: Any = None
+ context: Any = ''
+ iinput: Any = ''
+ client: Any = None
+ tokenizer: Any = None
+
+ system_prompt: Any = None
+ visible_models: Any = None
+ h2ogpt_key: Any = None
+
+ count_input_tokens: Any = 0
+ count_output_tokens: Any = 0
+
+ min_max_new_tokens: Any = 256
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that python package exists in environment."""
+
+ try:
+ if values['client'] is None:
+ import gradio_client
+ values["client"] = gradio_client.Client(
+ values["inference_server_url"]
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import gradio_client python package. "
+ "Please install it with `pip install gradio_client`."
+ )
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "gradio_inference"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ # NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection,
+ # so server should get prompt_type or '', not plain
+ # This is good, so gradio server can also handle stopping.py conditions
+ # this is different than TGI server that uses prompter to inject prompt_type prompting
+ stream_output = self.stream_output
+ gr_client = self.client
+ client_langchain_mode = 'Disabled'
+ client_add_chat_history_to_context = True
+ client_add_search_to_context = False
+ client_chat_conversation = []
+ client_langchain_action = LangChainAction.QUERY.value
+ client_langchain_agents = []
+ top_k_docs = 1
+ chunk = True
+ chunk_size = 512
+ client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
+ iinput=self.iinput if self.chat_client else '', # only for chat=True
+ context=self.context,
+ # streaming output is supported, loops over and outputs each generation in streaming mode
+ # but leave stream_output=False for simple input/output mode
+ stream_output=stream_output,
+ prompt_type=self.prompter.prompt_type,
+ prompt_dict='',
+
+ temperature=self.temperature,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ num_beams=self.num_beams,
+ max_new_tokens=self.max_new_tokens,
+ min_new_tokens=self.min_new_tokens,
+ early_stopping=self.early_stopping,
+ max_time=self.max_time,
+ repetition_penalty=self.repetition_penalty,
+ num_return_sequences=self.num_return_sequences,
+ do_sample=self.do_sample,
+ chat=self.chat_client,
+
+ instruction_nochat=prompt if not self.chat_client else '',
+ iinput_nochat=self.iinput if not self.chat_client else '',
+ langchain_mode=client_langchain_mode,
+ add_chat_history_to_context=client_add_chat_history_to_context,
+ langchain_action=client_langchain_action,
+ langchain_agents=client_langchain_agents,
+ top_k_docs=top_k_docs,
+ chunk=chunk,
+ chunk_size=chunk_size,
+ document_subset=DocumentSubset.Relevant.name,
+ document_choice=[DocumentChoice.ALL.value],
+ pre_prompt_query=None,
+ prompt_query=None,
+ pre_prompt_summary=None,
+ prompt_summary=None,
+ system_prompt=self.system_prompt,
+ image_loaders=None, # don't need to further do doc specific things
+ pdf_loaders=None, # don't need to further do doc specific things
+ url_loaders=None, # don't need to further do doc specific things
+ jq_schema=None, # don't need to further do doc specific things
+ visible_models=self.visible_models,
+ h2ogpt_key=self.h2ogpt_key,
+ add_search_to_context=client_add_search_to_context,
+ chat_conversation=client_chat_conversation,
+ text_context_list=None,
+ docs_ordering_type=None,
+ min_max_new_tokens=self.min_max_new_tokens,
+ )
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
+ self.count_input_tokens += self.get_num_tokens(prompt)
+
+ if not stream_output:
+ res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name)
+ res_dict = ast.literal_eval(res)
+ text = res_dict['response']
+ ret = self.prompter.get_response(prompt + text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ self.count_output_tokens += self.get_num_tokens(ret)
+ return ret
+ else:
+ text_callback = None
+ if run_manager:
+ text_callback = partial(
+ run_manager.on_llm_new_token, verbose=self.verbose
+ )
+
+ job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name)
+ text0 = ''
+ while not job.done():
+ if job.communicator.job.latest_status.code.name == 'FINISHED':
+ break
+ e = job.future._exception
+ if e is not None:
+ break
+ outputs_list = job.communicator.job.outputs
+ if outputs_list:
+ res = job.communicator.job.outputs[-1]
+ res_dict = ast.literal_eval(res)
+ text = res_dict['response']
+ text = self.prompter.get_response(prompt + text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ # FIXME: derive chunk from full for now
+ text_chunk = text[len(text0):]
+ if not text_chunk:
+ continue
+ # save old
+ text0 = text
+
+ if text_callback:
+ text_callback(text_chunk)
+
+ time.sleep(0.01)
+
+ # ensure get last output to avoid race
+ res_all = job.outputs()
+ if len(res_all) > 0:
+ res = res_all[-1]
+ res_dict = ast.literal_eval(res)
+ text = res_dict['response']
+ # FIXME: derive chunk from full for now
+ else:
+ # go with old if failure
+ text = text0
+ text_chunk = text[len(text0):]
+ if text_callback:
+ text_callback(text_chunk)
+ ret = self.prompter.get_response(prompt + text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ self.count_output_tokens += self.get_num_tokens(ret)
+ return ret
+
+ def get_token_ids(self, text: str) -> List[int]:
+ return self.tokenizer.encode(text)
+ # avoid base method that is not aware of how to properly tokenize (uses GPT2)
+ # return _get_token_ids_default_method(text)
+
+
+class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
+ max_new_tokens: int = 512
+ do_sample: bool = False
+ top_k: Optional[int] = None
+ top_p: Optional[float] = 0.95
+ typical_p: Optional[float] = 0.95
+ temperature: float = 0.8
+ repetition_penalty: Optional[float] = None
+ return_full_text: bool = False
+ stop_sequences: List[str] = Field(default_factory=list)
+ seed: Optional[int] = None
+ inference_server_url: str = ""
+ timeout: int = 300
+ headers: dict = None
+ stream_output: bool = False
+ sanitize_bot_response: bool = False
+ prompter: Any = None
+ context: Any = ''
+ iinput: Any = ''
+ tokenizer: Any = None
+ async_sem: Any = None
+ count_input_tokens: Any = 0
+ count_output_tokens: Any = 0
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ if stop is None:
+ stop = self.stop_sequences.copy()
+ else:
+ stop += self.stop_sequences.copy()
+ stop_tmp = stop.copy()
+ stop = []
+ [stop.append(x) for x in stop_tmp if x not in stop]
+
+ # HF inference server needs control over input tokens
+ assert self.tokenizer is not None
+ from h2oai_pipeline import H2OTextGenerationPipeline
+ prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
+
+ # NOTE: TGI server does not add prompting, so must do here
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
+ prompt = self.prompter.generate_prompt(data_point)
+ self.count_input_tokens += self.get_num_tokens(prompt)
+
+ gen_server_kwargs = dict(do_sample=self.do_sample,
+ stop_sequences=stop,
+ max_new_tokens=self.max_new_tokens,
+ top_k=self.top_k,
+ top_p=self.top_p,
+ typical_p=self.typical_p,
+ temperature=self.temperature,
+ repetition_penalty=self.repetition_penalty,
+ return_full_text=self.return_full_text,
+ seed=self.seed,
+ )
+ gen_server_kwargs.update(kwargs)
+
+ # lower bound because client is re-used if multi-threading
+ self.client.timeout = max(300, self.timeout)
+
+ if not self.stream_output:
+ res = self.client.generate(
+ prompt,
+ **gen_server_kwargs,
+ )
+ if self.return_full_text:
+ gen_text = res.generated_text[len(prompt):]
+ else:
+ gen_text = res.generated_text
+ # remove stop sequences from the end of the generated text
+ for stop_seq in stop:
+ if stop_seq in gen_text:
+ gen_text = gen_text[:gen_text.index(stop_seq)]
+ text = prompt + gen_text
+ text = self.prompter.get_response(text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ else:
+ text_callback = None
+ if run_manager:
+ text_callback = partial(
+ run_manager.on_llm_new_token, verbose=self.verbose
+ )
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
+ if text_callback:
+ text_callback(prompt)
+ text = ""
+ # Note: Streaming ignores return_full_text=True
+ for response in self.client.generate_stream(prompt, **gen_server_kwargs):
+ text_chunk = response.token.text
+ text += text_chunk
+ text = self.prompter.get_response(prompt + text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ # stream part
+ is_stop = False
+ for stop_seq in stop:
+ if stop_seq in text_chunk:
+ is_stop = True
+ break
+ if is_stop:
+ break
+ if not response.token.special:
+ if text_callback:
+ text_callback(text_chunk)
+ self.count_output_tokens += self.get_num_tokens(text)
+ return text
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ # print("acall", flush=True)
+ if stop is None:
+ stop = self.stop_sequences.copy()
+ else:
+ stop += self.stop_sequences.copy()
+ stop_tmp = stop.copy()
+ stop = []
+ [stop.append(x) for x in stop_tmp if x not in stop]
+
+ # HF inference server needs control over input tokens
+ assert self.tokenizer is not None
+ from h2oai_pipeline import H2OTextGenerationPipeline
+ prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
+
+ # NOTE: TGI server does not add prompting, so must do here
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
+ prompt = self.prompter.generate_prompt(data_point)
+
+ gen_text = await super()._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
+
+ # remove stop sequences from the end of the generated text
+ for stop_seq in stop:
+ if stop_seq in gen_text:
+ gen_text = gen_text[:gen_text.index(stop_seq)]
+ text = prompt + gen_text
+ text = self.prompter.get_response(text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ # print("acall done", flush=True)
+ return text
+
+ async def _agenerate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ """Run the LLM on the given prompt and input."""
+ generations = []
+ new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
+ self.count_input_tokens += sum([self.get_num_tokens(prompt) for prompt in prompts])
+ tasks = [
+ asyncio.ensure_future(self._agenerate_one(prompt, stop=stop, run_manager=run_manager,
+ new_arg_supported=new_arg_supported, **kwargs))
+ for prompt in prompts
+ ]
+ texts = await asyncio.gather(*tasks)
+ self.count_output_tokens += sum([self.get_num_tokens(text) for text in texts])
+ [generations.append([Generation(text=text)]) for text in texts]
+ return LLMResult(generations=generations)
+
+ async def _agenerate_one(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ new_arg_supported=None,
+ **kwargs: Any,
+ ) -> str:
+ async with self.async_sem: # semaphore limits num of simultaneous downloads
+ return await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) \
+ if new_arg_supported else \
+ await self._acall(prompt, stop=stop, **kwargs)
+
+ def get_token_ids(self, text: str) -> List[int]:
+ return self.tokenizer.encode(text)
+ # avoid base method that is not aware of how to properly tokenize (uses GPT2)
+ # return _get_token_ids_default_method(text)
+
+
+from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
+from langchain.llms import OpenAI, AzureOpenAI, Replicate
+from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
+ update_token_usage
+
+
+class H2OOpenAI(OpenAI):
+ """
+ New class to handle vLLM's use of OpenAI, no vllm_chat supported, so only need here
+ Handles prompting that OpenAI doesn't need, stopping as well
+ """
+ stop_sequences: Any = None
+ sanitize_bot_response: bool = False
+ prompter: Any = None
+ context: Any = ''
+ iinput: Any = ''
+ tokenizer: Any = None
+
+ @classmethod
+ def _all_required_field_names(cls) -> Set:
+ _all_required_field_names = super(OpenAI, cls)._all_required_field_names()
+ _all_required_field_names.update(
+ {'top_p', 'frequency_penalty', 'presence_penalty', 'stop_sequences', 'sanitize_bot_response', 'prompter',
+ 'tokenizer', 'logit_bias'})
+ return _all_required_field_names
+
+ def _generate(
+ self,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ stop_tmp = self.stop_sequences if not stop else self.stop_sequences + stop
+ stop = []
+ [stop.append(x) for x in stop_tmp if x not in stop]
+
+ # HF inference server needs control over input tokens
+ assert self.tokenizer is not None
+ from h2oai_pipeline import H2OTextGenerationPipeline
+ for prompti, prompt in enumerate(prompts):
+ prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
+ # NOTE: OpenAI/vLLM server does not add prompting, so must do here
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
+ prompt = self.prompter.generate_prompt(data_point)
+ prompts[prompti] = prompt
+
+ params = self._invocation_params
+ params = {**params, **kwargs}
+ sub_prompts = self.get_sub_prompts(params, prompts, stop)
+ choices = []
+ token_usage: Dict[str, int] = {}
+ # Get the token usage from the response.
+ # Includes prompt, completion, and total tokens used.
+ _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
+ text = ''
+ for _prompts in sub_prompts:
+ if self.streaming:
+ text_with_prompt = ""
+ prompt = _prompts[0]
+ if len(_prompts) > 1:
+ raise ValueError("Cannot stream results with multiple prompts.")
+ params["stream"] = True
+ response = _streaming_response_template()
+ first = True
+ for stream_resp in completion_with_retry(
+ self, prompt=_prompts, **params
+ ):
+ if first:
+ stream_resp["choices"][0]["text"] = prompt + stream_resp["choices"][0]["text"]
+ first = False
+ text_chunk = stream_resp["choices"][0]["text"]
+ text_with_prompt += text_chunk
+ text = self.prompter.get_response(text_with_prompt, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ if run_manager:
+ run_manager.on_llm_new_token(
+ text_chunk,
+ verbose=self.verbose,
+ logprobs=stream_resp["choices"][0]["logprobs"],
+ )
+ _update_response(response, stream_resp)
+ choices.extend(response["choices"])
+ else:
+ response = completion_with_retry(self, prompt=_prompts, **params)
+ choices.extend(response["choices"])
+ if not self.streaming:
+ # Can't update token usage if streaming
+ update_token_usage(_keys, response, token_usage)
+ if self.streaming:
+ choices[0]['text'] = text
+ return self.create_llm_result(choices, prompts, token_usage)
+
+ def get_token_ids(self, text: str) -> List[int]:
+ if self.tokenizer is not None:
+ return self.tokenizer.encode(text)
+ else:
+ # OpenAI uses tiktoken
+ return super().get_token_ids(text)
+
+
+class H2OReplicate(Replicate):
+ stop_sequences: Any = None
+ sanitize_bot_response: bool = False
+ prompter: Any = None
+ context: Any = ''
+ iinput: Any = ''
+ tokenizer: Any = None
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call to replicate endpoint."""
+ stop_tmp = self.stop_sequences if not stop else self.stop_sequences + stop
+ stop = []
+ [stop.append(x) for x in stop_tmp if x not in stop]
+
+ # HF inference server needs control over input tokens
+ assert self.tokenizer is not None
+ from h2oai_pipeline import H2OTextGenerationPipeline
+ prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
+ # Note Replicate handles the prompting of the specific model
+ return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
+
+ def get_token_ids(self, text: str) -> List[int]:
+ return self.tokenizer.encode(text)
+ # avoid base method that is not aware of how to properly tokenize (uses GPT2)
+ # return _get_token_ids_default_method(text)
+
+
+class H2OChatOpenAI(ChatOpenAI):
+ @classmethod
+ def _all_required_field_names(cls) -> Set:
+ _all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names()
+ _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
+ return _all_required_field_names
+
+
+class H2OAzureChatOpenAI(AzureChatOpenAI):
+ @classmethod
+ def _all_required_field_names(cls) -> Set:
+ _all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names()
+ _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
+ return _all_required_field_names
+
+
+class H2OAzureOpenAI(AzureOpenAI):
+ @classmethod
+ def _all_required_field_names(cls) -> Set:
+ _all_required_field_names = super(AzureOpenAI, cls)._all_required_field_names()
+ _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
+ return _all_required_field_names
+
+
+class H2OHuggingFacePipeline(HuggingFacePipeline):
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ response = self.pipeline(prompt, stop=stop)
+ if self.pipeline.task == "text-generation":
+ # Text generation return includes the starter text.
+ text = response[0]["generated_text"][len(prompt):]
+ elif self.pipeline.task == "text2text-generation":
+ text = response[0]["generated_text"]
+ elif self.pipeline.task == "summarization":
+ text = response[0]["summary_text"]
+ else:
+ raise ValueError(
+ f"Got invalid task {self.pipeline.task}, "
+ f"currently only {VALID_TASKS} are supported"
+ )
+ if stop:
+ # This is a bit hacky, but I can't figure out a better way to enforce
+ # stop tokens when making calls to huggingface_hub.
+ text = enforce_stop_tokens(text, stop)
+ return text
+
+
+def get_llm(use_openai_model=False,
+ model_name=None,
+ model=None,
+ tokenizer=None,
+ inference_server=None,
+ langchain_only_model=None,
+ stream_output=False,
+ async_output=True,
+ num_async=3,
+ do_sample=False,
+ temperature=0.1,
+ top_k=40,
+ top_p=0.7,
+ num_beams=1,
+ max_new_tokens=512,
+ min_new_tokens=1,
+ early_stopping=False,
+ max_time=180,
+ repetition_penalty=1.0,
+ num_return_sequences=1,
+ prompt_type=None,
+ prompt_dict=None,
+ prompter=None,
+ context=None,
+ iinput=None,
+ sanitize_bot_response=False,
+ system_prompt='',
+ visible_models=0,
+ h2ogpt_key=None,
+ min_max_new_tokens=None,
+ n_jobs=None,
+ cli=False,
+ llamacpp_dict=None,
+ verbose=False,
+ ):
+ # currently all but h2oai_pipeline case return prompt + new text, but could change
+ only_new_text = False
+
+ if n_jobs in [None, -1]:
+ n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count() // 2)))
+ if inference_server is None:
+ inference_server = ''
+ if inference_server.startswith('replicate'):
+ model_string = ':'.join(inference_server.split(':')[1:])
+ if 'meta/llama' in model_string:
+ temperature = max(0.01, temperature if do_sample else 0)
+ else:
+ temperature =temperature if do_sample else 0
+ gen_kwargs = dict(temperature=temperature,
+ seed=1234,
+ max_length=max_new_tokens, # langchain
+ max_new_tokens=max_new_tokens, # replicate docs
+ top_p=top_p if do_sample else 1,
+ top_k=top_k, # not always supported
+ repetition_penalty=repetition_penalty)
+ if system_prompt in [None, 'None', 'auto']:
+ if prompter.system_prompt:
+ system_prompt = prompter.system_prompt
+ else:
+ system_prompt = ''
+ if system_prompt:
+ gen_kwargs.update(dict(system_prompt=system_prompt))
+
+ # replicate handles prompting, so avoid get_response() filter
+ prompter.prompt_type = 'plain'
+ if stream_output:
+ callbacks = [StreamingGradioCallbackHandler()]
+ streamer = callbacks[0] if stream_output else None
+ llm = H2OReplicate(
+ streaming=True,
+ callbacks=callbacks,
+ model=model_string,
+ input=gen_kwargs,
+ stop=prompter.stop_sequences,
+ stop_sequences=prompter.stop_sequences,
+ sanitize_bot_response=sanitize_bot_response,
+ prompter=prompter,
+ context=context,
+ iinput=iinput,
+ tokenizer=tokenizer,
+ )
+ else:
+ streamer = None
+ llm = H2OReplicate(
+ model=model_string,
+ input=gen_kwargs,
+ stop=prompter.stop_sequences,
+ stop_sequences=prompter.stop_sequences,
+ sanitize_bot_response=sanitize_bot_response,
+ prompter=prompter,
+ context=context,
+ iinput=iinput,
+ tokenizer=tokenizer,
+ )
+ elif use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
+ if use_openai_model and model_name is None:
+ model_name = "gpt-3.5-turbo"
+ # FIXME: Will later import be ignored? I think so, so should be fine
+ openai, inf_type, deployment_name, base_url, api_version = set_openai(inference_server)
+ kwargs_extra = {}
+ if inf_type == 'openai_chat' or inf_type == 'vllm_chat':
+ cls = H2OChatOpenAI
+ # FIXME: Support context, iinput
+ # if inf_type == 'vllm_chat':
+ # kwargs_extra.update(dict(tokenizer=tokenizer))
+ openai_api_key = openai.api_key
+ elif inf_type == 'openai_azure_chat':
+ cls = H2OAzureChatOpenAI
+ kwargs_extra.update(dict(openai_api_type='azure'))
+ # FIXME: Support context, iinput
+ if os.getenv('OPENAI_AZURE_KEY') is not None:
+ openai_api_key = os.getenv('OPENAI_AZURE_KEY')
+ else:
+ openai_api_key = openai.api_key
+ elif inf_type == 'openai_azure':
+ cls = H2OAzureOpenAI
+ kwargs_extra.update(dict(openai_api_type='azure'))
+ # FIXME: Support context, iinput
+ if os.getenv('OPENAI_AZURE_KEY') is not None:
+ openai_api_key = os.getenv('OPENAI_AZURE_KEY')
+ else:
+ openai_api_key = openai.api_key
+ else:
+ cls = H2OOpenAI
+ if inf_type == 'vllm':
+ kwargs_extra.update(dict(stop_sequences=prompter.stop_sequences,
+ sanitize_bot_response=sanitize_bot_response,
+ prompter=prompter,
+ context=context,
+ iinput=iinput,
+ tokenizer=tokenizer,
+ openai_api_base=openai.api_base,
+ client=None))
+ else:
+ assert inf_type == 'openai' or use_openai_model
+ openai_api_key = openai.api_key
+
+ if deployment_name:
+ kwargs_extra.update(dict(deployment_name=deployment_name))
+ if api_version:
+ kwargs_extra.update(dict(openai_api_version=api_version))
+ elif openai.api_version:
+ kwargs_extra.update(dict(openai_api_version=openai.api_version))
+ elif inf_type in ['openai_azure', 'openai_azure_chat']:
+ kwargs_extra.update(dict(openai_api_version="2023-05-15"))
+ if base_url:
+ kwargs_extra.update(dict(openai_api_base=base_url))
+ else:
+ kwargs_extra.update(dict(openai_api_base=openai.api_base))
+
+ callbacks = [StreamingGradioCallbackHandler()]
+ llm = cls(model_name=model_name,
+ temperature=temperature if do_sample else 0,
+ # FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py
+ max_tokens=max_new_tokens,
+ top_p=top_p if do_sample else 1,
+ frequency_penalty=0,
+ presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
+ callbacks=callbacks if stream_output else None,
+ openai_api_key=openai_api_key,
+ logit_bias=None if inf_type == 'vllm' else {},
+ max_retries=6,
+ streaming=stream_output,
+ **kwargs_extra
+ )
+ streamer = callbacks[0] if stream_output else None
+ if inf_type in ['openai', 'openai_chat', 'openai_azure', 'openai_azure_chat']:
+ prompt_type = inference_server
+ else:
+ # vllm goes here
+ prompt_type = prompt_type or 'plain'
+ elif inference_server and inference_server.startswith('sagemaker'):
+ callbacks = [StreamingGradioCallbackHandler()] # FIXME
+ streamer = None
+
+ endpoint_name = ':'.join(inference_server.split(':')[1:2])
+ region_name = ':'.join(inference_server.split(':')[2:])
+
+ from sagemaker import H2OSagemakerEndpoint, ChatContentHandler, BaseContentHandler
+ if inference_server.startswith('sagemaker_chat'):
+ content_handler = ChatContentHandler()
+ else:
+ content_handler = BaseContentHandler()
+ model_kwargs = dict(temperature=temperature if do_sample else 1E-10,
+ return_full_text=False, top_p=top_p, max_new_tokens=max_new_tokens)
+ llm = H2OSagemakerEndpoint(
+ endpoint_name=endpoint_name,
+ region_name=region_name,
+ aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
+ aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
+ model_kwargs=model_kwargs,
+ content_handler=content_handler,
+ endpoint_kwargs={'CustomAttributes': 'accept_eula=true'},
+ )
+ elif inference_server:
+ assert inference_server.startswith(
+ 'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server
+
+ from gradio_utils.grclient import GradioClient
+ from text_generation import Client as HFClient
+ if isinstance(model, GradioClient):
+ gr_client = model
+ hf_client = None
+ else:
+ gr_client = None
+ hf_client = model
+ assert isinstance(hf_client, HFClient)
+
+ inference_server, headers = get_hf_server(inference_server)
+
+ # quick sanity check to avoid long timeouts, just see if can reach server
+ requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
+ callbacks = [StreamingGradioCallbackHandler()]
+
+ if gr_client:
+ async_output = False # FIXME: not implemented yet
+ chat_client = False
+ llm = GradioInference(
+ inference_server_url=inference_server,
+ return_full_text=False,
+
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ early_stopping=early_stopping,
+ max_time=max_time,
+ repetition_penalty=repetition_penalty,
+ num_return_sequences=num_return_sequences,
+ do_sample=do_sample,
+ chat_client=chat_client,
+
+ callbacks=callbacks if stream_output else None,
+ stream_output=stream_output,
+ prompter=prompter,
+ context=context,
+ iinput=iinput,
+ client=gr_client,
+ sanitize_bot_response=sanitize_bot_response,
+ tokenizer=tokenizer,
+ system_prompt=system_prompt,
+ visible_models=visible_models,
+ h2ogpt_key=h2ogpt_key,
+ min_max_new_tokens=min_max_new_tokens,
+ )
+ elif hf_client:
+ # no need to pass original client, no state and fast, so can use same validate_environment from base class
+ async_sem = asyncio.Semaphore(num_async) if async_output else NullContext()
+ llm = H2OHuggingFaceTextGenInference(
+ inference_server_url=inference_server,
+ do_sample=do_sample,
+ max_new_tokens=max_new_tokens,
+ repetition_penalty=repetition_penalty,
+ return_full_text=False, # this only controls internal behavior, still returns processed text
+ seed=SEED,
+
+ stop_sequences=prompter.stop_sequences,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ # typical_p=top_p,
+ callbacks=callbacks if stream_output else None,
+ stream_output=stream_output,
+ prompter=prompter,
+ context=context,
+ iinput=iinput,
+ tokenizer=tokenizer,
+ timeout=max_time,
+ sanitize_bot_response=sanitize_bot_response,
+ async_sem=async_sem,
+ )
+ else:
+ raise RuntimeError("No defined client")
+ streamer = callbacks[0] if stream_output else None
+ elif model_name in non_hf_types:
+ async_output = False # FIXME: not implemented yet
+ assert langchain_only_model
+ if model_name == 'llama':
+ callbacks = [StreamingGradioCallbackHandler()]
+ streamer = callbacks[0] if stream_output else None
+ else:
+ # stream_output = False
+ # doesn't stream properly as generator, but at least
+ callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
+ streamer = None
+ if prompter:
+ prompt_type = prompter.prompt_type
+ else:
+ prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output)
+ pass # assume inputted prompt_type is correct
+ from gpt4all_llm import get_llm_gpt4all
+ max_max_tokens = tokenizer.model_max_length
+ llm = get_llm_gpt4all(model_name,
+ model=model,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ repetition_penalty=repetition_penalty,
+ top_k=top_k,
+ top_p=top_p,
+ callbacks=callbacks,
+ n_jobs=n_jobs,
+ verbose=verbose,
+ streaming=stream_output,
+ prompter=prompter,
+ context=context,
+ iinput=iinput,
+ max_seq_len=max_max_tokens,
+ llamacpp_dict=llamacpp_dict,
+ )
+ elif hasattr(model, 'is_exlama') and model.is_exlama():
+ async_output = False # FIXME: not implemented yet
+ assert langchain_only_model
+ callbacks = [StreamingGradioCallbackHandler()]
+ streamer = callbacks[0] if stream_output else None
+ max_max_tokens = tokenizer.model_max_length
+
+ from src.llm_exllama import Exllama
+ llm = Exllama(streaming=stream_output,
+ model_path=None,
+ model=model,
+ lora_path=None,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ typical=.7,
+ beams=1,
+ # beam_length = 40,
+ stop_sequences=prompter.stop_sequences,
+ callbacks=callbacks,
+ verbose=verbose,
+ max_seq_len=max_max_tokens,
+ fused_attn=False,
+ # alpha_value = 1.0, #For use with any models
+ # compress_pos_emb = 4.0, #For use with superhot
+ # set_auto_map = "3, 2" #Gpu split, this will split 3gigs/2gigs
+ prompter=prompter,
+ context=context,
+ iinput=iinput,
+ )
+ else:
+ async_output = False # FIXME: not implemented yet
+ if model is None:
+ # only used if didn't pass model in
+ assert tokenizer is None
+ prompt_type = 'human_bot'
+ if model_name is None:
+ model_name = 'h2oai/h2ogpt-oasst1-512-12b'
+ # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
+ # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
+ inference_server = ''
+ model, tokenizer, device = get_model(load_8bit=True, base_model=model_name,
+ inference_server=inference_server, gpu_id=0)
+
+ max_max_tokens = tokenizer.model_max_length
+ only_new_text = True
+ gen_kwargs = dict(do_sample=do_sample,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ early_stopping=early_stopping,
+ max_time=max_time,
+ repetition_penalty=repetition_penalty,
+ num_return_sequences=num_return_sequences,
+ return_full_text=not only_new_text,
+ handle_long_generation=None)
+ if do_sample:
+ gen_kwargs.update(dict(temperature=temperature,
+ top_k=top_k,
+ top_p=top_p))
+ assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
+ else:
+ assert len(set(gen_hyper0).difference(gen_kwargs.keys())) == 0
+
+ if stream_output:
+ skip_prompt = only_new_text
+ from gen import H2OTextIteratorStreamer
+ decoder_kwargs = {}
+ streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
+ gen_kwargs.update(dict(streamer=streamer))
+ else:
+ streamer = None
+
+ from h2oai_pipeline import H2OTextGenerationPipeline
+ pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
+ prompter=prompter,
+ context=context,
+ iinput=iinput,
+ prompt_type=prompt_type,
+ prompt_dict=prompt_dict,
+ sanitize_bot_response=sanitize_bot_response,
+ chat=False, stream_output=stream_output,
+ tokenizer=tokenizer,
+ # leave some room for 1 paragraph, even if min_new_tokens=0
+ max_input_tokens=max_max_tokens - max(min_new_tokens, 256),
+ base_model=model_name,
+ **gen_kwargs)
+ # pipe.task = "text-generation"
+ # below makes it listen only to our prompt removal,
+ # not built in prompt removal that is less general and not specific for our model
+ pipe.task = "text2text-generation"
+
+ llm = H2OHuggingFacePipeline(pipeline=pipe)
+ return llm, model_name, streamer, prompt_type, async_output, only_new_text
+
+
+def get_device_dtype():
+ # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
+ import torch
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
+ device = 'cpu' if n_gpus == 0 else 'cuda'
+ # from utils import NullContext
+ # context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class
+ context_class = torch.device
+ torch_dtype = torch.float16 if device == 'cuda' else torch.float32
+ return device, torch_dtype, context_class
+
+
+def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
+ """
+ Get wikipedia data from online
+ :param title:
+ :param first_paragraph_only:
+ :param text_limit:
+ :param take_head:
+ :return:
+ """
+ filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head)
+ url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}"
+ if first_paragraph_only:
+ url += "&exintro=1"
+ import json
+ if not os.path.isfile(filename):
+ data = requests.get(url).json()
+ json.dump(data, open(filename, 'wt'))
+ else:
+ data = json.load(open(filename, "rt"))
+ page_content = list(data["query"]["pages"].values())[0]["extract"]
+ if take_head is not None and text_limit is not None:
+ page_content = page_content[:text_limit] if take_head else page_content[-text_limit:]
+ title_url = str(title).replace(' ', '_')
+ return Document(
+ page_content=str(page_content),
+ metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"},
+ )
+
+
+def get_wiki_sources(first_para=True, text_limit=None):
+ """
+ Get specific named sources from wikipedia
+ :param first_para:
+ :param text_limit:
+ :return:
+ """
+ default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux']
+ wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources))
+ return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources]
+
+
+def get_github_docs(repo_owner, repo_name):
+ """
+ Access github from specific repo
+ :param repo_owner:
+ :param repo_name:
+ :return:
+ """
+ with tempfile.TemporaryDirectory() as d:
+ subprocess.check_call(
+ f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .",
+ cwd=d,
+ shell=True,
+ )
+ git_sha = (
+ subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d)
+ .decode("utf-8")
+ .strip()
+ )
+ repo_path = pathlib.Path(d)
+ markdown_files = list(repo_path.glob("*/*.md")) + list(
+ repo_path.glob("*/*.mdx")
+ )
+ for markdown_file in markdown_files:
+ with open(markdown_file, "r") as f:
+ relative_path = markdown_file.relative_to(repo_path)
+ github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}"
+ yield Document(page_content=str(f.read()), metadata={"source": github_url})
+
+
+def get_dai_pickle(dest="."):
+ from huggingface_hub import hf_hub_download
+ # True for case when locally already logged in with correct token, so don't have to set key
+ token = os.getenv('HUGGING_FACE_HUB_TOKEN', True)
+ path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset')
+ shutil.copy(path_to_zip_file, dest)
+
+
+def get_dai_docs(from_hf=False, get_pickle=True):
+ """
+ Consume DAI documentation, or consume from public pickle
+ :param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain
+ :param get_pickle: Avoid raw DAI docs, just get pickle directly from HF
+ :return:
+ """
+ import pickle
+
+ if get_pickle:
+ get_dai_pickle()
+
+ dai_store = 'dai_docs.pickle'
+ dst = "working_dir_docs"
+ if not os.path.isfile(dai_store):
+ from create_data import setup_dai_docs
+ dst = setup_dai_docs(dst=dst, from_hf=from_hf)
+
+ import glob
+ files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
+
+ basedir = os.path.abspath(os.getcwd())
+ from create_data import rst_to_outputs
+ new_outputs = rst_to_outputs(files)
+ os.chdir(basedir)
+
+ pickle.dump(new_outputs, open(dai_store, 'wb'))
+ else:
+ new_outputs = pickle.load(open(dai_store, 'rb'))
+
+ sources = []
+ for line, file in new_outputs:
+ # gradio requires any linked file to be with app.py
+ sym_src = os.path.abspath(os.path.join(dst, file))
+ sym_dst = os.path.abspath(os.path.join(os.getcwd(), file))
+ if os.path.lexists(sym_dst):
+ os.remove(sym_dst)
+ os.symlink(sym_src, sym_dst)
+ itm = Document(page_content=str(line), metadata={"source": file})
+ # NOTE: yield has issues when going into db, loses metadata
+ # yield itm
+ sources.append(itm)
+ return sources
+
+
+def get_supported_types():
+ non_image_types0 = ["pdf", "txt", "csv", "toml", "py", "rst", "xml", "rtf",
+ "md",
+ "html", "mhtml", "htm",
+ "enex", "eml", "epub", "odt", "pptx", "ppt",
+ "zip",
+ "gz",
+ "gzip",
+ "urls",
+ ]
+ # "msg", GPL3
+
+ video_types0 = ['WEBM',
+ 'MPG', 'MP2', 'MPEG', 'MPE', '.PV',
+ 'OGG',
+ 'MP4', 'M4P', 'M4V',
+ 'AVI', 'WMV',
+ 'MOV', 'QT',
+ 'FLV', 'SWF',
+ 'AVCHD']
+ video_types0 = [x.lower() for x in video_types0]
+ if have_pillow:
+ from PIL import Image
+ exts = Image.registered_extensions()
+ image_types0 = {ex for ex, f in exts.items() if f in Image.OPEN if ex not in video_types0 + non_image_types0}
+ image_types0 = sorted(image_types0)
+ image_types0 = [x[1:] if x.startswith('.') else x for x in image_types0]
+ else:
+ image_types0 = []
+ return non_image_types0, image_types0, video_types0
+
+
+non_image_types, image_types, video_types = get_supported_types()
+set_image_types = set(image_types)
+
+if have_libreoffice or True:
+ # or True so it tries to load, e.g. on MAC/Windows, even if don't have libreoffice since works without that
+ non_image_types.extend(["docx", "doc", "xls", "xlsx"])
+if have_jq:
+ non_image_types.extend(["json", "jsonl"])
+
+file_types = non_image_types + image_types
+
+
+def try_as_html(file):
+ # try treating as html as occurs when scraping websites
+ from bs4 import BeautifulSoup
+ with open(file, "rt") as f:
+ try:
+ is_html = bool(BeautifulSoup(f.read(), "html.parser").find())
+ except: # FIXME
+ is_html = False
+ if is_html:
+ file_url = 'file://' + file
+ doc1 = UnstructuredURLLoader(urls=[file_url]).load()
+ doc1 = [x for x in doc1 if x.page_content]
+ else:
+ doc1 = []
+ return doc1
+
+
+def json_metadata_func(record: dict, metadata: dict) -> dict:
+ # Define the metadata extraction function.
+
+ if isinstance(record, dict):
+ metadata["sender_name"] = record.get("sender_name")
+ metadata["timestamp_ms"] = record.get("timestamp_ms")
+
+ if "source" in metadata:
+ metadata["source_json"] = metadata['source']
+ if "seq_num" in metadata:
+ metadata["seq_num_json"] = metadata['seq_num']
+
+ return metadata
+
+
+def file_to_doc(file,
+ filei=0,
+ base_path=None, verbose=False, fail_any_exception=False,
+ chunk=True, chunk_size=512, n_jobs=-1,
+ is_url=False, is_txt=False,
+
+ # urls
+ use_unstructured=True,
+ use_playwright=False,
+ use_selenium=False,
+
+ # pdfs
+ use_pymupdf='auto',
+ use_unstructured_pdf='auto',
+ use_pypdf='auto',
+ enable_pdf_ocr='auto',
+ try_pdf_as_html='auto',
+ enable_pdf_doctr='auto',
+
+ # images
+ enable_ocr=False,
+ enable_doctr=False,
+ enable_pix2struct=False,
+ enable_captions=True,
+ captions_model=None,
+ model_loaders=None,
+
+ # json
+ jq_schema='.[]',
+
+ headsize=50, # see also H2OSerpAPIWrapper
+ db_type=None,
+ selected_file_types=None):
+ assert isinstance(model_loaders, dict)
+ if selected_file_types is not None:
+ set_image_types1 = set_image_types.intersection(set(selected_file_types))
+ else:
+ set_image_types1 = set_image_types
+
+ assert db_type is not None
+ chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type)
+ add_meta = functools.partial(_add_meta, headsize=headsize, filei=filei)
+ # FIXME: if zip, file index order will not be correct if other files involved
+ path_to_docs_func = functools.partial(path_to_docs,
+ verbose=verbose,
+ fail_any_exception=fail_any_exception,
+ n_jobs=n_jobs,
+ chunk=chunk, chunk_size=chunk_size,
+ # url=file if is_url else None,
+ # text=file if is_txt else None,
+
+ # urls
+ use_unstructured=use_unstructured,
+ use_playwright=use_playwright,
+ use_selenium=use_selenium,
+
+ # pdfs
+ use_pymupdf=use_pymupdf,
+ use_unstructured_pdf=use_unstructured_pdf,
+ use_pypdf=use_pypdf,
+ enable_pdf_ocr=enable_pdf_ocr,
+ enable_pdf_doctr=enable_pdf_doctr,
+ try_pdf_as_html=try_pdf_as_html,
+
+ # images
+ enable_ocr=enable_ocr,
+ enable_doctr=enable_doctr,
+ enable_pix2struct=enable_pix2struct,
+ enable_captions=enable_captions,
+ captions_model=captions_model,
+
+ caption_loader=model_loaders['caption'],
+ doctr_loader=model_loaders['doctr'],
+ pix2struct_loader=model_loaders['pix2struct'],
+
+ # json
+ jq_schema=jq_schema,
+
+ db_type=db_type,
+ )
+
+ if file is None:
+ if fail_any_exception:
+ raise RuntimeError("Unexpected None file")
+ else:
+ return []
+ doc1 = [] # in case no support, or disabled support
+ if base_path is None and not is_txt and not is_url:
+ # then assume want to persist but don't care which path used
+ # can't be in base_path
+ dir_name = os.path.dirname(file)
+ base_name = os.path.basename(file)
+ # if from gradio, will have its own temp uuid too, but that's ok
+ base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
+ base_path = os.path.join(dir_name, base_name)
+ if is_url:
+ file = file.strip() # in case accidental spaces in front or at end
+ file_lower = file.lower()
+ case1 = file_lower.startswith('arxiv:') and len(file_lower.split('arxiv:')) == 2
+ case2 = file_lower.startswith('https://arxiv.org/abs') and len(file_lower.split('https://arxiv.org/abs')) == 2
+ case3 = file_lower.startswith('http://arxiv.org/abs') and len(file_lower.split('http://arxiv.org/abs')) == 2
+ case4 = file_lower.startswith('arxiv.org/abs/') and len(file_lower.split('arxiv.org/abs/')) == 2
+ if case1 or case2 or case3 or case4:
+ if case1:
+ query = file.lower().split('arxiv:')[1].strip()
+ elif case2:
+ query = file.lower().split('https://arxiv.org/abs/')[1].strip()
+ elif case2:
+ query = file.lower().split('http://arxiv.org/abs/')[1].strip()
+ elif case3:
+ query = file.lower().split('arxiv.org/abs/')[1].strip()
+ else:
+ raise RuntimeError("Unexpected arxiv error for %s" % file)
+ if have_arxiv:
+ trials = 3
+ docs1 = []
+ for trial in range(trials):
+ try:
+ docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load()
+ break
+ except urllib.error.URLError:
+ pass
+ if not docs1:
+ print("Failed to get arxiv %s" % query, flush=True)
+ # ensure string, sometimes None
+ [[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1]
+ query_url = f"https://arxiv.org/abs/{query}"
+ [x.metadata.update(
+ dict(source=x.metadata.get('entry_id', query_url), query=query_url,
+ input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in
+ docs1]
+ else:
+ docs1 = []
+ else:
+ if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")):
+ file = 'http://' + file
+ docs1 = []
+ do_unstructured = only_unstructured_urls or use_unstructured
+ if only_selenium or only_playwright:
+ do_unstructured = False
+ do_playwright = have_playwright and (use_playwright or only_playwright)
+ if only_unstructured_urls or only_selenium:
+ do_playwright = False
+ do_selenium = have_selenium and (use_selenium or only_selenium)
+ if only_unstructured_urls or only_playwright:
+ do_selenium = False
+ if do_unstructured or use_unstructured:
+ docs1a = UnstructuredURLLoader(urls=[file]).load()
+ docs1a = [x for x in docs1a if x.page_content]
+ add_parser(docs1a, 'UnstructuredURLLoader')
+ docs1.extend(docs1a)
+ if len(docs1) == 0 and have_playwright or do_playwright:
+ # then something went wrong, try another loader:
+ from langchain.document_loaders import PlaywrightURLLoader
+ docs1a = asyncio.run(PlaywrightURLLoader(urls=[file]).aload())
+ # docs1 = PlaywrightURLLoader(urls=[file]).load()
+ docs1a = [x for x in docs1a if x.page_content]
+ add_parser(docs1a, 'PlaywrightURLLoader')
+ docs1.extend(docs1a)
+ if len(docs1) == 0 and have_selenium or do_selenium:
+ # then something went wrong, try another loader:
+ # but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException:
+ # Message: unknown error: cannot find Chrome binary
+ from langchain.document_loaders import SeleniumURLLoader
+ from selenium.common.exceptions import WebDriverException
+ try:
+ docs1a = SeleniumURLLoader(urls=[file]).load()
+ docs1a = [x for x in docs1a if x.page_content]
+ add_parser(docs1a, 'SeleniumURLLoader')
+ docs1.extend(docs1a)
+ except WebDriverException as e:
+ print("No web driver: %s" % str(e), flush=True)
+ [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
+ add_meta(docs1, file, parser="is_url")
+ docs1 = clean_doc(docs1)
+ doc1 = chunk_sources(docs1)
+ elif is_txt:
+ base_path = "user_paste"
+ base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
+ source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
+ with open(source_file, "wt") as f:
+ f.write(file)
+ metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
+ doc1 = Document(page_content=str(file), metadata=metadata)
+ add_meta(doc1, file, parser="f.write")
+ # Bit odd to change if was original text
+ # doc1 = clean_doc(doc1)
+ elif file.lower().endswith('.html') or file.lower().endswith('.mhtml') or file.lower().endswith('.htm'):
+ docs1 = UnstructuredHTMLLoader(file_path=file).load()
+ add_meta(docs1, file, parser='UnstructuredHTMLLoader')
+ docs1 = clean_doc(docs1)
+ doc1 = chunk_sources(docs1, language=Language.HTML)
+ elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and (have_libreoffice or True):
+ docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
+ add_meta(docs1, file, parser='UnstructuredWordDocumentLoader')
+ doc1 = chunk_sources(docs1)
+ elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and (have_libreoffice or True):
+ docs1 = UnstructuredExcelLoader(file_path=file).load()
+ add_meta(docs1, file, parser='UnstructuredExcelLoader')
+ doc1 = chunk_sources(docs1)
+ elif file.lower().endswith('.odt'):
+ docs1 = UnstructuredODTLoader(file_path=file).load()
+ add_meta(docs1, file, parser='UnstructuredODTLoader')
+ doc1 = chunk_sources(docs1)
+ elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
+ docs1 = UnstructuredPowerPointLoader(file_path=file).load()
+ add_meta(docs1, file, parser='UnstructuredPowerPointLoader')
+ docs1 = clean_doc(docs1)
+ doc1 = chunk_sources(docs1)
+ elif file.lower().endswith('.txt'):
+ # use UnstructuredFileLoader ?
+ docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
+ # makes just one, but big one
+ doc1 = chunk_sources(docs1)
+ # Bit odd to change if was original text
+ # doc1 = clean_doc(doc1)
+ add_meta(doc1, file, parser='TextLoader')
+ elif file.lower().endswith('.rtf'):
+ docs1 = UnstructuredRTFLoader(file).load()
+ add_meta(docs1, file, parser='UnstructuredRTFLoader')
+ doc1 = chunk_sources(docs1)
+ elif file.lower().endswith('.md'):
+ docs1 = UnstructuredMarkdownLoader(file).load()
+ add_meta(docs1, file, parser='UnstructuredMarkdownLoader')
+ docs1 = clean_doc(docs1)
+ doc1 = chunk_sources(docs1, language=Language.MARKDOWN)
+ elif file.lower().endswith('.enex'):
+ docs1 = EverNoteLoader(file).load()
+ add_meta(doc1, file, parser='EverNoteLoader')
+ doc1 = chunk_sources(docs1)
+ elif file.lower().endswith('.epub'):
+ docs1 = UnstructuredEPubLoader(file).load()
+ add_meta(docs1, file, parser='UnstructuredEPubLoader')
+ doc1 = chunk_sources(docs1)
+ elif any(file.lower().endswith(x) for x in set_image_types1):
+ docs1 = []
+ if verbose:
+ print("BEGIN: Tesseract", flush=True)
+ if have_tesseract and enable_ocr:
+ # OCR, somewhat works, but not great
+ docs1a = UnstructuredImageLoader(file, strategy='ocr_only').load()
+ # docs1a = UnstructuredImageLoader(file, strategy='hi_res').load()
+ docs1a = [x for x in docs1a if x.page_content]
+ add_meta(docs1a, file, parser='UnstructuredImageLoader')
+ docs1.extend(docs1a)
+ if verbose:
+ print("END: Tesseract", flush=True)
+ if have_doctr and enable_doctr:
+ if verbose:
+ print("BEGIN: DocTR", flush=True)
+ if model_loaders['doctr'] is not None and not isinstance(model_loaders['doctr'], (str, bool)):
+ if verbose:
+ print("Reuse DocTR", flush=True)
+ model_loaders['doctr'].load_model()
+ else:
+ if verbose:
+ print("Fresh DocTR", flush=True)
+ from image_doctr import H2OOCRLoader
+ model_loaders['doctr'] = H2OOCRLoader()
+ model_loaders['doctr'].set_document_paths([file])
+ docs1c = model_loaders['doctr'].load()
+ docs1c = [x for x in docs1c if x.page_content]
+ add_meta(docs1c, file, parser='H2OOCRLoader: %s' % 'DocTR')
+ # caption didn't set source, so fix-up meta
+ for doci in docs1c:
+ doci.metadata['source'] = doci.metadata.get('document_path', file)
+ doci.metadata['hashid'] = hash_file(doci.metadata['source'])
+ docs1.extend(docs1c)
+ if verbose:
+ print("END: DocTR", flush=True)
+ if enable_captions:
+ # BLIP
+ if verbose:
+ print("BEGIN: BLIP", flush=True)
+ if model_loaders['caption'] is not None and not isinstance(model_loaders['caption'], (str, bool)):
+ # assumes didn't fork into this process with joblib, else can deadlock
+ if verbose:
+ print("Reuse BLIP", flush=True)
+ model_loaders['caption'].load_model()
+ else:
+ if verbose:
+ print("Fresh BLIP", flush=True)
+ from image_captions import H2OImageCaptionLoader
+ model_loaders['caption'] = H2OImageCaptionLoader(caption_gpu=model_loaders['caption'] == 'gpu',
+ blip_model=captions_model,
+ blip_processor=captions_model)
+ model_loaders['caption'].set_image_paths([file])
+ docs1c = model_loaders['caption'].load()
+ docs1c = [x for x in docs1c if x.page_content]
+ add_meta(docs1c, file, parser='H2OImageCaptionLoader: %s' % captions_model)
+ # caption didn't set source, so fix-up meta
+ for doci in docs1c:
+ doci.metadata['source'] = doci.metadata.get('image_path', file)
+ doci.metadata['hashid'] = hash_file(doci.metadata['source'])
+ docs1.extend(docs1c)
+
+ if verbose:
+ print("END: BLIP", flush=True)
+ if enable_pix2struct:
+ # BLIP
+ if verbose:
+ print("BEGIN: Pix2Struct", flush=True)
+ if model_loaders['pix2struct'] is not None and not isinstance(model_loaders['pix2struct'], (str, bool)):
+ if verbose:
+ print("Reuse pix2struct", flush=True)
+ model_loaders['pix2struct'].load_model()
+ else:
+ if verbose:
+ print("Fresh pix2struct", flush=True)
+ from image_pix2struct import H2OPix2StructLoader
+ model_loaders['pix2struct'] = H2OPix2StructLoader()
+ model_loaders['pix2struct'].set_image_paths([file])
+ docs1c = model_loaders['pix2struct'].load()
+ docs1c = [x for x in docs1c if x.page_content]
+ add_meta(docs1c, file, parser='H2OPix2StructLoader: %s' % model_loaders['pix2struct'])
+ # caption didn't set source, so fix-up meta
+ for doci in docs1c:
+ doci.metadata['source'] = doci.metadata.get('image_path', file)
+ doci.metadata['hashid'] = hash_file(doci.metadata['source'])
+ docs1.extend(docs1c)
+ if verbose:
+ print("END: Pix2Struct", flush=True)
+ doc1 = chunk_sources(docs1)
+ elif file.lower().endswith('.msg'):
+ raise RuntimeError("Not supported, GPL3 license")
+ # docs1 = OutlookMessageLoader(file).load()
+ # docs1[0].metadata['source'] = file
+ elif file.lower().endswith('.eml'):
+ try:
+ docs1 = UnstructuredEmailLoader(file).load()
+ add_meta(docs1, file, parser='UnstructuredEmailLoader')
+ doc1 = chunk_sources(docs1)
+ except ValueError as e:
+ if 'text/html content not found in email' in str(e):
+ pass
+ else:
+ raise
+ doc1 = [x for x in doc1 if x.page_content]
+ if len(doc1) == 0:
+ # e.g. plain/text dict key exists, but not
+ # doc1 = TextLoader(file, encoding="utf8").load()
+ docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
+ docs1 = [x for x in docs1 if x.page_content]
+ add_meta(docs1, file, parser='UnstructuredEmailLoader text/plain')
+ doc1 = chunk_sources(docs1)
+ # elif file.lower().endswith('.gcsdir'):
+ # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
+ # elif file.lower().endswith('.gcsfile'):
+ # doc1 = GCSFileLoader(project_name, bucket, blob).load()
+ elif file.lower().endswith('.rst'):
+ with open(file, "r") as f:
+ doc1 = Document(page_content=str(f.read()), metadata={"source": file})
+ add_meta(doc1, file, parser='f.read()')
+ doc1 = chunk_sources(doc1, language=Language.RST)
+ elif file.lower().endswith('.json'):
+ # 10k rows, 100 columns-like parts 4 bytes each
+ JSON_SIZE_LIMIT = int(os.getenv('JSON_SIZE_LIMIT', str(10 * 10 * 1024 * 10 * 4)))
+ if os.path.getsize(file) > JSON_SIZE_LIMIT:
+ raise ValueError(
+ "JSON file sizes > %s not supported for naive parsing and embedding, requires Agents enabled" % JSON_SIZE_LIMIT)
+ loader = JSONLoader(
+ file_path=file,
+ # jq_schema='.messages[].content',
+ jq_schema=jq_schema,
+ text_content=False,
+ metadata_func=json_metadata_func)
+ doc1 = loader.load()
+ add_meta(doc1, file, parser='JSONLoader: %s' % jq_schema)
+ fix_json_meta(doc1)
+ elif file.lower().endswith('.jsonl'):
+ loader = JSONLoader(
+ file_path=file,
+ # jq_schema='.messages[].content',
+ jq_schema=jq_schema,
+ json_lines=True,
+ text_content=False,
+ metadata_func=json_metadata_func)
+ doc1 = loader.load()
+ add_meta(doc1, file, parser='JSONLoader: %s' % jq_schema)
+ fix_json_meta(doc1)
+ elif file.lower().endswith('.pdf'):
+ # migration
+ if isinstance(use_pymupdf, bool):
+ if use_pymupdf == False:
+ use_pymupdf = 'off'
+ if use_pymupdf == True:
+ use_pymupdf = 'on'
+ if isinstance(use_unstructured_pdf, bool):
+ if use_unstructured_pdf == False:
+ use_unstructured_pdf = 'off'
+ if use_unstructured_pdf == True:
+ use_unstructured_pdf = 'on'
+ if isinstance(use_pypdf, bool):
+ if use_pypdf == False:
+ use_pypdf = 'off'
+ if use_pypdf == True:
+ use_pypdf = 'on'
+ if isinstance(enable_pdf_ocr, bool):
+ if enable_pdf_ocr == False:
+ enable_pdf_ocr = 'off'
+ if enable_pdf_ocr == True:
+ enable_pdf_ocr = 'on'
+ if isinstance(try_pdf_as_html, bool):
+ if try_pdf_as_html == False:
+ try_pdf_as_html = 'off'
+ if try_pdf_as_html == True:
+ try_pdf_as_html = 'on'
+
+ doc1 = []
+ tried_others = False
+ handled = False
+ did_pymupdf = False
+ did_unstructured = False
+ e = None
+ if have_pymupdf and (len(doc1) == 0 and use_pymupdf == 'auto' or use_pymupdf == 'on'):
+ # GPL, only use if installed
+ from langchain.document_loaders import PyMuPDFLoader
+ # load() still chunks by pages, but every page has title at start to help
+ try:
+ doc1a = PyMuPDFLoader(file).load()
+ did_pymupdf = True
+ except BaseException as e0:
+ doc1a = []
+ print("PyMuPDFLoader: %s" % str(e0), flush=True)
+ e = e0
+ # remove empty documents
+ handled |= len(doc1a) > 0
+ doc1a = [x for x in doc1a if x.page_content]
+ doc1a = clean_doc(doc1a)
+ add_parser(doc1a, 'PyMuPDFLoader')
+ doc1.extend(doc1a)
+ if len(doc1) == 0 and use_unstructured_pdf == 'auto' or use_unstructured_pdf == 'on':
+ tried_others = True
+ try:
+ doc1a = UnstructuredPDFLoader(file).load()
+ did_unstructured = True
+ except BaseException as e0:
+ doc1a = []
+ print("UnstructuredPDFLoader: %s" % str(e0), flush=True)
+ e = e0
+ handled |= len(doc1a) > 0
+ # remove empty documents
+ doc1a = [x for x in doc1a if x.page_content]
+ add_parser(doc1a, 'UnstructuredPDFLoader')
+ # seems to not need cleaning in most cases
+ doc1.extend(doc1a)
+ if len(doc1) == 0 and use_pypdf == 'auto' or use_pypdf == 'on':
+ tried_others = True
+ # open-source fallback
+ # load() still chunks by pages, but every page has title at start to help
+ try:
+ doc1a = PyPDFLoader(file).load()
+ except BaseException as e0:
+ doc1a = []
+ print("PyPDFLoader: %s" % str(e0), flush=True)
+ e = e0
+ handled |= len(doc1a) > 0
+ # remove empty documents
+ doc1a = [x for x in doc1a if x.page_content]
+ doc1a = clean_doc(doc1a)
+ add_parser(doc1a, 'PyPDFLoader')
+ doc1.extend(doc1a)
+ if not did_pymupdf and ((have_pymupdf and len(doc1) == 0) and tried_others):
+ # try again in case only others used, but only if didn't already try (2nd part of and)
+ # GPL, only use if installed
+ from langchain.document_loaders import PyMuPDFLoader
+ # load() still chunks by pages, but every page has title at start to help
+ try:
+ doc1a = PyMuPDFLoader(file).load()
+ except BaseException as e0:
+ doc1a = []
+ print("PyMuPDFLoader: %s" % str(e0), flush=True)
+ e = e0
+ handled |= len(doc1a) > 0
+ # remove empty documents
+ doc1a = [x for x in doc1a if x.page_content]
+ doc1a = clean_doc(doc1a)
+ add_parser(doc1a, 'PyMuPDFLoader2')
+ doc1.extend(doc1a)
+ did_pdf_ocr = False
+ if len(doc1) == 0 and (enable_pdf_ocr == 'auto' and enable_pdf_doctr != 'on') or enable_pdf_ocr == 'on':
+ did_pdf_ocr = True
+ # no did_unstructured condition here because here we do OCR, and before we did not
+ # try OCR in end since slowest, but works on pure image pages well
+ doc1a = UnstructuredPDFLoader(file, strategy='ocr_only').load()
+ handled |= len(doc1a) > 0
+ # remove empty documents
+ doc1a = [x for x in doc1a if x.page_content]
+ add_parser(doc1a, 'UnstructuredPDFLoader ocr_only')
+ # seems to not need cleaning in most cases
+ doc1.extend(doc1a)
+ # Some PDFs return nothing or junk from PDFMinerLoader
+ if len(doc1) == 0 and enable_pdf_doctr == 'auto' or enable_pdf_doctr == 'on':
+ if verbose:
+ print("BEGIN: DocTR", flush=True)
+ if model_loaders['doctr'] is not None and not isinstance(model_loaders['doctr'], (str, bool)):
+ model_loaders['doctr'].load_model()
+ else:
+ from image_doctr import H2OOCRLoader
+ model_loaders['doctr'] = H2OOCRLoader()
+ model_loaders['doctr'].set_document_paths([file])
+ doc1a = model_loaders['doctr'].load()
+ doc1a = [x for x in doc1a if x.page_content]
+ add_meta(doc1a, file, parser='H2OOCRLoader: %s' % 'DocTR')
+ handled |= len(doc1a) > 0
+ # caption didn't set source, so fix-up meta
+ for doci in doc1a:
+ doci.metadata['source'] = doci.metadata.get('document_path', file)
+ doci.metadata['hashid'] = hash_file(doci.metadata['source'])
+ doc1.extend(doc1a)
+ if verbose:
+ print("END: DocTR", flush=True)
+ if try_pdf_as_html in ['auto', 'on']:
+ doc1a = try_as_html(file)
+ add_parser(doc1a, 'try_as_html')
+ doc1.extend(doc1a)
+
+ if len(doc1) == 0:
+ # if literally nothing, show failed to parse so user knows, since unlikely nothing in PDF at all.
+ if handled:
+ raise ValueError("%s had no valid text, but meta data was parsed" % file)
+ else:
+ raise ValueError("%s had no valid text and no meta data was parsed: %s" % (file, str(e)))
+ add_meta(doc1, file, parser='pdf')
+ doc1 = chunk_sources(doc1)
+ elif file.lower().endswith('.csv'):
+ CSV_SIZE_LIMIT = int(os.getenv('CSV_SIZE_LIMIT', str(10 * 1024 * 10 * 4)))
+ if os.path.getsize(file) > CSV_SIZE_LIMIT:
+ raise ValueError(
+ "CSV file sizes > %s not supported for naive parsing and embedding, requires Agents enabled" % CSV_SIZE_LIMIT)
+ doc1 = CSVLoader(file).load()
+ add_meta(doc1, file, parser='CSVLoader')
+ if isinstance(doc1, list):
+ # each row is a Document, identify
+ [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(doc1)]
+ if db_type in ['chroma', 'chroma_old']:
+ # then separate summarize list
+ sdoc1 = clone_documents(doc1)
+ [x.metadata.update(dict(chunk_id=-1)) for chunk_id, x in enumerate(sdoc1)]
+ doc1 = sdoc1 + doc1
+ elif file.lower().endswith('.py'):
+ doc1 = PythonLoader(file).load()
+ add_meta(doc1, file, parser='PythonLoader')
+ doc1 = chunk_sources(doc1, language=Language.PYTHON)
+ elif file.lower().endswith('.toml'):
+ doc1 = TomlLoader(file).load()
+ add_meta(doc1, file, parser='TomlLoader')
+ doc1 = chunk_sources(doc1)
+ elif file.lower().endswith('.xml'):
+ from langchain.document_loaders import UnstructuredXMLLoader
+ loader = UnstructuredXMLLoader(file_path=file)
+ doc1 = loader.load()
+ add_meta(doc1, file, parser='UnstructuredXMLLoader')
+ elif file.lower().endswith('.urls'):
+ with open(file, "r") as f:
+ urls = f.readlines()
+ # recurse
+ doc1 = path_to_docs_func(None, url=urls)
+ elif file.lower().endswith('.zip'):
+ with zipfile.ZipFile(file, 'r') as zip_ref:
+ # don't put into temporary path, since want to keep references to docs inside zip
+ # so just extract in path where
+ zip_ref.extractall(base_path)
+ # recurse
+ doc1 = path_to_docs_func(base_path)
+ elif file.lower().endswith('.gz') or file.lower().endswith('.gzip'):
+ if file.lower().endswith('.gz'):
+ de_file = file.lower().replace('.gz', '')
+ else:
+ de_file = file.lower().replace('.gzip', '')
+ with gzip.open(file, 'rb') as f_in:
+ with open(de_file, 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ # recurse
+ doc1 = file_to_doc(de_file,
+ filei=filei, # single file, same file index as outside caller
+ base_path=base_path, verbose=verbose, fail_any_exception=fail_any_exception,
+ chunk=chunk, chunk_size=chunk_size, n_jobs=n_jobs,
+ is_url=is_url, is_txt=is_txt,
+
+ # urls
+ use_unstructured=use_unstructured,
+ use_playwright=use_playwright,
+ use_selenium=use_selenium,
+
+ # pdfs
+ use_pymupdf=use_pymupdf,
+ use_unstructured_pdf=use_unstructured_pdf,
+ use_pypdf=use_pypdf,
+ enable_pdf_ocr=enable_pdf_ocr,
+ enable_pdf_doctr=enable_pdf_doctr,
+ try_pdf_as_html=try_pdf_as_html,
+
+ # images
+ enable_ocr=enable_ocr,
+ enable_doctr=enable_doctr,
+ enable_pix2struct=enable_pix2struct,
+ enable_captions=enable_captions,
+ captions_model=captions_model,
+ model_loaders=model_loaders,
+
+ # json
+ jq_schema=jq_schema,
+
+ headsize=headsize,
+ db_type=db_type,
+ selected_file_types=selected_file_types)
+ else:
+ raise RuntimeError("No file handler for %s" % os.path.basename(file))
+
+ # allow doc1 to be list or not.
+ if not isinstance(doc1, list):
+ # If not list, did not chunk yet, so chunk now
+ docs = chunk_sources([doc1])
+ elif isinstance(doc1, list) and len(doc1) == 1:
+ # if list of length one, don't trust and chunk it, chunk_id's will still be correct if repeat
+ docs = chunk_sources(doc1)
+ else:
+ docs = doc1
+
+ assert isinstance(docs, list)
+ return docs
+
+
+def path_to_doc1(file,
+ filei=0,
+ verbose=False, fail_any_exception=False, return_file=True,
+ chunk=True, chunk_size=512,
+ n_jobs=-1,
+ is_url=False, is_txt=False,
+
+ # urls
+ use_unstructured=True,
+ use_playwright=False,
+ use_selenium=False,
+
+ # pdfs
+ use_pymupdf='auto',
+ use_unstructured_pdf='auto',
+ use_pypdf='auto',
+ enable_pdf_ocr='auto',
+ enable_pdf_doctr='auto',
+ try_pdf_as_html='auto',
+
+ # images
+ enable_ocr=False,
+ enable_doctr=False,
+ enable_pix2struct=False,
+ enable_captions=True,
+ captions_model=None,
+ model_loaders=None,
+
+ # json
+ jq_schema='.[]',
+
+ db_type=None,
+ selected_file_types=None):
+ assert db_type is not None
+ if verbose:
+ if is_url:
+ print("Ingesting URL: %s" % file, flush=True)
+ elif is_txt:
+ print("Ingesting Text: %s" % file, flush=True)
+ else:
+ print("Ingesting file: %s" % file, flush=True)
+ res = None
+ try:
+ # don't pass base_path=path, would infinitely recurse
+ res = file_to_doc(file,
+ filei=filei,
+ base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
+ chunk=chunk, chunk_size=chunk_size,
+ n_jobs=n_jobs,
+ is_url=is_url, is_txt=is_txt,
+
+ # urls
+ use_unstructured=use_unstructured,
+ use_playwright=use_playwright,
+ use_selenium=use_selenium,
+
+ # pdfs
+ use_pymupdf=use_pymupdf,
+ use_unstructured_pdf=use_unstructured_pdf,
+ use_pypdf=use_pypdf,
+ enable_pdf_ocr=enable_pdf_ocr,
+ enable_pdf_doctr=enable_pdf_doctr,
+ try_pdf_as_html=try_pdf_as_html,
+
+ # images
+ enable_ocr=enable_ocr,
+ enable_doctr=enable_doctr,
+ enable_pix2struct=enable_pix2struct,
+ enable_captions=enable_captions,
+ captions_model=captions_model,
+ model_loaders=model_loaders,
+
+ # json
+ jq_schema=jq_schema,
+
+ db_type=db_type,
+ selected_file_types=selected_file_types)
+ except BaseException as e:
+ print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
+ if fail_any_exception:
+ raise
+ else:
+ exception_doc = Document(
+ page_content='',
+ metadata={"source": file, "exception": '%s Exception: %s' % (file, str(e)),
+ "traceback": traceback.format_exc()})
+ res = [exception_doc]
+ if verbose:
+ if is_url:
+ print("DONE Ingesting URL: %s" % file, flush=True)
+ elif is_txt:
+ print("DONE Ingesting Text: %s" % file, flush=True)
+ else:
+ print("DONE Ingesting file: %s" % file, flush=True)
+ if return_file:
+ base_tmp = "temp_path_to_doc1"
+ if not os.path.isdir(base_tmp):
+ base_tmp = makedirs(base_tmp, exist_ok=True, tmp_ok=True, use_base=True)
+ filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
+ with open(filename, 'wb') as f:
+ pickle.dump(res, f)
+ return filename
+ return res
+
+
+def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1,
+ chunk=True, chunk_size=512,
+ url=None, text=None,
+
+ # urls
+ use_unstructured=True,
+ use_playwright=False,
+ use_selenium=False,
+
+ # pdfs
+ use_pymupdf='auto',
+ use_unstructured_pdf='auto',
+ use_pypdf='auto',
+ enable_pdf_ocr='auto',
+ enable_pdf_doctr='auto',
+ try_pdf_as_html='auto',
+
+ # images
+ enable_ocr=False,
+ enable_doctr=False,
+ enable_pix2struct=False,
+ enable_captions=True,
+ captions_model=None,
+
+ caption_loader=None,
+ doctr_loader=None,
+ pix2struct_loader=None,
+
+ # json
+ jq_schema='.[]',
+
+ existing_files=[],
+ existing_hash_ids={},
+ db_type=None,
+ selected_file_types=None,
+ ):
+ if verbose:
+ print("BEGIN Consuming path_or_paths=%s url=%s text=%s" % (path_or_paths, url, text), flush=True)
+ if selected_file_types is not None:
+ non_image_types1 = [x for x in non_image_types if x in selected_file_types]
+ image_types1 = [x for x in image_types if x in selected_file_types]
+ else:
+ non_image_types1 = non_image_types.copy()
+ image_types1 = image_types.copy()
+
+ assert db_type is not None
+ # path_or_paths could be str, list, tuple, generator
+ globs_image_types = []
+ globs_non_image_types = []
+ if not path_or_paths and not url and not text:
+ return []
+ elif url:
+ url = get_list_or_str(url)
+ globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url]
+ elif text:
+ globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text]
+ elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths):
+ # single path, only consume allowed files
+ path = path_or_paths
+ # Below globs should match patterns in file_to_doc()
+ [globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
+ for ftype in image_types1]
+ globs_image_types = [os.path.normpath(x) for x in globs_image_types]
+ [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
+ for ftype in non_image_types1]
+ globs_non_image_types = [os.path.normpath(x) for x in globs_non_image_types]
+ else:
+ if isinstance(path_or_paths, str):
+ if os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths):
+ path_or_paths = [path_or_paths]
+ else:
+ # path was deleted etc.
+ return []
+ # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
+ assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), \
+ "Wrong type for path_or_paths: %s %s" % (path_or_paths, type(path_or_paths))
+ # reform out of allowed types
+ globs_image_types.extend(
+ flatten_list([[os.path.normpath(x) for x in path_or_paths if x.endswith(y)] for y in image_types1]))
+ # could do below:
+ # globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types1])
+ # But instead, allow fail so can collect unsupported too
+ set_globs_image_types = set(globs_image_types)
+ globs_non_image_types.extend([os.path.normpath(x) for x in path_or_paths if x not in set_globs_image_types])
+
+ # filter out any files to skip (e.g. if already processed them)
+ # this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[]
+ assert not existing_files, "DEV: assume not using this approach"
+ if existing_files:
+ set_skip_files = set(existing_files)
+ globs_image_types = [x for x in globs_image_types if x not in set_skip_files]
+ globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files]
+ if existing_hash_ids:
+ # assume consistent with add_meta() use of hash_file(file)
+ # also assume consistent with get_existing_hash_ids for dict creation
+ # assume hashable values
+ existing_hash_ids_set = set(existing_hash_ids.items())
+ hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items())
+ hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items())
+ # don't use symmetric diff. If file is gone, ignore and don't remove or something
+ # just consider existing files (key) having new hash or not (value)
+ new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys())
+ new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys())
+ globs_image_types = [x for x in globs_image_types if x in new_files_image]
+ globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image]
+
+ # could use generator, but messes up metadata handling in recursive case
+ if caption_loader and not isinstance(caption_loader, (bool, str)) and caption_loader.device != 'cpu' or \
+ get_device() == 'cuda':
+ # to avoid deadlocks, presume was preloaded and so can't fork due to cuda context
+ # get_device() == 'cuda' because presume faster to process image from (temporarily) preloaded model
+ n_jobs_image = 1
+ else:
+ n_jobs_image = n_jobs
+ if enable_doctr or enable_pdf_doctr in [True, 'auto', 'on']:
+ if doctr_loader and not isinstance(doctr_loader, (bool, str)) and doctr_loader.device != 'cpu':
+ # can't fork cuda context
+ n_jobs = 1
+
+ return_file = True # local choice
+ is_url = url is not None
+ is_txt = text is not None
+ model_loaders = dict(caption=caption_loader,
+ doctr=doctr_loader,
+ pix2struct=pix2struct_loader)
+ model_loaders0 = model_loaders.copy()
+ kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
+ return_file=return_file,
+ chunk=chunk, chunk_size=chunk_size,
+ n_jobs=n_jobs,
+ is_url=is_url,
+ is_txt=is_txt,
+
+ # urls
+ use_unstructured=use_unstructured,
+ use_playwright=use_playwright,
+ use_selenium=use_selenium,
+
+ # pdfs
+ use_pymupdf=use_pymupdf,
+ use_unstructured_pdf=use_unstructured_pdf,
+ use_pypdf=use_pypdf,
+ enable_pdf_ocr=enable_pdf_ocr,
+ enable_pdf_doctr=enable_pdf_doctr,
+ try_pdf_as_html=try_pdf_as_html,
+
+ # images
+ enable_ocr=enable_ocr,
+ enable_doctr=enable_doctr,
+ enable_pix2struct=enable_pix2struct,
+ enable_captions=enable_captions,
+ captions_model=captions_model,
+ model_loaders=model_loaders,
+
+ # json
+ jq_schema=jq_schema,
+
+ db_type=db_type,
+ selected_file_types=selected_file_types,
+ )
+ if n_jobs != 1 and len(globs_non_image_types) > 1:
+ # avoid nesting, e.g. upload 1 zip and then inside many files
+ # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
+ documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
+ delayed(path_to_doc1)(file, filei=filei, **kwargs) for filei, file in enumerate(globs_non_image_types)
+ )
+ else:
+ documents = [path_to_doc1(file, filei=filei, **kwargs) for filei, file in
+ enumerate(tqdm(globs_non_image_types))]
+
+ # do images separately since can't fork after cuda in parent, so can't be parallel
+ if n_jobs_image != 1 and len(globs_image_types) > 1:
+ # avoid nesting, e.g. upload 1 zip and then inside many files
+ # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
+ image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
+ delayed(path_to_doc1)(file, filei=filei, **kwargs) for filei, file in enumerate(globs_image_types)
+ )
+ else:
+ image_documents = [path_to_doc1(file, filei=filei, **kwargs) for filei, file in
+ enumerate(tqdm(globs_image_types))]
+
+ # unload loaders (image loaders, includes enable_pdf_doctr that uses same loader)
+ for name, loader in model_loaders.items():
+ loader0 = model_loaders0[name]
+ real_model_initial = loader0 is not None and not isinstance(loader0, (str, bool))
+ real_model_final = model_loaders[name] is not None and not isinstance(model_loaders[name], (str, bool))
+ if not real_model_initial and real_model_final:
+ # clear off GPU newly added model
+ model_loaders[name].unload_model()
+
+ # add image docs in
+ documents += image_documents
+
+ if return_file:
+ # then documents really are files
+ files = documents.copy()
+ documents = []
+ for fil in files:
+ with open(fil, 'rb') as f:
+ documents.extend(pickle.load(f))
+ # remove temp pickle
+ remove(fil)
+ else:
+ documents = reduce(concat, documents)
+
+ if verbose:
+ print("END consuming path_or_paths=%s url=%s text=%s" % (path_or_paths, url, text), flush=True)
+ return documents
+
+
+def prep_langchain(persist_directory,
+ load_db_if_exists,
+ db_type, use_openai_embedding,
+ langchain_mode, langchain_mode_paths, langchain_mode_types,
+ hf_embedding_model,
+ migrate_embedding_model,
+ auto_migrate_db,
+ n_jobs=-1, kwargs_make_db={},
+ verbose=False):
+ """
+ do prep first time, involving downloads
+ # FIXME: Add github caching then add here
+ :return:
+ """
+ if os.getenv("HARD_ASSERTS"):
+ assert langchain_mode not in ['MyData'], "Should not prep scratch/personal data"
+
+ if langchain_mode in langchain_modes_intrinsic:
+ return None
+
+ db_dir_exists = os.path.isdir(persist_directory)
+ user_path = langchain_mode_paths.get(langchain_mode)
+
+ if db_dir_exists and user_path is None:
+ if verbose:
+ print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
+ db, use_openai_embedding, hf_embedding_model = \
+ get_existing_db(None, persist_directory, load_db_if_exists,
+ db_type, use_openai_embedding,
+ langchain_mode, langchain_mode_paths, langchain_mode_types,
+ hf_embedding_model, migrate_embedding_model, auto_migrate_db,
+ n_jobs=n_jobs)
+ else:
+ if db_dir_exists and user_path is not None:
+ if verbose:
+ print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
+ persist_directory, user_path), flush=True)
+ elif not db_dir_exists:
+ if verbose:
+ print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
+ db = None
+ if langchain_mode in ['DriverlessAI docs']:
+ # FIXME: Could also just use dai_docs.pickle directly and upload that
+ get_dai_docs(from_hf=True)
+
+ if langchain_mode in ['wiki']:
+ get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit'])
+
+ langchain_kwargs = kwargs_make_db.copy()
+ langchain_kwargs.update(locals())
+ db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs)
+
+ return db
+
+
+import posthog
+
+posthog.disabled = True
+
+
+class FakeConsumer(object):
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def run(self):
+ pass
+
+ def pause(self):
+ pass
+
+ def upload(self):
+ pass
+
+ def next(self):
+ pass
+
+ def request(self, batch):
+ pass
+
+
+posthog.Consumer = FakeConsumer
+
+
+def check_update_chroma_embedding(db,
+ db_type,
+ use_openai_embedding,
+ hf_embedding_model, migrate_embedding_model, auto_migrate_db,
+ langchain_mode, langchain_mode_paths, langchain_mode_types,
+ n_jobs=-1):
+ changed_db = False
+ embed_tuple = load_embed(db=db)
+ if embed_tuple not in [(True, use_openai_embedding, hf_embedding_model),
+ (False, use_openai_embedding, hf_embedding_model)]:
+ print("Detected new embedding %s vs. %s %s, updating db: %s" % (
+ use_openai_embedding, hf_embedding_model, embed_tuple, langchain_mode), flush=True)
+ # handle embedding changes
+ db_get = get_documents(db)
+ sources = [Document(page_content=result[0], metadata=result[1] or {})
+ for result in zip(db_get['documents'], db_get['metadatas'])]
+ # delete index, has to be redone
+ persist_directory = db._persist_directory
+ shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak")
+ assert db_type in ['chroma', 'chroma_old']
+ load_db_if_exists = False
+ db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
+ persist_directory=persist_directory, load_db_if_exists=load_db_if_exists,
+ langchain_mode=langchain_mode,
+ langchain_mode_paths=langchain_mode_paths,
+ langchain_mode_types=langchain_mode_types,
+ collection_name=None,
+ hf_embedding_model=hf_embedding_model,
+ migrate_embedding_model=migrate_embedding_model,
+ auto_migrate_db=auto_migrate_db,
+ n_jobs=n_jobs,
+ )
+ changed_db = True
+ print("Done updating db for new embedding: %s" % langchain_mode, flush=True)
+
+ return db, changed_db
+
+
+def migrate_meta_func(db, langchain_mode):
+ changed_db = False
+ db_get = get_documents(db)
+ # just check one doc
+ if len(db_get['metadatas']) > 0 and 'chunk_id' not in db_get['metadatas'][0]:
+ print("Detected old metadata, adding additional information", flush=True)
+ t0 = time.time()
+ # handle meta changes
+ [x.update(dict(chunk_id=x.get('chunk_id', 0))) for x in db_get['metadatas']]
+ client_collection = db._client.get_collection(name=db._collection.name,
+ embedding_function=db._collection._embedding_function)
+ client_collection.update(ids=db_get['ids'], metadatas=db_get['metadatas'])
+ # check
+ db_get = get_documents(db)
+ assert 'chunk_id' in db_get['metadatas'][0], "Failed to add meta"
+ changed_db = True
+ print("Done updating db for new meta: %s in %s seconds" % (langchain_mode, time.time() - t0), flush=True)
+
+ return db, changed_db
+
+
+def get_existing_db(db, persist_directory,
+ load_db_if_exists, db_type, use_openai_embedding,
+ langchain_mode, langchain_mode_paths, langchain_mode_types,
+ hf_embedding_model,
+ migrate_embedding_model,
+ auto_migrate_db=False,
+ verbose=False, check_embedding=True, migrate_meta=True,
+ n_jobs=-1):
+ if load_db_if_exists and db_type in ['chroma', 'chroma_old'] and os.path.isdir(persist_directory):
+ if os.path.isfile(os.path.join(persist_directory, 'chroma.sqlite3')):
+ must_migrate = False
+ elif os.path.isdir(os.path.join(persist_directory, 'index')):
+ must_migrate = True
+ else:
+ return db, use_openai_embedding, hf_embedding_model
+ chroma_settings = dict(is_persistent=True)
+ use_chromamigdb = False
+ if must_migrate:
+ if auto_migrate_db:
+ print("Detected chromadb<0.4 database, require migration, doing now....", flush=True)
+ from chroma_migrate.import_duckdb import migrate_from_duckdb
+ import chromadb
+ api = chromadb.PersistentClient(path=persist_directory)
+ did_migration = migrate_from_duckdb(api, persist_directory)
+ assert did_migration, "Failed to migrate chroma collection at %s, see https://docs.trychroma.com/migration for CLI tool" % persist_directory
+ elif have_chromamigdb:
+ print(
+ "Detected chroma<0.4 database but --auto_migrate_db=False, but detected chromamigdb package, so using old database that still requires duckdb",
+ flush=True)
+ chroma_settings = dict(chroma_db_impl="duckdb+parquet")
+ use_chromamigdb = True
+ else:
+ raise ValueError(
+ "Detected chromadb<0.4 database, require migration, but did not detect chromamigdb package or did not choose auto_migrate_db=False (see FAQ.md)")
+
+ if db is None:
+ if verbose:
+ print("DO Loading db: %s" % langchain_mode, flush=True)
+ got_embedding, use_openai_embedding0, hf_embedding_model0 = load_embed(persist_directory=persist_directory)
+ if got_embedding:
+ use_openai_embedding, hf_embedding_model = use_openai_embedding0, hf_embedding_model0
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
+ import logging
+ logging.getLogger("chromadb").setLevel(logging.ERROR)
+ if use_chromamigdb:
+ from chromamigdb.config import Settings
+ chroma_class = ChromaMig
+ else:
+ from chromadb.config import Settings
+ chroma_class = Chroma
+ client_settings = Settings(anonymized_telemetry=False,
+ **chroma_settings,
+ persist_directory=persist_directory)
+ db = chroma_class(persist_directory=persist_directory, embedding_function=embedding,
+ collection_name=langchain_mode.replace(' ', '_'),
+ client_settings=client_settings)
+ try:
+ db.similarity_search('')
+ except BaseException as e:
+ # migration when no embed_info
+ if 'Dimensionality of (768) does not match index dimensionality (384)' in str(e) or \
+ 'Embedding dimension 768 does not match collection dimensionality 384' in str(e):
+ hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
+ db = chroma_class(persist_directory=persist_directory, embedding_function=embedding,
+ collection_name=langchain_mode.replace(' ', '_'),
+ client_settings=client_settings)
+ # should work now, let fail if not
+ db.similarity_search('')
+ save_embed(db, use_openai_embedding, hf_embedding_model)
+ else:
+ raise
+
+ if verbose:
+ print("DONE Loading db: %s" % langchain_mode, flush=True)
+ else:
+ if not migrate_embedding_model:
+ # OVERRIDE embedding choices if could load embedding info when not migrating
+ got_embedding, use_openai_embedding, hf_embedding_model = load_embed(db=db)
+ if verbose:
+ print("USING already-loaded db: %s" % langchain_mode, flush=True)
+ if check_embedding:
+ db_trial, changed_db = check_update_chroma_embedding(db,
+ db_type,
+ use_openai_embedding,
+ hf_embedding_model,
+ migrate_embedding_model,
+ auto_migrate_db,
+ langchain_mode,
+ langchain_mode_paths,
+ langchain_mode_types,
+ n_jobs=n_jobs)
+ if changed_db:
+ db = db_trial
+ # only call persist if really changed db, else takes too long for large db
+ if db is not None:
+ db.persist()
+ clear_embedding(db)
+ save_embed(db, use_openai_embedding, hf_embedding_model)
+ if migrate_meta and db is not None:
+ db_trial, changed_db = migrate_meta_func(db, langchain_mode)
+ if changed_db:
+ db = db_trial
+ return db, use_openai_embedding, hf_embedding_model
+ return db, use_openai_embedding, hf_embedding_model
+
+
+def clear_embedding(db):
+ if db is None:
+ return
+ # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
+ try:
+ if hasattr(db._embedding_function, 'client') and hasattr(db._embedding_function.client, 'cpu'):
+ # only push back to CPU if each db/user has own embedding model, else if shared share on GPU
+ if hasattr(db._embedding_function.client, 'preload') and not db._embedding_function.client.preload:
+ db._embedding_function.client.cpu()
+ clear_torch_cache()
+ except RuntimeError as e:
+ print("clear_embedding error: %s" % ''.join(traceback.format_tb(e.__traceback__)), flush=True)
+
+
+def make_db(**langchain_kwargs):
+ func_names = list(inspect.signature(_make_db).parameters)
+ missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
+ defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()}
+ for k in missing_kwargs:
+ if k in defaults_db:
+ langchain_kwargs[k] = defaults_db[k]
+ # final check for missing
+ missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
+ assert not missing_kwargs, "Missing kwargs for make_db: %s" % missing_kwargs
+ # only keep actual used
+ langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
+ return _make_db(**langchain_kwargs)
+
+
+embed_lock_name = 'embed.lock'
+
+
+def get_embed_lock_file(db, persist_directory=None):
+ if hasattr(db, '_persist_directory') or persist_directory:
+ if persist_directory is None:
+ persist_directory = db._persist_directory
+ check_persist_directory(persist_directory)
+ base_path = os.path.join('locks', persist_directory)
+ base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
+ lock_file = os.path.join(base_path, embed_lock_name)
+ makedirs(os.path.dirname(lock_file))
+ return lock_file
+ return None
+
+
+def save_embed(db, use_openai_embedding, hf_embedding_model):
+ if hasattr(db, '_persist_directory'):
+ persist_directory = db._persist_directory
+ lock_file = get_embed_lock_file(db)
+ with filelock.FileLock(lock_file):
+ embed_info_file = os.path.join(persist_directory, 'embed_info')
+ with open(embed_info_file, 'wb') as f:
+ if isinstance(hf_embedding_model, str):
+ hf_embedding_model_save = hf_embedding_model
+ elif hasattr(hf_embedding_model, 'model_name'):
+ hf_embedding_model_save = hf_embedding_model.model_name
+ elif isinstance(hf_embedding_model, dict) and 'name' in hf_embedding_model:
+ hf_embedding_model_save = hf_embedding_model['name']
+ elif isinstance(hf_embedding_model, dict) and 'name' in hf_embedding_model:
+ if os.getenv('HARD_ASSERTS'):
+ # unexpected in testing or normally
+ raise RuntimeError("HERE")
+ hf_embedding_model_save = 'hkunlp/instructor-large'
+ pickle.dump((use_openai_embedding, hf_embedding_model_save), f)
+ return use_openai_embedding, hf_embedding_model
+
+
+def load_embed(db=None, persist_directory=None):
+ if hasattr(db, 'embeddings') and hasattr(db.embeddings, 'model_name'):
+ hf_embedding_model = db.embeddings.model_name if 'openai' not in db.embeddings.model_name.lower() else None
+ use_openai_embedding = hf_embedding_model is None
+ save_embed(db, use_openai_embedding, hf_embedding_model)
+ return True, use_openai_embedding, hf_embedding_model
+ if persist_directory is None:
+ persist_directory = db._persist_directory
+ embed_info_file = os.path.join(persist_directory, 'embed_info')
+ if os.path.isfile(embed_info_file):
+ lock_file = get_embed_lock_file(db, persist_directory=persist_directory)
+ with filelock.FileLock(lock_file):
+ with open(embed_info_file, 'rb') as f:
+ try:
+ use_openai_embedding, hf_embedding_model = pickle.load(f)
+ if not isinstance(hf_embedding_model, str):
+ # work-around bug introduced here: https://github.com/h2oai/h2ogpt/commit/54c4414f1ce3b5b7c938def651c0f6af081c66de
+ hf_embedding_model = 'hkunlp/instructor-large'
+ # fix file
+ save_embed(db, use_openai_embedding, hf_embedding_model)
+ got_embedding = True
+ except EOFError:
+ use_openai_embedding, hf_embedding_model = False, 'hkunlp/instructor-large'
+ got_embedding = False
+ if os.getenv('HARD_ASSERTS'):
+ # unexpected in testing or normally
+ raise
+ else:
+ # migration, assume defaults
+ use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2"
+ got_embedding = False
+ assert isinstance(hf_embedding_model, str)
+ return got_embedding, use_openai_embedding, hf_embedding_model
+
+
+def get_persist_directory(langchain_mode, langchain_type=None, db1s=None, dbs=None):
+ if langchain_mode in [LangChainMode.DISABLED.value, LangChainMode.LLM.value]:
+ # not None so join works but will fail to find db
+ return '', langchain_type
+
+ userid = get_userid_direct(db1s)
+ username = get_username_direct(db1s)
+
+ # sanity for bad code
+ assert userid != 'None'
+ assert username != 'None'
+
+ dirid = username or userid
+ if langchain_type == LangChainTypes.SHARED.value and not dirid:
+ dirid = './' # just to avoid error
+ if langchain_type == LangChainTypes.PERSONAL.value and not dirid:
+ # e.g. from client when doing transient calls with MyData
+ if db1s is None:
+ # just trick to get filled locally
+ db1s = {LangChainMode.MY_DATA.value: [None, None, None]}
+ set_userid_direct(db1s, str(uuid.uuid4()), str(uuid.uuid4()))
+ userid = get_userid_direct(db1s)
+ username = get_username_direct(db1s)
+ dirid = username or userid
+ langchain_type = LangChainTypes.PERSONAL.value
+
+ # deal with existing locations
+ user_base_dir = os.getenv('USERS_BASE_DIR', 'users')
+ persist_directory = os.path.join(user_base_dir, dirid, 'db_dir_%s' % langchain_mode)
+ if userid and \
+ (os.path.isdir(persist_directory) or
+ db1s is not None and langchain_mode in db1s or
+ langchain_type == LangChainTypes.PERSONAL.value):
+ langchain_type = LangChainTypes.PERSONAL.value
+ persist_directory = makedirs(persist_directory, use_base=True)
+ check_persist_directory(persist_directory)
+ return persist_directory, langchain_type
+
+ persist_directory = 'db_dir_%s' % langchain_mode
+ if (os.path.isdir(persist_directory) or
+ dbs is not None and langchain_mode in dbs or
+ langchain_type == LangChainTypes.SHARED.value):
+ # ensure consistent
+ langchain_type = LangChainTypes.SHARED.value
+ persist_directory = makedirs(persist_directory, use_base=True)
+ check_persist_directory(persist_directory)
+ return persist_directory, langchain_type
+
+ # dummy return for prep_langchain() or full personal space
+ base_others = 'db_nonusers'
+ persist_directory = os.path.join(base_others, 'db_dir_%s' % str(uuid.uuid4()))
+ persist_directory = makedirs(persist_directory, use_base=True)
+ langchain_type = LangChainTypes.PERSONAL.value
+
+ check_persist_directory(persist_directory)
+ return persist_directory, langchain_type
+
+
+def check_persist_directory(persist_directory):
+ # deal with some cases when see intrinsic names being used as shared
+ for langchain_mode in langchain_modes_intrinsic:
+ if persist_directory == 'db_dir_%s' % langchain_mode:
+ raise RuntimeError("Illegal access to %s" % persist_directory)
+
+
+def _make_db(use_openai_embedding=False,
+ hf_embedding_model=None,
+ migrate_embedding_model=False,
+ auto_migrate_db=False,
+ first_para=False, text_limit=None,
+ chunk=True, chunk_size=512,
+
+ # urls
+ use_unstructured=True,
+ use_playwright=False,
+ use_selenium=False,
+
+ # pdfs
+ use_pymupdf='auto',
+ use_unstructured_pdf='auto',
+ use_pypdf='auto',
+ enable_pdf_ocr='auto',
+ enable_pdf_doctr='auto',
+ try_pdf_as_html='auto',
+
+ # images
+ enable_ocr=False,
+ enable_doctr=False,
+ enable_pix2struct=False,
+ enable_captions=True,
+ captions_model=None,
+ caption_loader=None,
+ doctr_loader=None,
+ pix2struct_loader=None,
+
+ # json
+ jq_schema='.[]',
+
+ langchain_mode=None,
+ langchain_mode_paths=None,
+ langchain_mode_types=None,
+ db_type='faiss',
+ load_db_if_exists=True,
+ db=None,
+ n_jobs=-1,
+ verbose=False):
+ assert hf_embedding_model is not None
+ user_path = langchain_mode_paths.get(langchain_mode)
+ langchain_type = langchain_mode_types.get(langchain_mode, LangChainTypes.EITHER.value)
+ persist_directory, langchain_type = get_persist_directory(langchain_mode, langchain_type=langchain_type)
+ langchain_mode_types[langchain_mode] = langchain_type
+ # see if can get persistent chroma db
+ db_trial, use_openai_embedding, hf_embedding_model = \
+ get_existing_db(db, persist_directory, load_db_if_exists, db_type,
+ use_openai_embedding,
+ langchain_mode, langchain_mode_paths, langchain_mode_types,
+ hf_embedding_model, migrate_embedding_model, auto_migrate_db, verbose=verbose,
+ n_jobs=n_jobs)
+ if db_trial is not None:
+ db = db_trial
+
+ sources = []
+ if not db:
+ chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type)
+ if langchain_mode in ['wiki_full']:
+ from read_wiki_full import get_all_documents
+ small_test = None
+ print("Generating new wiki", flush=True)
+ sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
+ print("Got new wiki", flush=True)
+ sources1 = chunk_sources(sources1, chunk=chunk)
+ print("Chunked new wiki", flush=True)
+ sources.extend(sources1)
+ elif langchain_mode in ['wiki']:
+ sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
+ sources1 = chunk_sources(sources1, chunk=chunk)
+ sources.extend(sources1)
+ elif langchain_mode in ['github h2oGPT']:
+ # sources = get_github_docs("dagster-io", "dagster")
+ sources1 = get_github_docs("h2oai", "h2ogpt")
+ # FIXME: always chunk for now
+ sources1 = chunk_sources(sources1)
+ sources.extend(sources1)
+ elif langchain_mode in ['DriverlessAI docs']:
+ sources1 = get_dai_docs(from_hf=True)
+ # FIXME: DAI docs are already chunked well, should only chunk more if over limit
+ sources1 = chunk_sources(sources1, chunk=False)
+ sources.extend(sources1)
+ if user_path:
+ # UserData or custom, which has to be from user's disk
+ if db is not None:
+ # NOTE: Ignore file names for now, only go by hash ids
+ # existing_files = get_existing_files(db)
+ existing_files = []
+ existing_hash_ids = get_existing_hash_ids(db)
+ else:
+ # pretend no existing files so won't filter
+ existing_files = []
+ existing_hash_ids = []
+ # chunk internally for speed over multiple docs
+ # FIXME: If first had old Hash=None and switch embeddings,
+ # then re-embed, and then hit here and reload so have hash, and then re-embed.
+ sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
+ # urls
+ use_unstructured=use_unstructured,
+ use_playwright=use_playwright,
+ use_selenium=use_selenium,
+
+ # pdfs
+ use_pymupdf=use_pymupdf,
+ use_unstructured_pdf=use_unstructured_pdf,
+ use_pypdf=use_pypdf,
+ enable_pdf_ocr=enable_pdf_ocr,
+ enable_pdf_doctr=enable_pdf_doctr,
+ try_pdf_as_html=try_pdf_as_html,
+
+ # images
+ enable_ocr=enable_ocr,
+ enable_doctr=enable_doctr,
+ enable_pix2struct=enable_pix2struct,
+ enable_captions=enable_captions,
+ captions_model=captions_model,
+ caption_loader=caption_loader,
+ doctr_loader=doctr_loader,
+ pix2struct_loader=pix2struct_loader,
+
+ # json
+ jq_schema=jq_schema,
+
+ existing_files=existing_files, existing_hash_ids=existing_hash_ids,
+ db_type=db_type)
+ new_metadata_sources = set([x.metadata['source'] for x in sources1])
+ if new_metadata_sources:
+ if os.getenv('NO_NEW_FILES') is not None:
+ raise RuntimeError("Expected no new files! %s" % new_metadata_sources)
+ print("Loaded %s new files as sources to add to %s" % (len(new_metadata_sources), langchain_mode),
+ flush=True)
+ if verbose:
+ print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
+ sources.extend(sources1)
+ if len(sources) > 0 and os.getenv('NO_NEW_FILES') is not None:
+ raise RuntimeError("Expected no new files! %s" % langchain_mode)
+ if len(sources) == 0 and os.getenv('SHOULD_NEW_FILES') is not None:
+ raise RuntimeError("Expected new files! %s" % langchain_mode)
+ print("Loaded %s sources for potentially adding to %s" % (len(sources), langchain_mode), flush=True)
+
+ # see if got sources
+ if not sources:
+ if verbose:
+ if db is not None:
+ print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True)
+ else:
+ print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True)
+ return db, 0, []
+ if verbose:
+ if db is not None:
+ print("Generating db", flush=True)
+ else:
+ print("Adding to db", flush=True)
+ if not db:
+ if sources:
+ db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
+ persist_directory=persist_directory,
+ langchain_mode=langchain_mode,
+ langchain_mode_paths=langchain_mode_paths,
+ langchain_mode_types=langchain_mode_types,
+ hf_embedding_model=hf_embedding_model,
+ migrate_embedding_model=migrate_embedding_model,
+ auto_migrate_db=auto_migrate_db,
+ n_jobs=n_jobs)
+ if verbose:
+ print("Generated db", flush=True)
+ elif langchain_mode not in langchain_modes_intrinsic:
+ print("Did not generate db for %s since no sources" % langchain_mode, flush=True)
+ new_sources_metadata = [x.metadata for x in sources]
+ elif user_path is not None:
+ print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
+ db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
+ use_openai_embedding=use_openai_embedding,
+ hf_embedding_model=hf_embedding_model)
+ print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
+ else:
+ new_sources_metadata = [x.metadata for x in sources]
+
+ return db, len(new_sources_metadata), new_sources_metadata
+
+
+def get_metadatas(db):
+ metadatas = []
+ from langchain.vectorstores import FAISS
+ if isinstance(db, FAISS):
+ metadatas = [v.metadata for k, v in db.docstore._dict.items()]
+ elif isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db):
+ metadatas = get_documents(db)['metadatas']
+ elif db is not None:
+ # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
+ # seems no way to get all metadata, so need to avoid this approach for weaviate
+ metadatas = [x.metadata for x in db.similarity_search("", k=10000)]
+ return metadatas
+
+
+def get_db_lock_file(db, lock_type='getdb'):
+ if hasattr(db, '_persist_directory'):
+ persist_directory = db._persist_directory
+ check_persist_directory(persist_directory)
+ base_path = os.path.join('locks', persist_directory)
+ base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
+ lock_file = os.path.join(base_path, "%s.lock" % lock_type)
+ makedirs(os.path.dirname(lock_file)) # ensure made
+ return lock_file
+ return None
+
+
+def get_documents(db):
+ if hasattr(db, '_persist_directory'):
+ lock_file = get_db_lock_file(db)
+ with filelock.FileLock(lock_file):
+ # get segfaults and other errors when multiple threads access this
+ return _get_documents(db)
+ else:
+ return _get_documents(db)
+
+
+def _get_documents(db):
+ from langchain.vectorstores import FAISS
+ if isinstance(db, FAISS):
+ documents = [v for k, v in db.docstore._dict.items()]
+ documents = dict(documents=documents)
+ elif isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db):
+ documents = db.get()
+ else:
+ # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
+ # seems no way to get all metadata, so need to avoid this approach for weaviate
+ documents = [x for x in db.similarity_search("", k=10000)]
+ documents = dict(documents=documents)
+ return documents
+
+
+def get_docs_and_meta(db, top_k_docs, filter_kwargs={}, text_context_list=None):
+ if hasattr(db, '_persist_directory'):
+ lock_file = get_db_lock_file(db)
+ with filelock.FileLock(lock_file):
+ return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, text_context_list=text_context_list)
+ else:
+ return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, text_context_list=text_context_list)
+
+
+def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}, text_context_list=None):
+ db_documents = []
+ db_metadatas = []
+
+ if text_context_list:
+ db_documents += [x.page_content if hasattr(x, 'page_content') else x for x in text_context_list]
+ db_metadatas += [x.metadata if hasattr(x, 'metadata') else {} for x in text_context_list]
+
+ from langchain.vectorstores import FAISS
+ if isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db):
+ db_get = db._collection.get(where=filter_kwargs.get('filter'))
+ db_metadatas += db_get['metadatas']
+ db_documents += db_get['documents']
+ elif isinstance(db, FAISS):
+ import itertools
+ db_metadatas += get_metadatas(db)
+ # FIXME: FAISS has no filter
+ if top_k_docs == -1:
+ db_documents += list(db.docstore._dict.values())
+ else:
+ # slice dict first
+ db_documents += list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values())
+ elif db is not None:
+ db_metadatas += get_metadatas(db)
+ db_documents += get_documents(db)['documents']
+
+ return db_documents, db_metadatas
+
+
+def get_existing_files(db):
+ metadatas = get_metadatas(db)
+ metadata_sources = set([x['source'] for x in metadatas])
+ return metadata_sources
+
+
+def get_existing_hash_ids(db):
+ metadatas = get_metadatas(db)
+ # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
+ metadata_hash_ids = {os.path.normpath(x['source']): x.get('hashid') for x in metadatas}
+ return metadata_hash_ids
+
+
+def run_qa_db(**kwargs):
+ func_names = list(inspect.signature(_run_qa_db).parameters)
+ # hard-coded defaults
+ kwargs['answer_with_sources'] = kwargs.get('answer_with_sources', True)
+ kwargs['show_rank'] = kwargs.get('show_rank', False)
+ kwargs['show_accordions'] = kwargs.get('show_accordions', True)
+ kwargs['show_link_in_sources'] = kwargs.get('show_link_in_sources', True)
+ kwargs['top_k_docs_max_show'] = kwargs.get('top_k_docs_max_show', 10)
+ kwargs['llamacpp_dict'] = {} # shouldn't be required unless from test using _run_qa_db
+ missing_kwargs = [x for x in func_names if x not in kwargs]
+ assert not missing_kwargs, "Missing kwargs for run_qa_db: %s" % missing_kwargs
+ # only keep actual used
+ kwargs = {k: v for k, v in kwargs.items() if k in func_names}
+ try:
+ return _run_qa_db(**kwargs)
+ finally:
+ clear_torch_cache()
+
+
+def _run_qa_db(query=None,
+ iinput=None,
+ context=None,
+ use_openai_model=False, use_openai_embedding=False,
+ first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
+
+ # urls
+ use_unstructured=True,
+ use_playwright=False,
+ use_selenium=False,
+
+ # pdfs
+ use_pymupdf='auto',
+ use_unstructured_pdf='auto',
+ use_pypdf='auto',
+ enable_pdf_ocr='auto',
+ enable_pdf_doctr='auto',
+ try_pdf_as_html='auto',
+
+ # images
+ enable_ocr=False,
+ enable_doctr=False,
+ enable_pix2struct=False,
+ enable_captions=True,
+ captions_model=None,
+ caption_loader=None,
+ doctr_loader=None,
+ pix2struct_loader=None,
+
+ # json
+ jq_schema='.[]',
+
+ langchain_mode_paths={},
+ langchain_mode_types={},
+ detect_user_path_changes_every_query=False,
+ db_type=None,
+ model_name=None, model=None, tokenizer=None, inference_server=None,
+ langchain_only_model=False,
+ hf_embedding_model=None,
+ migrate_embedding_model=False,
+ auto_migrate_db=False,
+ stream_output=False,
+ async_output=True,
+ num_async=3,
+ prompter=None,
+ prompt_type=None,
+ prompt_dict=None,
+ answer_with_sources=True,
+ append_sources_to_answer=True,
+ cut_distance=1.64,
+ add_chat_history_to_context=True,
+ add_search_to_context=False,
+ keep_sources_in_context=False,
+ memory_restriction_level=0,
+ system_prompt='',
+ sanitize_bot_response=False,
+ show_rank=False,
+ show_accordions=True,
+ show_link_in_sources=True,
+ top_k_docs_max_show=10,
+ use_llm_if_no_docs=True,
+ load_db_if_exists=False,
+ db=None,
+ do_sample=False,
+ temperature=0.1,
+ top_k=40,
+ top_p=0.7,
+ num_beams=1,
+ max_new_tokens=512,
+ min_new_tokens=1,
+ early_stopping=False,
+ max_time=180,
+ repetition_penalty=1.0,
+ num_return_sequences=1,
+ langchain_mode=None,
+ langchain_action=None,
+ langchain_agents=None,
+ document_subset=DocumentSubset.Relevant.name,
+ document_choice=[DocumentChoice.ALL.value],
+ pre_prompt_query=None,
+ prompt_query=None,
+ pre_prompt_summary=None,
+ prompt_summary=None,
+ text_context_list=None,
+ chat_conversation=None,
+ visible_models=None,
+ h2ogpt_key=None,
+ docs_ordering_type='reverse_ucurve_sort',
+ min_max_new_tokens=256,
+
+ n_jobs=-1,
+ llamacpp_dict=None,
+ verbose=False,
+ cli=False,
+ lora_weights='',
+ auto_reduce_chunks=True,
+ max_chunks=100,
+ total_tokens_for_docs=None,
+ headsize=50,
+ ):
+ """
+
+ :param query:
+ :param use_openai_model:
+ :param use_openai_embedding:
+ :param first_para:
+ :param text_limit:
+ :param top_k_docs:
+ :param chunk:
+ :param chunk_size:
+ :param langchain_mode_paths: dict of langchain_mode -> user path to glob recursively from
+ :param db_type: 'faiss' for in-memory
+ 'chroma' (for chroma >= 0.4)
+ 'chroma_old' (for chroma < 0.4)
+ 'weaviate' for persisted on disk
+ :param model_name: model name, used to switch behaviors
+ :param model: pre-initialized model, else will make new one
+ :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
+ :param answer_with_sources
+ :return:
+ """
+ t_run = time.time()
+ if stream_output:
+ # threads and asyncio don't mix
+ async_output = False
+ if langchain_action in [LangChainAction.QUERY.value]:
+ # only summarization supported
+ async_output = False
+
+ # in case None, e.g. lazy client, then set based upon actual model
+ pre_prompt_query, prompt_query, pre_prompt_summary, prompt_summary = \
+ get_langchain_prompts(pre_prompt_query, prompt_query,
+ pre_prompt_summary, prompt_summary,
+ model_name, inference_server,
+ llamacpp_dict.get('model_path_llama'))
+
+ assert db_type is not None
+ assert hf_embedding_model is not None
+ assert langchain_mode_paths is not None
+ assert langchain_mode_types is not None
+ if model is not None:
+ assert model_name is not None # require so can make decisions
+ assert query is not None
+ assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
+ if prompter is not None:
+ prompt_type = prompter.prompt_type
+ prompt_dict = prompter.prompt_dict
+ if model is not None:
+ assert prompt_type is not None
+ if prompt_type == PromptType.custom.name:
+ assert prompt_dict is not None # should at least be {} or ''
+ else:
+ prompt_dict = ''
+
+ if LangChainAgent.SEARCH.value in langchain_agents and 'llama' in model_name.lower():
+ system_prompt = """You are a zero shot react agent.
+Consider to prompt of Question that was original query from the user.
+Respond to prompt of Thought with a thought that may lead to a reasonable new action choice.
+Respond to prompt of Action with an action to take out of the tools given, giving exactly single word for the tool name.
+Respond to prompt of Action Input with an input to give the tool.
+Consider to prompt of Observation that was response from the tool.
+Repeat this Thought, Action, Action Input, Observation, Thought sequence several times with new and different thoughts and actions each time, do not repeat.
+Once satisfied that the thoughts, responses are sufficient to answer the question, then respond to prompt of Thought with: I now know the final answer
+Respond to prompt of Final Answer with your final high-quality bullet list answer to the original query.
+"""
+ prompter.system_prompt = system_prompt
+
+ assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
+ # pass in context to LLM directly, since already has prompt_type structure
+ # can't pass through langchain in get_chain() to LLM: https://github.com/hwchase17/langchain/issues/6638
+ llm, model_name, streamer, prompt_type_out, async_output, only_new_text = \
+ get_llm(use_openai_model=use_openai_model, model_name=model_name,
+ model=model,
+ tokenizer=tokenizer,
+ inference_server=inference_server,
+ langchain_only_model=langchain_only_model,
+ stream_output=stream_output,
+ async_output=async_output,
+ num_async=num_async,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ early_stopping=early_stopping,
+ max_time=max_time,
+ repetition_penalty=repetition_penalty,
+ num_return_sequences=num_return_sequences,
+ prompt_type=prompt_type,
+ prompt_dict=prompt_dict,
+ prompter=prompter,
+ context=context,
+ iinput=iinput,
+ sanitize_bot_response=sanitize_bot_response,
+ system_prompt=system_prompt,
+ visible_models=visible_models,
+ h2ogpt_key=h2ogpt_key,
+ min_max_new_tokens=min_max_new_tokens,
+ n_jobs=n_jobs,
+ llamacpp_dict=llamacpp_dict,
+ cli=cli,
+ verbose=verbose,
+ )
+ # in case change, override original prompter
+ if hasattr(llm, 'prompter'):
+ prompter = llm.prompter
+ if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'prompter'):
+ prompter = llm.pipeline.prompter
+
+ if prompter is None:
+ if prompt_type is None:
+ prompt_type = prompt_type_out
+ # get prompter
+ chat = True # FIXME?
+ prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output,
+ system_prompt=system_prompt)
+
+ use_docs_planned = False
+ scores = []
+ chain = None
+
+ # basic version of prompt without docs etc.
+ data_point = dict(context=context, instruction=query, input=iinput)
+ prompt_basic = prompter.generate_prompt(data_point)
+
+ if isinstance(document_choice, str):
+ # support string as well
+ document_choice = [document_choice]
+
+ func_names = list(inspect.signature(get_chain).parameters)
+ sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
+ missing_kwargs = [x for x in func_names if x not in sim_kwargs]
+ assert not missing_kwargs, "Missing: %s" % missing_kwargs
+ docs, chain, scores, \
+ use_docs_planned, num_docs_before_cut, \
+ use_llm_if_no_docs, llm_mode, top_k_docs_max_show = \
+ get_chain(**sim_kwargs)
+ if document_subset in non_query_commands:
+ formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
+ if not formatted_doc_chunks and not use_llm_if_no_docs:
+ yield dict(prompt=prompt_basic, response="No sources", sources='', num_prompt_tokens=0)
+ return
+ # if no souces, outside gpt_langchain, LLM will be used with '' input
+ scores = [1] * len(docs)
+ get_answer_args = tuple([query, docs, formatted_doc_chunks, scores, show_rank,
+ answer_with_sources,
+ append_sources_to_answer])
+ get_answer_kwargs = dict(show_accordions=show_accordions,
+ show_link_in_sources=show_link_in_sources,
+ top_k_docs_max_show=top_k_docs_max_show,
+ docs_ordering_type=docs_ordering_type,
+ num_docs_before_cut=num_docs_before_cut,
+ verbose=verbose)
+ ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
+ yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0)
+ return
+ if not use_llm_if_no_docs:
+ if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
+ LangChainAction.SUMMARIZE_ALL.value,
+ LangChainAction.SUMMARIZE_REFINE.value]:
+ ret = 'No relevant documents to summarize.' if num_docs_before_cut else 'No documents to summarize.'
+ extra = ''
+ yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
+ return
+ if not docs and not llm_mode:
+ ret = 'No relevant documents to query (for chatting with LLM, pick Resources->Collections->LLM).' if num_docs_before_cut else 'No documents to query (for chatting with LLM, pick Resources->Collections->LLM).'
+ extra = ''
+ yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
+ return
+
+ if chain is None and not langchain_only_model:
+ # here if no docs at all and not HF type
+ # can only return if HF type
+ return
+
+ # context stuff similar to used in evaluate()
+ import torch
+ device, torch_dtype, context_class = get_device_dtype()
+ conditional_type = hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'model') and hasattr(llm.pipeline.model,
+ 'conditional_type') and llm.pipeline.model.conditional_type
+ with torch.no_grad():
+ have_lora_weights = lora_weights not in [no_lora_str, '', None]
+ context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
+ if conditional_type:
+ # issues when casting to float16, can mess up t5 model, e.g. only when not streaming, or other odd behaviors
+ context_class_cast = NullContext
+ with context_class_cast(device):
+ if stream_output and streamer:
+ answer = None
+ import queue
+ bucket = queue.Queue()
+ thread = EThread(target=chain, streamer=streamer, bucket=bucket)
+ thread.start()
+ outputs = ""
+ try:
+ for new_text in streamer:
+ # print("new_text: %s" % new_text, flush=True)
+ if bucket.qsize() > 0 or thread.exc:
+ thread.join()
+ outputs += new_text
+ if prompter: # and False: # FIXME: pipeline can already use prompter
+ if conditional_type:
+ if prompter.botstr:
+ prompt = prompter.botstr
+ output_with_prompt = prompt + outputs
+ only_new_text = False
+ else:
+ prompt = None
+ output_with_prompt = outputs
+ only_new_text = True
+ else:
+ prompt = None # FIXME
+ output_with_prompt = outputs
+ # don't specify only_new_text here, use get_llm() value
+ output1 = prompter.get_response(output_with_prompt, prompt=prompt,
+ only_new_text=only_new_text,
+ sanitize_bot_response=sanitize_bot_response)
+ yield dict(prompt=prompt, response=output1, sources='', num_prompt_tokens=0)
+ else:
+ yield dict(prompt=prompt, response=outputs, sources='', num_prompt_tokens=0)
+ except BaseException:
+ # if any exception, raise that exception if was from thread, first
+ if thread.exc:
+ raise thread.exc
+ raise
+ finally:
+ # in case no exception and didn't join with thread yet, then join
+ if not thread.exc:
+ answer = thread.join()
+ if isinstance(answer, dict):
+ if 'output_text' in answer:
+ answer = answer['output_text']
+ elif 'output' in answer:
+ answer = answer['output']
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
+ if thread.exc:
+ raise thread.exc
+ else:
+ if async_output:
+ import asyncio
+ answer = asyncio.run(chain())
+ else:
+ answer = chain()
+ if isinstance(answer, dict):
+ if 'output_text' in answer:
+ answer = answer['output_text']
+ elif 'output' in answer:
+ answer = answer['output']
+
+ get_answer_args = tuple([query, docs, answer, scores, show_rank,
+ answer_with_sources,
+ append_sources_to_answer])
+ get_answer_kwargs = dict(show_accordions=show_accordions,
+ show_link_in_sources=show_link_in_sources,
+ top_k_docs_max_show=top_k_docs_max_show,
+ docs_ordering_type=docs_ordering_type,
+ num_docs_before_cut=num_docs_before_cut,
+ verbose=verbose,
+ t_run=t_run,
+ count_input_tokens=llm.count_input_tokens
+ if hasattr(llm, 'count_input_tokens') else None,
+ count_output_tokens=llm.count_output_tokens
+ if hasattr(llm, 'count_output_tokens') else None)
+
+ t_run = time.time() - t_run
+
+ # for final yield, get real prompt used
+ if hasattr(llm, 'prompter') and llm.prompter.prompt is not None:
+ prompt = llm.prompter.prompt
+ else:
+ prompt = prompt_basic
+ num_prompt_tokens = get_token_count(prompt, tokenizer)
+
+ if not use_docs_planned:
+ ret = answer
+ extra = ''
+ yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
+ elif answer is not None:
+ ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
+ yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
+ return
+
+
+def get_docs_with_score(query, k_db, filter_kwargs, db, db_type, text_context_list=None, verbose=False):
+ docs_with_score = []
+ got_db_docs = False
+
+ if text_context_list:
+ docs_with_score += [(x, x.metadata.get('score', 1.0)) for x in text_context_list]
+
+ # deal with bug in chroma where if (say) 234 doc chunks and ask for 233+ then fails due to reduction misbehavior
+ if hasattr(db, '_embedding_function') and isinstance(db._embedding_function, FakeEmbeddings):
+ top_k_docs = -1
+ # don't add text_context_list twice
+ db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
+ text_context_list=None)
+ # sort by order given to parser (file_id) and any chunk_id if chunked
+ doc_file_ids = [x.get('file_id', 0) for x in db_metadatas]
+ doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
+ docs_with_score_fake = [(Document(page_content=result[0], metadata=result[1] or {}), 1.0)
+ for result in zip(db_documents, db_metadatas)]
+ docs_with_score_fake = [x for fx, cx, x in
+ sorted(zip(doc_file_ids, doc_chunk_ids, docs_with_score_fake),
+ key=lambda x: (x[0], x[1]))
+ ]
+ got_db_docs |= len(docs_with_score_fake) > 0
+ docs_with_score += docs_with_score_fake
+ elif db is not None and db_type in ['chroma', 'chroma_old']:
+ while True:
+ try:
+ docs_with_score_chroma = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)
+ break
+ except (RuntimeError, AttributeError) as e:
+ # AttributeError is for people with wrong version of langchain
+ if verbose:
+ print("chroma bug: %s" % str(e), flush=True)
+ if k_db == 1:
+ raise
+ if k_db > 500:
+ k_db -= 200
+ elif k_db > 100:
+ k_db -= 50
+ elif k_db > 10:
+ k_db -= 5
+ else:
+ k_db -= 1
+ k_db = max(1, k_db)
+ got_db_docs |= len(docs_with_score_chroma) > 0
+ docs_with_score += docs_with_score_chroma
+ elif db is not None:
+ docs_with_score_other = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)
+ got_db_docs |= len(docs_with_score_other) > 0
+ docs_with_score += docs_with_score_other
+
+ # set in metadata original order of docs
+ [x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)]
+
+ return docs_with_score, got_db_docs
+
+
+def get_chain(query=None,
+ iinput=None,
+ context=None, # FIXME: https://github.com/hwchase17/langchain/issues/6638
+ use_openai_model=False, use_openai_embedding=False,
+ first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
+
+ # urls
+ use_unstructured=True,
+ use_playwright=False,
+ use_selenium=False,
+
+ # pdfs
+ use_pymupdf='auto',
+ use_unstructured_pdf='auto',
+ use_pypdf='auto',
+ enable_pdf_ocr='auto',
+ enable_pdf_doctr='auto',
+ try_pdf_as_html='auto',
+
+ # images
+ enable_ocr=False,
+ enable_doctr=False,
+ enable_pix2struct=False,
+ enable_captions=True,
+ captions_model=None,
+ caption_loader=None,
+ doctr_loader=None,
+ pix2struct_loader=None,
+
+ # json
+ jq_schema='.[]',
+
+ langchain_mode_paths=None,
+ langchain_mode_types=None,
+ detect_user_path_changes_every_query=False,
+ db_type='faiss',
+ model_name=None,
+ inference_server='',
+ max_new_tokens=None,
+ langchain_only_model=False,
+ hf_embedding_model=None,
+ migrate_embedding_model=False,
+ auto_migrate_db=False,
+ prompter=None,
+ prompt_type=None,
+ prompt_dict=None,
+ system_prompt=None,
+ cut_distance=1.1,
+ add_chat_history_to_context=True, # FIXME: https://github.com/hwchase17/langchain/issues/6638
+ add_search_to_context=False,
+ keep_sources_in_context=False,
+ memory_restriction_level=0,
+ top_k_docs_max_show=10,
+
+ load_db_if_exists=False,
+ db=None,
+ langchain_mode=None,
+ langchain_action=None,
+ langchain_agents=None,
+ document_subset=DocumentSubset.Relevant.name,
+ document_choice=[DocumentChoice.ALL.value],
+ pre_prompt_query=None,
+ prompt_query=None,
+ pre_prompt_summary=None,
+ prompt_summary=None,
+ text_context_list=None,
+ chat_conversation=None,
+
+ n_jobs=-1,
+ # beyond run_db_query:
+ llm=None,
+ tokenizer=None,
+ verbose=False,
+ docs_ordering_type='reverse_ucurve_sort',
+ min_max_new_tokens=256,
+ stream_output=True,
+ async_output=True,
+
+ # local
+ auto_reduce_chunks=True,
+ max_chunks=100,
+ total_tokens_for_docs=None,
+ use_llm_if_no_docs=None,
+ headsize=50,
+ ):
+ if inference_server is None:
+ inference_server = ''
+ assert hf_embedding_model is not None
+ assert langchain_agents is not None # should be at least []
+ if text_context_list is None:
+ text_context_list = []
+
+ # default value:
+ llm_mode = langchain_mode in ['Disabled', 'LLM'] and len(text_context_list) == 0
+ query_action = langchain_action == LangChainAction.QUERY.value
+ summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
+ LangChainAction.SUMMARIZE_ALL.value,
+ LangChainAction.SUMMARIZE_REFINE.value]
+
+ if len(text_context_list) > 0:
+ # turn into documents to make easy to manage and add meta
+ # try to account for summarization vs. query
+ chunk_id = 0 if query_action else -1
+ text_context_list = [
+ Document(page_content=x, metadata=dict(source='text_context_list', score=1.0, chunk_id=chunk_id)) for x
+ in text_context_list]
+
+ if add_search_to_context:
+ params = {
+ "engine": "duckduckgo",
+ "gl": "us",
+ "hl": "en",
+ }
+ search = H2OSerpAPIWrapper(params=params)
+ # if doing search, allow more docs
+ docs_search, top_k_docs = search.get_search_documents(query,
+ query_action=query_action,
+ chunk=chunk, chunk_size=chunk_size,
+ db_type=db_type,
+ headsize=headsize,
+ top_k_docs=top_k_docs)
+ text_context_list = docs_search + text_context_list
+ add_search_to_context &= len(docs_search) > 0
+ top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search))
+
+ if len(text_context_list) > 0:
+ llm_mode = False
+ use_llm_if_no_docs = True
+
+ from src.output_parser import H2OMRKLOutputParser
+ from langchain.agents import AgentType, load_tools, initialize_agent, create_vectorstore_agent, \
+ create_pandas_dataframe_agent, create_json_agent, create_csv_agent
+ from langchain.agents.agent_toolkits import VectorStoreInfo, VectorStoreToolkit, create_python_agent, JsonToolkit
+ if LangChainAgent.SEARCH.value in langchain_agents:
+ output_parser = H2OMRKLOutputParser()
+ tools = load_tools(["serpapi"], llm=llm, serpapi_api_key=os.environ.get('SERPAPI_API_KEY'))
+ if inference_server.startswith('openai'):
+ agent_type = AgentType.OPENAI_FUNCTIONS
+ agent_executor_kwargs = {"handle_parsing_errors": True, 'output_parser': output_parser}
+ else:
+ agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION
+ agent_executor_kwargs = {'output_parser': output_parser}
+ chain = initialize_agent(tools, llm, agent=agent_type,
+ agent_executor_kwargs=agent_executor_kwargs,
+ agent_kwargs=dict(output_parser=output_parser,
+ format_instructions=output_parser.get_format_instructions()),
+ output_parser=output_parser,
+ max_iterations=10,
+ verbose=True)
+ chain_kwargs = dict(input=query)
+ target = wrapped_partial(chain, chain_kwargs)
+
+ docs = []
+ scores = []
+ use_docs_planned = False
+ num_docs_before_cut = 0
+ use_llm_if_no_docs = True
+ return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
+
+ if LangChainAgent.COLLECTION.value in langchain_agents:
+ output_parser = H2OMRKLOutputParser()
+ vectorstore_info = VectorStoreInfo(
+ name=langchain_mode,
+ description="DataBase of text from PDFs, Image Captions, or web URL content",
+ vectorstore=db,
+ )
+ toolkit = VectorStoreToolkit(vectorstore_info=vectorstore_info)
+ chain = create_vectorstore_agent(llm=llm, toolkit=toolkit,
+ agent_executor_kwargs=dict(output_parser=output_parser),
+ verbose=True)
+
+ chain_kwargs = dict(input=query)
+ target = wrapped_partial(chain, chain_kwargs)
+
+ docs = []
+ scores = []
+ use_docs_planned = False
+ num_docs_before_cut = 0
+ use_llm_if_no_docs = True
+ return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
+
+ if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'):
+ chain = create_python_agent(
+ llm=llm,
+ tool=PythonREPLTool(),
+ verbose=True,
+ agent_type=AgentType.OPENAI_FUNCTIONS,
+ agent_executor_kwargs={"handle_parsing_errors": True},
+ )
+
+ chain_kwargs = dict(input=query)
+ target = wrapped_partial(chain, chain_kwargs)
+
+ docs = []
+ scores = []
+ use_docs_planned = False
+ num_docs_before_cut = 0
+ use_llm_if_no_docs = True
+ return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
+
+ if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'):
+ # FIXME: DATA
+ df = pd.DataFrame(None)
+ chain = create_pandas_dataframe_agent(
+ llm,
+ df,
+ verbose=True,
+ agent_type=AgentType.OPENAI_FUNCTIONS,
+ )
+
+ chain_kwargs = dict(input=query)
+ target = wrapped_partial(chain, chain_kwargs)
+
+ docs = []
+ scores = []
+ use_docs_planned = False
+ num_docs_before_cut = 0
+ use_llm_if_no_docs = True
+ return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
+
+ if isinstance(document_choice, str):
+ document_choice = [document_choice]
+ if document_choice and document_choice[0] == DocumentChoice.ALL.value:
+ document_choice_agent = document_choice[1:]
+ else:
+ document_choice_agent = document_choice
+ document_choice_agent = [x for x in document_choice_agent if x.endswith('.json')]
+ if LangChainAgent.JSON.value in \
+ langchain_agents and \
+ inference_server.startswith('openai_chat') and \
+ len(document_choice_agent) == 1 and \
+ document_choice_agent[0].endswith('.json'):
+ # with open('src/openai.yaml') as f:
+ # data = yaml.load(f, Loader=yaml.FullLoader)
+ with open(document_choice[0], 'rt') as f:
+ data = json.loads(f.read())
+ json_spec = JsonSpec(dict_=data, max_value_length=4000)
+ json_toolkit = JsonToolkit(spec=json_spec)
+
+ chain = create_json_agent(
+ llm=llm, toolkit=json_toolkit, verbose=True
+ )
+
+ chain_kwargs = dict(input=query)
+ target = wrapped_partial(chain, chain_kwargs)
+
+ docs = []
+ scores = []
+ use_docs_planned = False
+ num_docs_before_cut = 0
+ use_llm_if_no_docs = True
+ return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
+
+ if isinstance(document_choice, str):
+ document_choice = [document_choice]
+ if document_choice and document_choice[0] == DocumentChoice.ALL.value:
+ document_choice_agent = document_choice[1:]
+ else:
+ document_choice_agent = document_choice
+ document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')]
+ if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[
+ 0].endswith(
+ '.csv'):
+ data_file = document_choice[0]
+ if inference_server.startswith('openai_chat'):
+ chain = create_csv_agent(
+ llm,
+ data_file,
+ verbose=True,
+ agent_type=AgentType.OPENAI_FUNCTIONS,
+ )
+ else:
+ chain = create_csv_agent(
+ llm,
+ data_file,
+ verbose=True,
+ agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
+ )
+ chain_kwargs = dict(input=query)
+ target = wrapped_partial(chain, chain_kwargs)
+
+ docs = []
+ scores = []
+ use_docs_planned = False
+ num_docs_before_cut = 0
+ use_llm_if_no_docs = True
+ return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
+
+ # determine whether use of context out of docs is planned
+ if not use_openai_model and prompt_type not in ['plain'] or langchain_only_model:
+ if llm_mode:
+ use_docs_planned = False
+ else:
+ use_docs_planned = True
+ else:
+ use_docs_planned = True
+
+ # https://github.com/hwchase17/langchain/issues/1946
+ # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
+ # Chroma collection MyData contains fewer than 4 elements.
+ # type logger error
+ if top_k_docs == -1:
+ k_db = 1000 if db_type in ['chroma', 'chroma_old'] else 100
+ else:
+ # top_k_docs=100 works ok too
+ k_db = 1000 if db_type in ['chroma', 'chroma_old'] else top_k_docs
+
+ # FIXME: For All just go over all dbs instead of a separate db for All
+ if not detect_user_path_changes_every_query and db is not None:
+ # avoid looking at user_path during similarity search db handling,
+ # if already have db and not updating from user_path every query
+ # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
+ if langchain_mode_paths is None:
+ langchain_mode_paths = {}
+ langchain_mode_paths = langchain_mode_paths.copy()
+ langchain_mode_paths[langchain_mode] = None
+ # once use_openai_embedding, hf_embedding_model passed in, possibly changed,
+ # but that's ok as not used below or in calling functions
+ db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
+ hf_embedding_model=hf_embedding_model,
+ migrate_embedding_model=migrate_embedding_model,
+ auto_migrate_db=auto_migrate_db,
+ first_para=first_para, text_limit=text_limit,
+ chunk=chunk, chunk_size=chunk_size,
+
+ # urls
+ use_unstructured=use_unstructured,
+ use_playwright=use_playwright,
+ use_selenium=use_selenium,
+
+ # pdfs
+ use_pymupdf=use_pymupdf,
+ use_unstructured_pdf=use_unstructured_pdf,
+ use_pypdf=use_pypdf,
+ enable_pdf_ocr=enable_pdf_ocr,
+ enable_pdf_doctr=enable_pdf_doctr,
+ try_pdf_as_html=try_pdf_as_html,
+
+ # images
+ enable_ocr=enable_ocr,
+ enable_doctr=enable_doctr,
+ enable_pix2struct=enable_pix2struct,
+ enable_captions=enable_captions,
+ captions_model=captions_model,
+ caption_loader=caption_loader,
+ doctr_loader=doctr_loader,
+ pix2struct_loader=pix2struct_loader,
+
+ # json
+ jq_schema=jq_schema,
+
+ langchain_mode=langchain_mode,
+ langchain_mode_paths=langchain_mode_paths,
+ langchain_mode_types=langchain_mode_types,
+ db_type=db_type,
+ load_db_if_exists=load_db_if_exists,
+ db=db,
+ n_jobs=n_jobs,
+ verbose=verbose)
+ num_docs_before_cut = 0
+ use_template = not use_openai_model and prompt_type not in ['plain'] or langchain_only_model
+ got_db_docs = False # not yet at least
+ template, template_if_no_docs, auto_reduce_chunks, query = \
+ get_template(query, iinput,
+ pre_prompt_query, prompt_query,
+ pre_prompt_summary, prompt_summary,
+ langchain_action,
+ llm_mode,
+ use_docs_planned,
+ auto_reduce_chunks,
+ got_db_docs,
+ add_search_to_context)
+
+ max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
+ model_name=model_name, max_new_tokens=max_new_tokens)
+
+ if (db or text_context_list) and use_docs_planned:
+ if hasattr(db, '_persist_directory'):
+ lock_file = get_db_lock_file(db, lock_type='sim')
+ else:
+ base_path = 'locks'
+ base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
+ name_path = "sim.lock"
+ lock_file = os.path.join(base_path, name_path)
+
+ if not (isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db)):
+ # only chroma supports filtering
+ filter_kwargs = {}
+ filter_kwargs_backup = {}
+ else:
+ import logging
+ logging.getLogger("chromadb").setLevel(logging.ERROR)
+ assert document_choice is not None, "Document choice was None"
+ if isinstance(db, Chroma):
+ filter_kwargs_backup = {} # shouldn't ever need backup
+ # chroma >= 0.4
+ if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
+ 0] == DocumentChoice.ALL.value:
+ filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
+ {"filter": {"chunk_id": {"$eq": -1}}}
+ else:
+ if document_choice[0] == DocumentChoice.ALL.value:
+ document_choice = document_choice[1:]
+ if len(document_choice) == 0:
+ filter_kwargs = {}
+ elif len(document_choice) > 1:
+ or_filter = [
+ {"$and": [dict(source={"$eq": x}), dict(chunk_id={"$gte": 0})]} if query_action else {
+ "$and": [dict(source={"$eq": x}), dict(chunk_id={"$eq": -1})]}
+ for x in document_choice]
+ filter_kwargs = dict(filter={"$or": or_filter})
+ else:
+ # still chromadb UX bug, have to do different thing for 1 vs. 2+ docs when doing filter
+ one_filter = \
+ [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {
+ "source": {"$eq": x},
+ "chunk_id": {
+ "$eq": -1}}
+ for x in document_choice][0]
+
+ filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']),
+ dict(chunk_id=one_filter['chunk_id'])]})
+ else:
+ # migration for chroma < 0.4
+ if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
+ 0] == DocumentChoice.ALL.value:
+ filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
+ {"filter": {"chunk_id": {"$eq": -1}}}
+ filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}}
+ elif len(document_choice) >= 2:
+ if document_choice[0] == DocumentChoice.ALL.value:
+ document_choice = document_choice[1:]
+ or_filter = [
+ {"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
+ "chunk_id": {
+ "$eq": -1}}
+ for x in document_choice]
+ filter_kwargs = dict(filter={"$or": or_filter})
+ or_filter_backup = [
+ {"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
+ for x in document_choice]
+ filter_kwargs_backup = dict(filter={"$or": or_filter_backup})
+ elif len(document_choice) == 1:
+ # degenerate UX bug in chroma
+ one_filter = \
+ [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
+ "chunk_id": {
+ "$eq": -1}}
+ for x in document_choice][0]
+ filter_kwargs = dict(filter=one_filter)
+ one_filter_backup = \
+ [{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
+ for x in document_choice][0]
+ filter_kwargs_backup = dict(filter=one_filter_backup)
+ else:
+ # shouldn't reach
+ filter_kwargs = {}
+ filter_kwargs_backup = {}
+
+ if llm_mode:
+ docs = []
+ scores = []
+ elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
+ db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
+ text_context_list=text_context_list)
+ if len(db_documents) == 0 and filter_kwargs_backup:
+ db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup,
+ text_context_list=text_context_list)
+
+ if top_k_docs == -1:
+ top_k_docs = len(db_documents)
+ # similar to langchain's chroma's _results_to_docs_and_scores
+ docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
+ for result in zip(db_documents, db_metadatas)]
+ # set in metadata original order of docs
+ [x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)]
+
+ # order documents
+ doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
+ if query_action:
+ doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
+ docs_with_score2 = [x for hx, cx, x in
+ sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
+ if cx >= 0]
+ else:
+ assert summarize_action
+ doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas]
+ docs_with_score2 = [x for hx, cx, x in
+ sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
+ if cx == -1
+ ]
+ if len(docs_with_score2) == 0 and len(docs_with_score) > 0:
+ # old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive
+ # just do again and relax filter, let summarize operate on actual chunks if nothing else
+ docs_with_score2 = [x for hx, cx, x in
+ sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score),
+ key=lambda x: (x[0], x[1]))
+ ]
+ docs_with_score = docs_with_score2
+
+ docs_with_score = docs_with_score[:top_k_docs]
+ docs = [x[0] for x in docs_with_score]
+ scores = [x[1] for x in docs_with_score]
+ num_docs_before_cut = len(docs)
+ else:
+ with filelock.FileLock(lock_file):
+ docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type,
+ text_context_list=text_context_list,
+ verbose=verbose)
+ if len(docs_with_score) == 0 and filter_kwargs_backup:
+ docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db,
+ db_type,
+ text_context_list=text_context_list,
+ verbose=verbose)
+
+ tokenizer = get_tokenizer(db=db, llm=llm, tokenizer=tokenizer, inference_server=inference_server,
+ use_openai_model=use_openai_model,
+ db_type=db_type)
+ # NOTE: if map_reduce, then no need to auto reduce chunks
+ if query_action and (top_k_docs == -1 or auto_reduce_chunks):
+ top_k_docs_tokenize = 100
+ docs_with_score = docs_with_score[:top_k_docs_tokenize]
+
+ prompt_no_docs = template.format(context='', question=query)
+
+ model_max_length = tokenizer.model_max_length
+ chat = True # FIXME?
+
+ # first docs_with_score are most important with highest score
+ full_prompt, \
+ instruction, iinput, context, \
+ num_prompt_tokens, max_new_tokens, \
+ num_prompt_tokens0, num_prompt_tokens_actual, \
+ chat_index, top_k_docs_trial, one_doc_size = \
+ get_limited_prompt(prompt_no_docs,
+ iinput,
+ tokenizer,
+ prompter=prompter,
+ inference_server=inference_server,
+ prompt_type=prompt_type,
+ prompt_dict=prompt_dict,
+ chat=chat,
+ max_new_tokens=max_new_tokens,
+ system_prompt=system_prompt,
+ context=context,
+ chat_conversation=chat_conversation,
+ text_context_list=[x[0].page_content for x in docs_with_score],
+ keep_sources_in_context=keep_sources_in_context,
+ model_max_length=model_max_length,
+ memory_restriction_level=memory_restriction_level,
+ langchain_mode=langchain_mode,
+ add_chat_history_to_context=add_chat_history_to_context,
+ min_max_new_tokens=min_max_new_tokens,
+ )
+ # avoid craziness
+ if 0 < top_k_docs_trial < max_chunks:
+ # avoid craziness
+ if top_k_docs == -1:
+ top_k_docs = top_k_docs_trial
+ else:
+ top_k_docs = min(top_k_docs, top_k_docs_trial)
+ elif top_k_docs_trial >= max_chunks:
+ top_k_docs = max_chunks
+ if top_k_docs > 0:
+ docs_with_score = docs_with_score[:top_k_docs]
+ elif one_doc_size is not None:
+ docs_with_score = [docs_with_score[0][:one_doc_size]]
+ else:
+ docs_with_score = []
+ else:
+ if total_tokens_for_docs is not None:
+ # used to limit tokens for summarization, e.g. public instance
+ top_k_docs, one_doc_size, num_doc_tokens = \
+ get_docs_tokens(tokenizer,
+ text_context_list=[x[0].page_content for x in docs_with_score],
+ max_input_tokens=total_tokens_for_docs)
+
+ docs_with_score = docs_with_score[:top_k_docs]
+
+ # put most relevant chunks closest to question,
+ # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
+ # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
+ if docs_ordering_type in ['best_first']:
+ pass
+ elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']:
+ docs_with_score.reverse()
+ elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']:
+ docs_with_score = reverse_ucurve_list(docs_with_score)
+ else:
+ raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type)
+
+ # cut off so no high distance docs/sources considered
+ num_docs_before_cut = len(docs_with_score)
+ docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
+ scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
+ if len(scores) > 0 and verbose:
+ print("Distance: min: %s max: %s mean: %s median: %s" %
+ (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
+ else:
+ docs = []
+ scores = []
+
+ if not docs and use_docs_planned and not langchain_only_model:
+ # if HF type and have no docs, can bail out
+ return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
+
+ if document_subset in non_query_commands:
+ # no LLM use
+ return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
+
+ # FIXME: WIP
+ common_words_file = "data/NGSL_1.2_stats.csv.zip"
+ if False and os.path.isfile(common_words_file) and langchain_action == LangChainAction.QUERY.value:
+ df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
+ import string
+ reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
+ reduced_query_words = reduced_query.split(' ')
+ set_common = set(df['Lemma'].values.tolist())
+ num_common = len([x.lower() in set_common for x in reduced_query_words])
+ frac_common = num_common / len(reduced_query) if reduced_query else 0
+ # FIXME: report to user bad query that uses too many common words
+ if verbose:
+ print("frac_common: %s" % frac_common, flush=True)
+
+ if len(docs) == 0:
+ # avoid context == in prompt then
+ use_docs_planned = False
+ template = template_if_no_docs
+
+ got_db_docs = got_db_docs and len(text_context_list) < len(docs)
+ # update template in case situation changed or did get docs
+ # then no new documents from database or not used, redo template
+ # got template earlier as estimate of template token size, here is final used version
+ template, template_if_no_docs, auto_reduce_chunks, query = \
+ get_template(query, iinput,
+ pre_prompt_query, prompt_query,
+ pre_prompt_summary, prompt_summary,
+ langchain_action,
+ llm_mode,
+ use_docs_planned,
+ auto_reduce_chunks,
+ got_db_docs,
+ add_search_to_context)
+
+ if langchain_action == LangChainAction.QUERY.value:
+ if use_template:
+ # instruct-like, rather than few-shot prompt_type='plain' as default
+ # but then sources confuse the model with how inserted among rest of text, so avoid
+ prompt = PromptTemplate(
+ # input_variables=["summaries", "question"],
+ input_variables=["context", "question"],
+ template=template,
+ )
+ chain = load_qa_chain(llm, prompt=prompt, verbose=verbose)
+ else:
+ # only if use_openai_model = True, unused normally except in testing
+ chain = load_qa_with_sources_chain(llm)
+ if not use_docs_planned:
+ chain_kwargs = dict(input_documents=[], question=query)
+ else:
+ chain_kwargs = dict(input_documents=docs, question=query)
+ target = wrapped_partial(chain, chain_kwargs)
+ elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
+ LangChainAction.SUMMARIZE_REFINE,
+ LangChainAction.SUMMARIZE_ALL.value]:
+ if async_output:
+ return_intermediate_steps = False
+ else:
+ return_intermediate_steps = True
+ from langchain.chains.summarize import load_summarize_chain
+ if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
+ prompt = PromptTemplate(input_variables=["text"], template=template)
+ chain = load_summarize_chain(llm, chain_type="map_reduce",
+ map_prompt=prompt, combine_prompt=prompt,
+ return_intermediate_steps=return_intermediate_steps,
+ token_max=max_input_tokens, verbose=verbose)
+ if async_output:
+ chain_func = chain.arun
+ else:
+ chain_func = chain
+ target = wrapped_partial(chain_func, {"input_documents": docs}) # , return_only_outputs=True)
+ elif langchain_action == LangChainAction.SUMMARIZE_ALL.value:
+ assert use_template
+ prompt = PromptTemplate(input_variables=["text"], template=template)
+ chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt,
+ return_intermediate_steps=return_intermediate_steps, verbose=verbose)
+ if async_output:
+ chain_func = chain.arun
+ else:
+ chain_func = chain
+ target = wrapped_partial(chain_func)
+ elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value:
+ chain = load_summarize_chain(llm, chain_type="refine",
+ return_intermediate_steps=return_intermediate_steps, verbose=verbose)
+ if async_output:
+ chain_func = chain.arun
+ else:
+ chain_func = chain
+ target = wrapped_partial(chain_func)
+ else:
+ raise RuntimeError("No such langchain_action=%s" % langchain_action)
+ else:
+ raise RuntimeError("No such langchain_action=%s" % langchain_action)
+
+ return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
+
+
+def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None):
+ if hasattr(tokenizer, 'model_max_length'):
+ return tokenizer.model_max_length
+ elif inference_server in ['openai', 'openai_azure']:
+ return llm.modelname_to_contextsize(model_name)
+ elif inference_server in ['openai_chat', 'openai_azure_chat']:
+ return model_token_mapping[model_name]
+ elif isinstance(tokenizer, FakeTokenizer):
+ # GGML
+ return tokenizer.model_max_length
+ else:
+ return 2048
+
+
+def get_max_input_tokens(llm=None, tokenizer=None, inference_server=None, model_name=None, max_new_tokens=None):
+ model_max_length = get_max_model_length(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
+ model_name=model_name)
+
+ if any([inference_server.startswith(x) for x in
+ ['openai', 'openai_azure', 'openai_chat', 'openai_azure_chat', 'vllm']]):
+ # openai can't handle tokens + max_new_tokens > max_tokens even if never generate those tokens
+ # and vllm uses OpenAI API with same limits
+ max_input_tokens = model_max_length - max_new_tokens
+ elif isinstance(tokenizer, FakeTokenizer):
+ # don't trust that fake tokenizer (e.g. GGML) will make lots of tokens normally, allow more input
+ max_input_tokens = model_max_length - min(256, max_new_tokens)
+ else:
+ if 'falcon' in model_name or inference_server.startswith('http'):
+ # allow for more input for falcon, assume won't make as long outputs as default max_new_tokens
+ # Also allow if TGI or Gradio, because we tell it input may be same as output, even if model can't actually handle
+ max_input_tokens = model_max_length - min(256, max_new_tokens)
+ else:
+ # trust that maybe model will make so many tokens, so limit input
+ max_input_tokens = model_max_length - max_new_tokens
+
+ return max_input_tokens
+
+
+def get_tokenizer(db=None, llm=None, tokenizer=None, inference_server=None, use_openai_model=False,
+ db_type='chroma'):
+ if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
+ # more accurate
+ return llm.pipeline.tokenizer
+ elif hasattr(llm, 'tokenizer'):
+ # e.g. TGI client mode etc.
+ return llm.tokenizer
+ elif inference_server in ['openai', 'openai_chat', 'openai_azure',
+ 'openai_azure_chat']:
+ return tokenizer
+ elif isinstance(tokenizer, FakeTokenizer):
+ return tokenizer
+ elif use_openai_model:
+ return FakeTokenizer()
+ elif (hasattr(db, '_embedding_function') and
+ hasattr(db._embedding_function, 'client') and
+ hasattr(db._embedding_function.client, 'tokenize')):
+ # in case model is not our pipeline with HF tokenizer
+ return db._embedding_function.client.tokenize
+ else:
+ # backup method
+ if os.getenv('HARD_ASSERTS'):
+ assert db_type in ['faiss', 'weaviate']
+ # use tiktoken for faiss since embedding called differently
+ return FakeTokenizer()
+
+
+def get_template(query, iinput,
+ pre_prompt_query, prompt_query,
+ pre_prompt_summary, prompt_summary,
+ langchain_action,
+ llm_mode,
+ use_docs_planned,
+ auto_reduce_chunks,
+ got_db_docs,
+ add_search_to_context):
+ if got_db_docs and add_search_to_context:
+ # modify prompts, assumes patterns like in predefined prompts. If user customizes, then they'd need to account for that.
+ prompt_query = prompt_query.replace('information in the document sources',
+ 'information in the document and web search sources (and their source dates and website source)')
+ prompt_summary = prompt_summary.replace('information in the document sources',
+ 'information in the document and web search sources (and their source dates and website source)')
+ elif got_db_docs and not add_search_to_context:
+ pass
+ elif not got_db_docs and add_search_to_context:
+ # modify prompts, assumes patterns like in predefined prompts. If user customizes, then they'd need to account for that.
+ prompt_query = prompt_query.replace('information in the document sources',
+ 'information in the web search sources (and their source dates and website source)')
+ prompt_summary = prompt_summary.replace('information in the document sources',
+ 'information in the web search sources (and their source dates and website source)')
+
+ if langchain_action == LangChainAction.QUERY.value:
+ if iinput:
+ query = "%s\n%s" % (query, iinput)
+ if llm_mode or not use_docs_planned:
+ template_if_no_docs = template = """{context}{question}"""
+ else:
+ template = """%s
+\"\"\"
+{context}
+\"\"\"
+%s{question}""" % (pre_prompt_query, prompt_query)
+ template_if_no_docs = """{context}{question}"""
+ elif langchain_action in [LangChainAction.SUMMARIZE_ALL.value, LangChainAction.SUMMARIZE_MAP.value]:
+ none = ['', '\n', None]
+
+ # modify prompt_summary if user passes query or iinput
+ if query not in none and iinput not in none:
+ prompt_summary = "Focusing on %s, %s, %s" % (query, iinput, prompt_summary)
+ elif query not in none:
+ prompt_summary = "Focusing on %s, %s" % (query, prompt_summary)
+ # don't auto reduce
+ auto_reduce_chunks = False
+ if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
+ fstring = '{text}'
+ else:
+ fstring = '{input_documents}'
+ template = """%s:
+\"\"\"
+%s
+\"\"\"\n%s""" % (pre_prompt_summary, fstring, prompt_summary)
+ template_if_no_docs = "Exactly only say: There are no documents to summarize."
+ elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]:
+ template = '' # unused
+ template_if_no_docs = '' # unused
+ else:
+ raise RuntimeError("No such langchain_action=%s" % langchain_action)
+
+ return template, template_if_no_docs, auto_reduce_chunks, query
+
+
+def get_sources_answer(query, docs, answer, scores, show_rank,
+ answer_with_sources, append_sources_to_answer,
+ show_accordions=True,
+ show_link_in_sources=True,
+ top_k_docs_max_show=10,
+ docs_ordering_type='reverse_ucurve_sort',
+ num_docs_before_cut=0,
+ verbose=False,
+ t_run=None,
+ count_input_tokens=None, count_output_tokens=None):
+ if verbose:
+ print("query: %s" % query, flush=True)
+ print("answer: %s" % answer, flush=True)
+
+ if len(docs) == 0:
+ extra = ''
+ ret = answer + extra
+ return ret, extra
+
+ if answer_with_sources == -1:
+ extra = [dict(score=score, content=get_doc(x), source=get_source(x), orig_index=x.metadata.get('orig_index', 0))
+ for score, x in zip(scores, docs)][
+ :top_k_docs_max_show]
+ if append_sources_to_answer:
+ extra_str = [str(x) for x in extra]
+ ret = answer + '\n\n' + '\n'.join(extra_str)
+ else:
+ ret = answer
+ return ret, extra
+
+ # link
+ answer_sources = [(max(0.0, 1.5 - score) / 1.5,
+ get_url(doc, font_size=font_size),
+ get_accordion(doc, font_size=font_size, head_acc=head_acc)) for score, doc in
+ zip(scores, docs)]
+ if not show_accordions:
+ answer_sources_dict = defaultdict(list)
+ [answer_sources_dict[url].append(score) for score, url in answer_sources]
+ answers_dict = {}
+ for url, scores_url in answer_sources_dict.items():
+ answers_dict[url] = np.max(scores_url)
+ answer_sources = [(score, url) for url, score in answers_dict.items()]
+ answer_sources.sort(key=lambda x: x[0], reverse=True)
+ if show_rank:
+ # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
+ # sorted_sources_urls = "Sources [Rank | Link]:
" + "
".join(answer_sources)
+ answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
+ answer_sources = answer_sources[:top_k_docs_max_show]
+ sorted_sources_urls = "Ranked Sources:
" + "
".join(answer_sources)
+ else:
+ if show_accordions:
+ if show_link_in_sources:
+ answer_sources = ['" + "".join(answer_sources)
+ else:
+ sorted_sources_urls = f"{source_prefix}
" + "
{title_overall}
{sorted_sources_urls}
+
+
+
+
')
+ text = '```'.join(ts)
+ return text
+
+
+def is_valid_key(enforce_h2ogpt_api_key, h2ogpt_api_keys, h2ogpt_key1, requests_state1=None):
+ valid_key = False
+ if not enforce_h2ogpt_api_key:
+ # no token barrier
+ valid_key = 'not enforced'
+ else:
+ if isinstance(h2ogpt_api_keys, list) and h2ogpt_key1 in h2ogpt_api_keys:
+ # passed token barrier
+ valid_key = True
+ elif isinstance(h2ogpt_api_keys, str) and os.path.isfile(h2ogpt_api_keys):
+ with filelock.FileLock(h2ogpt_api_keys + '.lock'):
+ with open(h2ogpt_api_keys, 'rt') as f:
+ h2ogpt_api_keys = json.load(f)
+ if h2ogpt_key1 in h2ogpt_api_keys:
+ valid_key = True
+ if isinstance(requests_state1, dict) and 'username' in requests_state1 and requests_state1['username']:
+ # no UI limit currently
+ valid_key = True
+ return valid_key
+
+
+def go_gradio(**kwargs):
+ allow_api = kwargs['allow_api']
+ is_public = kwargs['is_public']
+ is_hf = kwargs['is_hf']
+ memory_restriction_level = kwargs['memory_restriction_level']
+ n_gpus = kwargs['n_gpus']
+ admin_pass = kwargs['admin_pass']
+ model_states = kwargs['model_states']
+ dbs = kwargs['dbs']
+ db_type = kwargs['db_type']
+ visible_langchain_actions = kwargs['visible_langchain_actions']
+ visible_langchain_agents = kwargs['visible_langchain_agents']
+ allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
+ allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
+ enable_sources_list = kwargs['enable_sources_list']
+ enable_url_upload = kwargs['enable_url_upload']
+ enable_text_upload = kwargs['enable_text_upload']
+ use_openai_embedding = kwargs['use_openai_embedding']
+ hf_embedding_model = kwargs['hf_embedding_model']
+ load_db_if_exists = kwargs['load_db_if_exists']
+ migrate_embedding_model = kwargs['migrate_embedding_model']
+ auto_migrate_db = kwargs['auto_migrate_db']
+ captions_model = kwargs['captions_model']
+ caption_loader = kwargs['caption_loader']
+ doctr_loader = kwargs['doctr_loader']
+
+ n_jobs = kwargs['n_jobs']
+ verbose = kwargs['verbose']
+
+ # for dynamic state per user session in gradio
+ model_state0 = kwargs['model_state0']
+ score_model_state0 = kwargs['score_model_state0']
+ my_db_state0 = kwargs['my_db_state0']
+ selection_docs_state0 = kwargs['selection_docs_state0']
+ visible_models_state0 = kwargs['visible_models_state0']
+ # For Heap analytics
+ is_heap_analytics_enabled = kwargs['enable_heap_analytics']
+ heap_app_id = kwargs['heap_app_id']
+
+ # easy update of kwargs needed for evaluate() etc.
+ queue = True
+ allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
+ allow_upload_api = allow_api and allow_upload
+
+ kwargs.update(locals())
+
+ # import control
+ if kwargs['langchain_mode'] != 'Disabled':
+ from gpt_langchain import file_types, have_arxiv
+ else:
+ have_arxiv = False
+ file_types = []
+
+ if 'mbart-' in kwargs['model_lower']:
+ instruction_label_nochat = "Text to translate"
+ else:
+ instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
+ " use Enter for multiple input lines)"
+
+ title = 'h2oGPT'
+ if kwargs['visible_h2ogpt_header']:
+ description = """h2oGPT LLM Leaderboard LLM Studio
CodeLlama
🤗 Models"""
+ else:
+ description = None
+ description_bottom = "If this host is busy, try
[Multi-Model](https://gpt.h2o.ai)
[CodeLlama](https://codellama.h2o.ai)
[Llama2 70B](https://llama.h2o.ai)
[Falcon 40B](https://falcon.h2o.ai)
[HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)
[HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)
"
+ if is_hf:
+ description_bottom += ''''''
+ task_info_md = ''
+ css_code = get_css(kwargs)
+
+ if kwargs['gradio_offline_level'] >= 0:
+ # avoid GoogleFont that pulls from internet
+ if kwargs['gradio_offline_level'] == 1:
+ # front end would still have to download fonts or have cached it at some point
+ base_font = 'Source Sans Pro'
+ else:
+ base_font = 'Helvetica'
+ theme_kwargs = dict(font=(base_font, 'ui-sans-serif', 'system-ui', 'sans-serif'),
+ font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'))
+ else:
+ theme_kwargs = dict()
+ if kwargs['gradio_size'] == 'xsmall':
+ theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm))
+ elif kwargs['gradio_size'] in [None, 'small']:
+ theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm,
+ radius_size=gr.themes.sizes.spacing_sm))
+ elif kwargs['gradio_size'] == 'large':
+ theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_lg, text_size=gr.themes.sizes.text_lg),
+ radius_size=gr.themes.sizes.spacing_lg)
+ elif kwargs['gradio_size'] == 'medium':
+ theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_md, text_size=gr.themes.sizes.text_md,
+ radius_size=gr.themes.sizes.spacing_md))
+
+ theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs)
+ demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
+ callback = gr.CSVLogger()
+
+ model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
+ if kwargs['base_model'].strip() not in model_options0:
+ model_options0 = [kwargs['base_model'].strip()] + model_options0
+ lora_options = kwargs['extra_lora_options']
+ if kwargs['lora_weights'].strip() not in lora_options:
+ lora_options = [kwargs['lora_weights'].strip()] + lora_options
+ server_options = kwargs['extra_server_options']
+ if kwargs['inference_server'].strip() not in server_options:
+ server_options = [kwargs['inference_server'].strip()] + server_options
+ if os.getenv('OPENAI_API_KEY'):
+ if 'openai_chat' not in server_options:
+ server_options += ['openai_chat']
+ if 'openai' not in server_options:
+ server_options += ['openai']
+
+ # always add in no lora case
+ # add fake space so doesn't go away in gradio dropdown
+ model_options0 = [no_model_str] + sorted(model_options0)
+ lora_options = [no_lora_str] + sorted(lora_options)
+ server_options = [no_server_str] + sorted(server_options)
+ # always add in no model case so can free memory
+ # add fake space so doesn't go away in gradio dropdown
+
+ # transcribe, will be detranscribed before use by evaluate()
+ if not kwargs['base_model'].strip():
+ kwargs['base_model'] = no_model_str
+
+ if not kwargs['lora_weights'].strip():
+ kwargs['lora_weights'] = no_lora_str
+
+ if not kwargs['inference_server'].strip():
+ kwargs['inference_server'] = no_server_str
+
+ # transcribe for gradio
+ kwargs['gpu_id'] = str(kwargs['gpu_id'])
+
+ no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
+ output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
+ 'base_model') else no_model_msg
+ output_label0_model2 = no_model_msg
+
+ def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0):
+ if not prompt_type1 or which_model != 0:
+ # keep prompt_type and prompt_dict in sync if possible
+ prompt_type1 = kwargs.get('prompt_type', prompt_type1)
+ prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
+ # prefer model specific prompt type instead of global one
+ if not prompt_type1 or which_model != 0:
+ prompt_type1 = model_state1.get('prompt_type', prompt_type1)
+ prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
+
+ if not prompt_dict1 or which_model != 0:
+ # if still not defined, try to get
+ prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
+ if not prompt_dict1 or which_model != 0:
+ prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
+ return prompt_type1, prompt_dict1
+
+ def visible_models_to_model_choice(visible_models1):
+ if isinstance(visible_models1, list):
+ assert len(
+ visible_models1) >= 1, "Invalid visible_models1=%s, can only be single entry" % visible_models1
+ # just take first
+ model_active_choice1 = visible_models1[0]
+ elif isinstance(visible_models1, (str, int)):
+ model_active_choice1 = visible_models1
+ else:
+ assert isinstance(visible_models1, type(None)), "Invalid visible_models1=%s" % visible_models1
+ model_active_choice1 = visible_models1
+ if model_active_choice1 is not None:
+ if isinstance(model_active_choice1, str):
+ base_model_list = [x['base_model'] for x in model_states]
+ if model_active_choice1 in base_model_list:
+ # if dups, will just be first one
+ model_active_choice1 = base_model_list.index(model_active_choice1)
+ else:
+ # NOTE: Could raise, but sometimes raising in certain places fails too hard and requires UI restart
+ model_active_choice1 = 0
+ else:
+ model_active_choice1 = 0
+ return model_active_choice1
+
+ default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
+ # ensure prompt_type consistent with prep_bot(), so nochat API works same way
+ default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \
+ update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'],
+ model_state1=model_state0,
+ which_model=visible_models_to_model_choice(kwargs['visible_models']))
+ for k in no_default_param_names:
+ default_kwargs[k] = ''
+
+ def dummy_fun(x):
+ # need dummy function to block new input from being sent until output is done,
+ # else gets input_list at time of submit that is old, and shows up as truncated in chatbot
+ return x
+
+ def update_auth_selection(auth_user, selection_docs_state1, save=False):
+ # in-place update of both
+ if 'selection_docs_state' not in auth_user:
+ auth_user['selection_docs_state'] = selection_docs_state0
+ for k, v in auth_user['selection_docs_state'].items():
+ if isinstance(selection_docs_state1[k], dict):
+ if save:
+ auth_user['selection_docs_state'][k].clear()
+ auth_user['selection_docs_state'][k].update(selection_docs_state1[k])
+ else:
+ selection_docs_state1[k].clear()
+ selection_docs_state1[k].update(auth_user['selection_docs_state'][k])
+ elif isinstance(selection_docs_state1[k], list):
+ if save:
+ auth_user['selection_docs_state'][k].clear()
+ auth_user['selection_docs_state'][k].extend(selection_docs_state1[k])
+ else:
+ selection_docs_state1[k].clear()
+ selection_docs_state1[k].extend(auth_user['selection_docs_state'][k])
+ else:
+ raise RuntimeError("Bad type: %s" % selection_docs_state1[k])
+
+ # BEGIN AUTH THINGS
+ def auth_func(username1, password1, auth_pairs=None, auth_filename=None,
+ auth_access=None,
+ auth_freeze=None,
+ guest_name=None,
+ selection_docs_state1=None,
+ selection_docs_state00=None,
+ **kwargs):
+ assert auth_freeze is not None
+ if selection_docs_state1 is None:
+ selection_docs_state1 = selection_docs_state00
+ assert selection_docs_state1 is not None
+ assert auth_filename and isinstance(auth_filename, str), "Auth file must be a non-empty string, got: %s" % str(
+ auth_filename)
+ if auth_access == 'open' and username1 == guest_name:
+ return True
+ if username1 == '':
+ # some issue with login
+ return False
+ with filelock.FileLock(auth_filename + '.lock'):
+ auth_dict = {}
+ if os.path.isfile(auth_filename):
+ try:
+ with open(auth_filename, 'rt') as f:
+ auth_dict = json.load(f)
+ except json.decoder.JSONDecodeError as e:
+ print("Auth exception: %s" % str(e), flush=True)
+ shutil.move(auth_filename, auth_filename + '.bak' + str(uuid.uuid4()))
+ auth_dict = {}
+ if username1 in auth_dict and username1 in auth_pairs:
+ if password1 == auth_dict[username1]['password'] and password1 == auth_pairs[username1]:
+ auth_user = auth_dict[username1]
+ update_auth_selection(auth_user, selection_docs_state1)
+ save_auth_dict(auth_dict, auth_filename)
+ return True
+ else:
+ return False
+ elif username1 in auth_dict:
+ if password1 == auth_dict[username1]['password']:
+ auth_user = auth_dict[username1]
+ update_auth_selection(auth_user, selection_docs_state1)
+ save_auth_dict(auth_dict, auth_filename)
+ return True
+ else:
+ return False
+ elif username1 in auth_pairs:
+ # copy over CLI auth to file so only one state to manage
+ auth_dict[username1] = dict(password=auth_pairs[username1], userid=str(uuid.uuid4()))
+ auth_user = auth_dict[username1]
+ update_auth_selection(auth_user, selection_docs_state1)
+ save_auth_dict(auth_dict, auth_filename)
+ return True
+ else:
+ if auth_access == 'closed':
+ return False
+ # open access
+ auth_dict[username1] = dict(password=password1, userid=str(uuid.uuid4()))
+ auth_user = auth_dict[username1]
+ update_auth_selection(auth_user, selection_docs_state1)
+ save_auth_dict(auth_dict, auth_filename)
+ if auth_access == 'open':
+ return True
+ else:
+ raise RuntimeError("Invalid auth_access: %s" % auth_access)
+
+ def auth_func_open(*args, **kwargs):
+ return True
+
+ def get_username(requests_state1):
+ username1 = None
+ if 'username' in requests_state1:
+ username1 = requests_state1['username']
+ return username1
+
+ def get_userid_auth_func(requests_state1, auth_filename=None, auth_access=None, guest_name=None, **kwargs):
+ if auth_filename and isinstance(auth_filename, str):
+ username1 = get_username(requests_state1)
+ if username1:
+ if username1 == guest_name:
+ return str(uuid.uuid4())
+ with filelock.FileLock(auth_filename + '.lock'):
+ if os.path.isfile(auth_filename):
+ with open(auth_filename, 'rt') as f:
+ auth_dict = json.load(f)
+ if username1 in auth_dict:
+ return auth_dict[username1]['userid']
+ # if here, then not persistently associated with username1,
+ # but should only be one-time asked if going to persist within a single session!
+ return str(uuid.uuid4())
+
+ get_userid_auth = functools.partial(get_userid_auth_func,
+ auth_filename=kwargs['auth_filename'],
+ auth_access=kwargs['auth_access'],
+ guest_name=kwargs['guest_name'],
+ )
+ if kwargs['auth_access'] == 'closed':
+ auth_message1 = "Closed access"
+ else:
+ auth_message1 = "WELCOME! Open access" \
+ " (%s/%s or any unique user/pass)" % (kwargs['guest_name'], kwargs['guest_name'])
+
+ if kwargs['auth_message'] is not None:
+ auth_message = kwargs['auth_message']
+ else:
+ auth_message = auth_message1
+
+ # always use same callable
+ auth_pairs0 = {}
+ if isinstance(kwargs['auth'], list):
+ for k, v in kwargs['auth']:
+ auth_pairs0[k] = v
+ authf = functools.partial(auth_func,
+ auth_pairs=auth_pairs0,
+ auth_filename=kwargs['auth_filename'],
+ auth_access=kwargs['auth_access'],
+ auth_freeze=kwargs['auth_freeze'],
+ guest_name=kwargs['guest_name'],
+ selection_docs_state00=copy.deepcopy(selection_docs_state0))
+
+ def get_request_state(requests_state1, request, db1s):
+ # if need to get state, do it now
+ if not requests_state1:
+ requests_state1 = requests_state0.copy()
+ if requests:
+ if not requests_state1.get('headers', '') and hasattr(request, 'headers'):
+ requests_state1.update(request.headers)
+ if not requests_state1.get('host', '') and hasattr(request, 'host'):
+ requests_state1.update(dict(host=request.host))
+ if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'):
+ requests_state1.update(dict(host2=request.client.host))
+ if not requests_state1.get('username', '') and hasattr(request, 'username'):
+ # use already-defined username instead of keep changing to new uuid
+ # should be same as in requests_state1
+ db_username = get_username_direct(db1s)
+ requests_state1.update(dict(username=request.username or db_username or str(uuid.uuid4())))
+ requests_state1 = {str(k): str(v) for k, v in requests_state1.items()}
+ return requests_state1
+
+ def user_state_setup(db1s, requests_state1, request: gr.Request, *args):
+ requests_state1 = get_request_state(requests_state1, request, db1s)
+ set_userid(db1s, requests_state1, get_userid_auth)
+ args_list = [db1s, requests_state1] + list(args)
+ return tuple(args_list)
+
+ # END AUTH THINGS
+
+ def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
+ allow = False
+ allow |= langchain_action1 not in LangChainAction.QUERY.value
+ allow |= document_subset1 in DocumentSubset.TopKSources.name
+ if langchain_mode1 in [LangChainMode.LLM.value]:
+ allow = False
+ return allow
+
+ image_loaders_options0, image_loaders_options, \
+ pdf_loaders_options0, pdf_loaders_options, \
+ url_loaders_options0, url_loaders_options = lg_to_gr(**kwargs)
+ jq_schema0 = '.[]'
+
+ with demo:
+ # avoid actual model/tokenizer here or anything that would be bad to deepcopy
+ # https://github.com/gradio-app/gradio/issues/3558
+ model_state = gr.State(
+ dict(model='model', tokenizer='tokenizer', device=kwargs['device'],
+ base_model=kwargs['base_model'],
+ tokenizer_base_model=kwargs['tokenizer_base_model'],
+ lora_weights=kwargs['lora_weights'],
+ inference_server=kwargs['inference_server'],
+ prompt_type=kwargs['prompt_type'],
+ prompt_dict=kwargs['prompt_dict'],
+ visible_models=kwargs['visible_models'],
+ h2ogpt_key=kwargs['h2ogpt_key'],
+ )
+ )
+
+ def update_langchain_mode_paths(selection_docs_state1):
+ dup = selection_docs_state1['langchain_mode_paths'].copy()
+ for k, v in dup.items():
+ if k not in selection_docs_state1['langchain_modes']:
+ selection_docs_state1['langchain_mode_paths'].pop(k)
+ for k in selection_docs_state1['langchain_modes']:
+ if k not in selection_docs_state1['langchain_mode_types']:
+ # if didn't specify shared, then assume scratch if didn't login or personal if logged in
+ selection_docs_state1['langchain_mode_types'][k] = LangChainTypes.PERSONAL.value
+ return selection_docs_state1
+
+ # Setup some gradio states for per-user dynamic state
+ model_state2 = gr.State(kwargs['model_state_none'].copy())
+ model_options_state = gr.State([model_options0])
+ lora_options_state = gr.State([lora_options])
+ server_options_state = gr.State([server_options])
+ my_db_state = gr.State(my_db_state0)
+ chat_state = gr.State({})
+ docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value]
+ docs_state0 = []
+ [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
+ docs_state = gr.State(docs_state0)
+ viewable_docs_state0 = []
+ viewable_docs_state = gr.State(viewable_docs_state0)
+ selection_docs_state0 = update_langchain_mode_paths(selection_docs_state0)
+ selection_docs_state = gr.State(selection_docs_state0)
+ requests_state0 = dict(headers='', host='', username='')
+ requests_state = gr.State(requests_state0)
+
+ if description is not None:
+ gr.Markdown(f"""
+ {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
+ """)
+
+ # go button visible if
+ base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
+ go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
+
+ nas = ' '.join(['NA'] * len(kwargs['model_states']))
+ res_value = "Response Score: NA" if not kwargs[
+ 'model_lock'] else "Response Scores: %s" % nas
+
+ user_can_do_sum = kwargs['langchain_mode'] != LangChainMode.DISABLED.value and \
+ (kwargs['visible_side_bar'] or kwargs['visible_system_tab'])
+ if user_can_do_sum:
+ extra_prompt_form = ". For summarization, no query required, just click submit"
+ else:
+ extra_prompt_form = ""
+ if kwargs['input_lines'] > 1:
+ instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
+ else:
+ instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
+
+ def get_langchain_choices(selection_docs_state1):
+ langchain_modes = selection_docs_state1['langchain_modes']
+
+ if is_hf:
+ # don't show 'wiki' since only usually useful for internal testing at moment
+ no_show_modes = ['Disabled', 'wiki']
+ else:
+ no_show_modes = ['Disabled']
+ allowed_modes = langchain_modes.copy()
+ # allowed_modes = [x for x in allowed_modes if x in dbs]
+ allowed_modes += ['LLM']
+ if allow_upload_to_my_data and 'MyData' not in allowed_modes:
+ allowed_modes += ['MyData']
+ if allow_upload_to_user_data and 'UserData' not in allowed_modes:
+ allowed_modes += ['UserData']
+ choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes]
+ return choices
+
+ def get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=None):
+ langchain_choices1 = get_langchain_choices(selection_docs_state1)
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
+ langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if k in langchain_choices1}
+ if langchain_mode_paths:
+ langchain_mode_paths = langchain_mode_paths.copy()
+ for langchain_mode1 in langchain_modes_non_db:
+ langchain_mode_paths.pop(langchain_mode1, None)
+ df1 = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns')
+ df1.columns = ['Collection', 'Path']
+ df1 = df1.set_index('Collection')
+ else:
+ df1 = pd.DataFrame(None)
+ langchain_mode_types = selection_docs_state1['langchain_mode_types']
+ langchain_mode_types = {k: v for k, v in langchain_mode_types.items() if k in langchain_choices1}
+ if langchain_mode_types:
+ langchain_mode_types = langchain_mode_types.copy()
+ for langchain_mode1 in langchain_modes_non_db:
+ langchain_mode_types.pop(langchain_mode1, None)
+
+ df2 = pd.DataFrame.from_dict(langchain_mode_types.items(), orient='columns')
+ df2.columns = ['Collection', 'Type']
+ df2 = df2.set_index('Collection')
+
+ from src.gpt_langchain import get_persist_directory, load_embed
+ persist_directory_dict = {}
+ embed_dict = {}
+ chroma_version_dict = {}
+ for langchain_mode3 in langchain_mode_types:
+ langchain_type3 = langchain_mode_types.get(langchain_mode3, LangChainTypes.EITHER.value)
+ persist_directory3, langchain_type3 = get_persist_directory(langchain_mode3,
+ langchain_type=langchain_type3,
+ db1s=db1s, dbs=dbs1)
+ got_embedding3, use_openai_embedding3, hf_embedding_model3 = load_embed(
+ persist_directory=persist_directory3)
+ persist_directory_dict[langchain_mode3] = persist_directory3
+ embed_dict[langchain_mode3] = 'OpenAI' if not hf_embedding_model3 else hf_embedding_model3
+
+ if os.path.isfile(os.path.join(persist_directory3, 'chroma.sqlite3')):
+ chroma_version_dict[langchain_mode3] = 'ChromaDB>=0.4'
+ elif os.path.isdir(os.path.join(persist_directory3, 'index')):
+ chroma_version_dict[langchain_mode3] = 'ChromaDB<0.4'
+ elif not os.listdir(persist_directory3):
+ if db_type == 'chroma':
+ chroma_version_dict[langchain_mode3] = 'ChromaDB>=0.4' # will be
+ elif db_type == 'chroma_old':
+ chroma_version_dict[langchain_mode3] = 'ChromaDB<0.4' # will be
+ else:
+ chroma_version_dict[langchain_mode3] = 'Weaviate' # will be
+ if isinstance(hf_embedding_model, dict):
+ hf_embedding_model3 = hf_embedding_model['name']
+ else:
+ hf_embedding_model3 = hf_embedding_model
+ assert isinstance(hf_embedding_model3, str)
+ embed_dict[langchain_mode3] = hf_embedding_model3 # will be
+ else:
+ chroma_version_dict[langchain_mode3] = 'Weaviate'
+
+ df3 = pd.DataFrame.from_dict(persist_directory_dict.items(), orient='columns')
+ df3.columns = ['Collection', 'Directory']
+ df3 = df3.set_index('Collection')
+
+ df4 = pd.DataFrame.from_dict(embed_dict.items(), orient='columns')
+ df4.columns = ['Collection', 'Embedding']
+ df4 = df4.set_index('Collection')
+
+ df5 = pd.DataFrame.from_dict(chroma_version_dict.items(), orient='columns')
+ df5.columns = ['Collection', 'DB']
+ df5 = df5.set_index('Collection')
+ else:
+ df2 = pd.DataFrame(None)
+ df3 = pd.DataFrame(None)
+ df4 = pd.DataFrame(None)
+ df5 = pd.DataFrame(None)
+ df_list = [df2, df1, df3, df4, df5]
+ df_list = [x for x in df_list if x.shape[1] > 0]
+ if len(df_list) > 1:
+ df = df_list[0].join(df_list[1:]).replace(np.nan, '').reset_index()
+ elif len(df_list) == 0:
+ df = df_list[0].replace(np.nan, '').reset_index()
+ else:
+ df = pd.DataFrame(None)
+ return df
+
+ normal_block = gr.Row(visible=not base_wanted, equal_height=False, elem_id="col_container")
+ with normal_block:
+ side_bar = gr.Column(elem_id="sidebar", scale=1, min_width=100, visible=kwargs['visible_side_bar'])
+ with side_bar:
+ with gr.Accordion("Chats", open=False, visible=True):
+ radio_chats = gr.Radio(value=None, label="Saved Chats", show_label=False,
+ visible=True, interactive=True,
+ type='value')
+ upload_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload
+ with gr.Accordion("Upload", open=False, visible=upload_visible):
+ with gr.Column():
+ with gr.Row(equal_height=False):
+ fileup_output = gr.File(show_label=False,
+ file_types=['.' + x for x in file_types],
+ # file_types=['*', '*.*'], # for iPhone etc. needs to be unconstrained else doesn't work with extension-based restrictions
+ file_count="multiple",
+ scale=1,
+ min_width=0,
+ elem_id="warning", elem_classes="feedback",
+ )
+ fileup_output_text = gr.Textbox(visible=False)
+ max_quality = gr.Checkbox(label="Maximum Ingest Quality", value=kwargs['max_quality'],
+ visible=not is_public)
+ url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
+ url_label = 'URL/ArXiv' if have_arxiv else 'URL'
+ url_text = gr.Textbox(label=url_label,
+ # placeholder="Enter Submits",
+ max_lines=1,
+ interactive=True)
+ text_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload
+ user_text_text = gr.Textbox(label='Paste Text',
+ # placeholder="Enter Submits",
+ interactive=True,
+ visible=text_visible)
+ github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
+ database_visible = kwargs['langchain_mode'] != 'Disabled'
+ with gr.Accordion("Resources", open=False, visible=database_visible):
+ langchain_choices0 = get_langchain_choices(selection_docs_state0)
+ langchain_mode = gr.Radio(
+ langchain_choices0,
+ value=kwargs['langchain_mode'],
+ label="Collections",
+ show_label=True,
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ min_width=100)
+ add_chat_history_to_context = gr.Checkbox(label="Chat History",
+ value=kwargs['add_chat_history_to_context'])
+ add_search_to_context = gr.Checkbox(label="Web Search",
+ value=kwargs['add_search_to_context'],
+ visible=os.environ.get('SERPAPI_API_KEY') is not None \
+ and have_serpapi)
+ document_subset = gr.Radio([x.name for x in DocumentSubset],
+ label="Subset",
+ value=DocumentSubset.Relevant.name,
+ interactive=True,
+ )
+ allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
+ langchain_action = gr.Radio(
+ allowed_actions,
+ value=allowed_actions[0] if len(allowed_actions) > 0 else None,
+ label="Action",
+ visible=True)
+ allowed_agents = [x for x in langchain_agents_list if x in visible_langchain_agents]
+ if os.getenv('OPENAI_API_KEY') is None and LangChainAgent.JSON.value in allowed_agents:
+ allowed_agents.remove(LangChainAgent.JSON.value)
+ if os.getenv('OPENAI_API_KEY') is None and LangChainAgent.PYTHON.value in allowed_agents:
+ allowed_agents.remove(LangChainAgent.PYTHON.value)
+ if LangChainAgent.PANDAS.value in allowed_agents:
+ allowed_agents.remove(LangChainAgent.PANDAS.value)
+ langchain_agents = gr.Dropdown(
+ allowed_agents,
+ value=None,
+ label="Agents",
+ multiselect=True,
+ interactive=True,
+ visible=True,
+ elem_id="langchain_agents",
+ filterable=False)
+ visible_doc_track = upload_visible and kwargs['visible_doc_track'] and not kwargs[
+ 'large_file_count_mode']
+ row_doc_track = gr.Row(visible=visible_doc_track)
+ with row_doc_track:
+ if kwargs['langchain_mode'] in langchain_modes_non_db:
+ doc_counts_str = "Pure LLM Mode"
+ else:
+ doc_counts_str = "Name: %s\nDocs: Unset\nChunks: Unset" % kwargs['langchain_mode']
+ text_doc_count = gr.Textbox(lines=3, label="Doc Counts", value=doc_counts_str,
+ visible=visible_doc_track)
+ text_file_last = gr.Textbox(lines=1, label="Newest Doc", value=None, visible=visible_doc_track)
+ text_viewable_doc_count = gr.Textbox(lines=2, label=None, visible=False)
+ col_tabs = gr.Column(elem_id="col-tabs", scale=10)
+ with col_tabs, gr.Tabs():
+ if kwargs['chat_tables']:
+ chat_tab = gr.Row(visible=True)
+ else:
+ chat_tab = gr.TabItem("Chat") \
+ if kwargs['visible_chat_tab'] else gr.Row(visible=False)
+ with chat_tab:
+ if kwargs['langchain_mode'] == 'Disabled':
+ text_output_nochat = gr.Textbox(lines=5, label=output_label0, show_copy_button=True,
+ visible=not kwargs['chat'])
+ else:
+ # text looks a bit worse, but HTML links work
+ text_output_nochat = gr.HTML(label=output_label0, visible=not kwargs['chat'])
+ with gr.Row():
+ # NOCHAT
+ instruction_nochat = gr.Textbox(
+ lines=kwargs['input_lines'],
+ label=instruction_label_nochat,
+ placeholder=kwargs['placeholder_instruction'],
+ visible=not kwargs['chat'],
+ )
+ iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
+ placeholder=kwargs['placeholder_input'],
+ value=kwargs['iinput'],
+ visible=not kwargs['chat'])
+ submit_nochat = gr.Button("Submit", size='sm', visible=not kwargs['chat'])
+ flag_btn_nochat = gr.Button("Flag", size='sm', visible=not kwargs['chat'])
+ score_text_nochat = gr.Textbox("Response Score: NA", show_label=False,
+ visible=not kwargs['chat'])
+ submit_nochat_api = gr.Button("Submit nochat API", visible=False)
+ submit_nochat_api_plain = gr.Button("Submit nochat API Plain", visible=False)
+ inputs_dict_str = gr.Textbox(label='API input for nochat', show_label=False, visible=False)
+ text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
+ show_copy_button=True)
+
+ visible_upload = (allow_upload_to_user_data or
+ allow_upload_to_my_data) and \
+ kwargs['langchain_mode'] != 'Disabled'
+ # CHAT
+ col_chat = gr.Column(visible=kwargs['chat'])
+ with col_chat:
+ with gr.Row():
+ with gr.Column(scale=50):
+ with gr.Row(elem_id="prompt-form-row"):
+ label_instruction = 'Ask anything'
+ instruction = gr.Textbox(
+ lines=kwargs['input_lines'],
+ label=label_instruction,
+ placeholder=instruction_label,
+ info=None,
+ elem_id='prompt-form',
+ container=True,
+ )
+ attach_button = gr.UploadButton(
+ elem_id="attach-button" if visible_upload else None,
+ value="",
+ label="Upload File(s)",
+ size="sm",
+ min_width=24,
+ file_types=['.' + x for x in file_types],
+ file_count="multiple",
+ visible=visible_upload)
+
+ submit_buttons = gr.Row(equal_height=False, visible=kwargs['visible_submit_buttons'])
+ with submit_buttons:
+ mw1 = 50
+ mw2 = 50
+ with gr.Column(min_width=mw1):
+ submit = gr.Button(value='Submit', variant='primary', size='sm',
+ min_width=mw1)
+ stop_btn = gr.Button(value="Stop", variant='secondary', size='sm',
+ min_width=mw1)
+ save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
+ with gr.Column(min_width=mw2):
+ retry_btn = gr.Button("Redo", size='sm', min_width=mw2)
+ undo = gr.Button("Undo", size='sm', min_width=mw2)
+ clear_chat_btn = gr.Button(value="Clear", size='sm', min_width=mw2)
+
+ visible_model_choice = bool(kwargs['model_lock']) and \
+ len(model_states) > 1 and \
+ kwargs['visible_visible_models']
+ with gr.Row(visible=visible_model_choice):
+ visible_models = gr.Dropdown(kwargs['all_models'],
+ label="Visible Models",
+ value=visible_models_state0,
+ interactive=True,
+ multiselect=True,
+ visible=visible_model_choice,
+ elem_id="visible-models",
+ filterable=False,
+ )
+
+ text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
+ **kwargs)
+
+ with gr.Row():
+ with gr.Column(visible=kwargs['score_model']):
+ score_text = gr.Textbox(res_value,
+ show_label=False,
+ visible=True)
+ score_text2 = gr.Textbox("Response Score2: NA", show_label=False,
+ visible=False and not kwargs['model_lock'])
+
+ doc_selection_tab = gr.TabItem("Document Selection") \
+ if kwargs['visible_doc_selection_tab'] else gr.Row(visible=False)
+ with doc_selection_tab:
+ if kwargs['langchain_mode'] in langchain_modes_non_db:
+ dlabel1 = 'Choose Resources->Collections and Pick Collection'
+ active_collection = gr.Markdown(value="#### Not Chatting with Any Collection\n%s" % dlabel1)
+ else:
+ dlabel1 = 'Select Subset of Document(s) for Chat with Collection: %s' % kwargs['langchain_mode']
+ active_collection = gr.Markdown(
+ value="#### Chatting with Collection: %s" % kwargs['langchain_mode'])
+ document_choice = gr.Dropdown(docs_state0,
+ label=dlabel1,
+ value=[DocumentChoice.ALL.value],
+ interactive=True,
+ multiselect=True,
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ )
+ sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
+ with gr.Row():
+ with gr.Column(scale=1):
+ get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
+ visible=sources_visible and kwargs['large_file_count_mode'])
+ # handle API get sources
+ get_sources_api_btn = gr.Button(visible=False)
+ get_sources_api_text = gr.Textbox(visible=False)
+
+ get_document_api_btn = gr.Button(visible=False)
+ get_document_api_text = gr.Textbox(visible=False)
+
+ show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
+ visible=sources_visible and kwargs['large_file_count_mode'])
+ delete_sources_btn = gr.Button(value="Delete Selected Sources from DB", scale=0, size='sm',
+ visible=sources_visible)
+ refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
+ size='sm',
+ visible=sources_visible and allow_upload_to_user_data)
+ with gr.Column(scale=4):
+ pass
+ visible_add_remove_collection = visible_upload
+ with gr.Row():
+ with gr.Column(scale=1):
+ add_placeholder = "e.g. UserData2, shared, user_path2" \
+ if not is_public else "e.g. MyData2, personal (optional)"
+ remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
+ new_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection,
+ label='Add Collection',
+ placeholder=add_placeholder,
+ interactive=True)
+ remove_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection,
+ label='Remove Collection from UI',
+ placeholder=remove_placeholder,
+ interactive=True)
+ purge_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection,
+ label='Purge Collection (UI, DB, & source files)',
+ placeholder=remove_placeholder,
+ interactive=True)
+ sync_sources_btn = gr.Button(
+ value="Synchronize DB and UI [only required if did not login and have shared docs]",
+ scale=0, size='sm',
+ visible=sources_visible and allow_upload_to_user_data and not kwargs[
+ 'large_file_count_mode'])
+ load_langchain = gr.Button(
+ value="Load Collections State [only required if logged in another user ", scale=0,
+ size='sm',
+ visible=False and allow_upload_to_user_data and
+ kwargs['langchain_mode'] != 'Disabled')
+ with gr.Column(scale=5):
+ if kwargs['langchain_mode'] != 'Disabled' and visible_add_remove_collection:
+ df0 = get_df_langchain_mode_paths(selection_docs_state0, None, dbs1=dbs)
+ else:
+ df0 = pd.DataFrame(None)
+ langchain_mode_path_text = gr.Dataframe(value=df0,
+ visible=visible_add_remove_collection,
+ label='LangChain Mode-Path',
+ show_label=False,
+ interactive=False)
+
+ sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
+ equal_height=False)
+ with sources_row:
+ with gr.Column(scale=1):
+ file_source = gr.File(interactive=False,
+ label="Download File w/Sources")
+ with gr.Column(scale=2):
+ sources_text = gr.HTML(label='Sources Added', interactive=False)
+
+ doc_exception_text = gr.Textbox(value="", label='Document Exceptions',
+ interactive=False,
+ visible=kwargs['langchain_mode'] != 'Disabled')
+ file_types_str = ' '.join(file_types) + ' URL ArXiv TEXT'
+ gr.Textbox(value=file_types_str, label='Document Types Supported',
+ lines=2,
+ interactive=False,
+ visible=kwargs['langchain_mode'] != 'Disabled')
+
+ doc_view_tab = gr.TabItem("Document Viewer") \
+ if kwargs['visible_doc_view_tab'] else gr.Row(visible=False)
+ with doc_view_tab:
+ with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled'):
+ with gr.Column(scale=2):
+ get_viewable_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0,
+ size='sm',
+ visible=sources_visible and kwargs[
+ 'large_file_count_mode'])
+ view_document_choice = gr.Dropdown(viewable_docs_state0,
+ label="Select Single Document to View",
+ value=None,
+ interactive=True,
+ multiselect=False,
+ visible=True,
+ )
+ info_view_raw = "Raw text shown if render of original doc fails"
+ if is_public:
+ info_view_raw += " (Up to %s chunks in public portal)" % kwargs['max_raw_chunks']
+ view_raw_text_checkbox = gr.Checkbox(label="View Database Text", value=False,
+ info=info_view_raw,
+ visible=kwargs['db_type'] in ['chroma', 'chroma_old'])
+ with gr.Column(scale=4):
+ pass
+ doc_view = gr.HTML(visible=False)
+ doc_view2 = gr.Dataframe(visible=False)
+ doc_view3 = gr.JSON(visible=False)
+ doc_view4 = gr.Markdown(visible=False)
+ doc_view5 = gr.HTML(visible=False)
+
+ chat_tab = gr.TabItem("Chat History") \
+ if kwargs['visible_chat_history_tab'] else gr.Row(visible=False)
+ with chat_tab:
+ with gr.Row():
+ with gr.Column(scale=1):
+ remove_chat_btn = gr.Button(value="Remove Selected Saved Chats", visible=True, size='sm')
+ flag_btn = gr.Button("Flag Current Chat", size='sm')
+ export_chats_btn = gr.Button(value="Export Chats to Download", size='sm')
+ with gr.Column(scale=4):
+ pass
+ with gr.Row():
+ chats_file = gr.File(interactive=False, label="Download Exported Chats")
+ chatsup_output = gr.File(label="Upload Chat File(s)",
+ file_types=['.json'],
+ file_count='multiple',
+ elem_id="warning", elem_classes="feedback")
+ with gr.Row():
+ if 'mbart-' in kwargs['model_lower']:
+ src_lang = gr.Dropdown(list(languages_covered().keys()),
+ value=kwargs['src_lang'],
+ label="Input Language")
+ tgt_lang = gr.Dropdown(list(languages_covered().keys()),
+ value=kwargs['tgt_lang'],
+ label="Output Language")
+
+ chat_exception_text = gr.Textbox(value="", visible=True, label='Chat Exceptions',
+ interactive=False)
+ expert_tab = gr.TabItem("Expert") \
+ if kwargs['visible_expert_tab'] else gr.Row(visible=False)
+ with expert_tab:
+ with gr.Row():
+ with gr.Column():
+ prompt_type = gr.Dropdown(prompt_types_strings,
+ value=kwargs['prompt_type'], label="Prompt Type",
+ visible=not kwargs['model_lock'],
+ interactive=not is_public,
+ )
+ prompt_type2 = gr.Dropdown(prompt_types_strings,
+ value=kwargs['prompt_type'], label="Prompt Type Model 2",
+ visible=False and not kwargs['model_lock'],
+ interactive=not is_public)
+ system_prompt = gr.Textbox(label="System Prompt",
+ info="If 'auto', then uses model's system prompt,"
+ " else use this message."
+ " If empty, no system message is used",
+ value=kwargs['system_prompt'])
+ context = gr.Textbox(lines=2, label="System Pre-Context",
+ info="Directly pre-appended without prompt processing (before Pre-Conversation)",
+ value=kwargs['context'])
+ chat_conversation = gr.Textbox(lines=2, label="Pre-Conversation",
+ info="Pre-append conversation for instruct/chat models as List of tuple of (human, bot)",
+ value=kwargs['chat_conversation'])
+ text_context_list = gr.Textbox(lines=2, label="Text Doc Q/A",
+ info="List of strings, for document Q/A, for bypassing database (i.e. also works in LLM Mode)",
+ value=kwargs['chat_conversation'],
+ visible=not is_public, # primarily meant for API
+ )
+ iinput = gr.Textbox(lines=2, label="Input for Instruct prompt types",
+ info="If given for document query, added after query",
+ value=kwargs['iinput'],
+ placeholder=kwargs['placeholder_input'],
+ interactive=not is_public)
+ with gr.Column():
+ pre_prompt_query = gr.Textbox(label="Query Pre-Prompt",
+ info="Added before documents",
+ value=kwargs['pre_prompt_query'] or '')
+ prompt_query = gr.Textbox(label="Query Prompt",
+ info="Added after documents",
+ value=kwargs['prompt_query'] or '')
+ pre_prompt_summary = gr.Textbox(label="Summary Pre-Prompt",
+ info="Added before documents",
+ value=kwargs['pre_prompt_summary'] or '')
+ prompt_summary = gr.Textbox(label="Summary Prompt",
+ info="Added after documents (if query given, 'Focusing on {query}, ' is pre-appended)",
+ value=kwargs['prompt_summary'] or '')
+ with gr.Row(visible=not is_public):
+ image_loaders = gr.CheckboxGroup(image_loaders_options,
+ label="Force Image Reader",
+ value=image_loaders_options0)
+ pdf_loaders = gr.CheckboxGroup(pdf_loaders_options,
+ label="Force PDF Reader",
+ value=pdf_loaders_options0)
+ url_loaders = gr.CheckboxGroup(url_loaders_options,
+ label="Force URL Reader", value=url_loaders_options0)
+ jq_schema = gr.Textbox(label="JSON jq_schema", value=jq_schema0)
+
+ min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public)
+ top_k_docs = gr.Slider(minimum=min_top_k_docs, maximum=max_top_k_docs, step=1,
+ value=kwargs['top_k_docs'],
+ label=label_top_k_docs,
+ # info="For LangChain",
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ interactive=not is_public)
+ chunk_size = gr.Number(value=kwargs['chunk_size'],
+ label="Chunk size for document chunking",
+ info="For LangChain (ignored if chunk=False)",
+ minimum=128,
+ maximum=2048,
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ interactive=not is_public,
+ precision=0)
+ docs_ordering_type = gr.Radio(
+ docs_ordering_types,
+ value=kwargs['docs_ordering_type'],
+ label="Document Sorting in LLM Context",
+ visible=True)
+ chunk = gr.components.Checkbox(value=kwargs['chunk'],
+ label="Whether to chunk documents",
+ info="For LangChain",
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ interactive=not is_public)
+ embed = gr.components.Checkbox(value=True,
+ label="Whether to embed text",
+ info="For LangChain",
+ visible=False)
+ with gr.Row():
+ stream_output = gr.components.Checkbox(label="Stream output",
+ value=kwargs['stream_output'])
+ do_sample = gr.Checkbox(label="Sample",
+ info="Enable sampler (required for use of temperature, top_p, top_k)",
+ value=kwargs['do_sample'])
+ max_time = gr.Slider(minimum=0, maximum=kwargs['max_max_time'], step=1,
+ value=min(kwargs['max_max_time'],
+ kwargs['max_time']), label="Max. time",
+ info="Max. time to search optimal output.")
+ temperature = gr.Slider(minimum=0.01, maximum=2,
+ value=kwargs['temperature'],
+ label="Temperature",
+ info="Lower is deterministic, higher more creative")
+ top_p = gr.Slider(minimum=1e-3, maximum=1.0 - 1e-3,
+ value=kwargs['top_p'], label="Top p",
+ info="Cumulative probability of tokens to sample from")
+ top_k = gr.Slider(
+ minimum=1, maximum=100, step=1,
+ value=kwargs['top_k'], label="Top k",
+ info='Num. tokens to sample from'
+ )
+ # FIXME: https://github.com/h2oai/h2ogpt/issues/106
+ if os.getenv('TESTINGFAIL'):
+ max_beams = 8 if not (memory_restriction_level or is_public) else 1
+ else:
+ max_beams = 1
+ num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
+ value=min(max_beams, kwargs['num_beams']), label="Beams",
+ info="Number of searches for optimal overall probability. "
+ "Uses more GPU memory/compute",
+ interactive=False, visible=max_beams > 1)
+ max_max_new_tokens = get_max_max_new_tokens(model_state0, **kwargs)
+ max_new_tokens = gr.Slider(
+ minimum=1, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
+ )
+ min_new_tokens = gr.Slider(
+ minimum=0, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
+ )
+ max_new_tokens2 = gr.Slider(
+ minimum=1, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length 2",
+ visible=False and not kwargs['model_lock'],
+ )
+ min_new_tokens2 = gr.Slider(
+ minimum=0, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length 2",
+ visible=False and not kwargs['model_lock'],
+ )
+ min_max_new_tokens = gr.Slider(
+ minimum=1, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['min_max_new_tokens']),
+ label="Min. of Max output length",
+ )
+ early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
+ value=kwargs['early_stopping'], visible=max_beams > 1)
+ repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
+ value=kwargs['repetition_penalty'],
+ label="Repetition Penalty")
+ num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
+ value=kwargs['num_return_sequences'],
+ label="Number Returns", info="Must be <= num_beams",
+ interactive=not is_public, visible=max_beams > 1)
+ chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
+ visible=False, # no longer support nochat in UI
+ interactive=not is_public,
+ )
+ with gr.Row():
+ count_chat_tokens_btn = gr.Button(value="Count Chat Tokens",
+ visible=not is_public and not kwargs['model_lock'],
+ interactive=not is_public, size='sm')
+ chat_token_count = gr.Textbox(label="Chat Token Count Result", value=None,
+ visible=not is_public and not kwargs['model_lock'],
+ interactive=False)
+
+ models_tab = gr.TabItem("Models") \
+ if kwargs['visible_models_tab'] and not bool(kwargs['model_lock']) else gr.Row(visible=False)
+ with models_tab:
+ load_msg = "Download/Load Model" if not is_public \
+ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
+ if kwargs['base_model'] not in ['', None, no_model_str]:
+ load_msg += ' [WARNING: Avoid --base_model on CLI for memory efficient Load-Unload]'
+ load_msg2 = load_msg + "(Model 2)"
+ variant_load_msg = 'primary' if not is_public else 'secondary'
+ with gr.Row():
+ n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
+ with gr.Column():
+ with gr.Row():
+ with gr.Column(scale=20, visible=not kwargs['model_lock']):
+ load_model_button = gr.Button(load_msg, variant=variant_load_msg, scale=0,
+ size='sm', interactive=not is_public)
+ model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Base Model",
+ value=kwargs['base_model'])
+ lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
+ value=kwargs['lora_weights'], visible=kwargs['show_lora'])
+ server_choice = gr.Dropdown(server_options_state.value[0], label="Choose Server",
+ value=kwargs['inference_server'], visible=not is_public)
+ max_seq_len = gr.Number(value=kwargs['max_seq_len'] or 2048,
+ minimum=128,
+ maximum=2 ** 18,
+ info="If standard LLaMa-2, choose up to 4096",
+ label="max_seq_len")
+ rope_scaling = gr.Textbox(value=str(kwargs['rope_scaling'] or {}),
+ label="rope_scaling")
+ row_llama = gr.Row(visible=kwargs['show_llama'] and kwargs['base_model'] == 'llama')
+ with row_llama:
+ model_path_llama = gr.Textbox(value=kwargs['llamacpp_dict']['model_path_llama'],
+ lines=4,
+ label="Choose LLaMa.cpp Model Path/URL (for Base Model: llama)",
+ visible=kwargs['show_llama'])
+ n_gpu_layers = gr.Number(value=kwargs['llamacpp_dict']['n_gpu_layers'],
+ minimum=0, maximum=100,
+ label="LLaMa.cpp Num. GPU Layers Offloaded",
+ visible=kwargs['show_llama'])
+ n_batch = gr.Number(value=kwargs['llamacpp_dict']['n_batch'],
+ minimum=0, maximum=2048,
+ label="LLaMa.cpp Batch Size",
+ visible=kwargs['show_llama'])
+ n_gqa = gr.Number(value=kwargs['llamacpp_dict']['n_gqa'],
+ minimum=0, maximum=32,
+ label="LLaMa.cpp Num. Group Query Attention (8 for 70B LLaMa2)",
+ visible=kwargs['show_llama'])
+ llamacpp_dict_more = gr.Textbox(value="{}",
+ lines=4,
+ label="Dict for other LLaMa.cpp/GPT4All options",
+ visible=kwargs['show_llama'])
+ row_gpt4all = gr.Row(
+ visible=kwargs['show_gpt4all'] and kwargs['base_model'] in ['gptj',
+ 'gpt4all_llama'])
+ with row_gpt4all:
+ model_name_gptj = gr.Textbox(value=kwargs['llamacpp_dict']['model_name_gptj'],
+ label="Choose GPT4All GPTJ Model Path/URL (for Base Model: gptj)",
+ visible=kwargs['show_gpt4all'])
+ model_name_gpt4all_llama = gr.Textbox(
+ value=kwargs['llamacpp_dict']['model_name_gpt4all_llama'],
+ label="Choose GPT4All LLaMa Model Path/URL (for Base Model: gpt4all_llama)",
+ visible=kwargs['show_gpt4all'])
+ with gr.Column(scale=1, visible=not kwargs['model_lock']):
+ model_load8bit_checkbox = gr.components.Checkbox(
+ label="Load 8-bit [requires support]",
+ value=kwargs['load_8bit'], interactive=not is_public)
+ model_load4bit_checkbox = gr.components.Checkbox(
+ label="Load 4-bit [requires support]",
+ value=kwargs['load_4bit'], interactive=not is_public)
+ model_low_bit_mode = gr.Slider(value=kwargs['low_bit_mode'],
+ minimum=0, maximum=4, step=1,
+ label="low_bit_mode")
+ model_load_gptq = gr.Textbox(label="gptq", value=kwargs['load_gptq'],
+ interactive=not is_public)
+ model_load_exllama_checkbox = gr.components.Checkbox(
+ label="Load load_exllama [requires support]",
+ value=kwargs['load_exllama'], interactive=not is_public)
+ model_safetensors_checkbox = gr.components.Checkbox(
+ label="Safetensors [requires support]",
+ value=kwargs['use_safetensors'], interactive=not is_public)
+ model_revision = gr.Textbox(label="revision", value=kwargs['revision'],
+ interactive=not is_public)
+ model_use_gpu_id_checkbox = gr.components.Checkbox(
+ label="Choose Devices [If not Checked, use all GPUs]",
+ value=kwargs['use_gpu_id'], interactive=not is_public,
+ visible=n_gpus != 0)
+ model_gpu = gr.Dropdown(n_gpus_list,
+ label="GPU ID [-1 = all GPUs, if Choose is enabled]",
+ value=kwargs['gpu_id'], interactive=not is_public,
+ visible=n_gpus != 0)
+ model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
+ interactive=False)
+ lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
+ visible=kwargs['show_lora'], interactive=False)
+ server_used = gr.Textbox(label="Current Server",
+ value=kwargs['inference_server'],
+ visible=bool(kwargs['inference_server']) and not is_public,
+ interactive=False)
+ prompt_dict = gr.Textbox(label="Prompt (or Custom)",
+ value=pprint.pformat(kwargs['prompt_dict'], indent=4),
+ interactive=not is_public, lines=4)
+ col_model2 = gr.Column(visible=False)
+ with col_model2:
+ with gr.Row():
+ with gr.Column(scale=20, visible=not kwargs['model_lock']):
+ load_model_button2 = gr.Button(load_msg2, variant=variant_load_msg, scale=0,
+ size='sm', interactive=not is_public)
+ model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
+ value=no_model_str)
+ lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
+ value=no_lora_str,
+ visible=kwargs['show_lora'])
+ server_choice2 = gr.Dropdown(server_options_state.value[0], label="Choose Server 2",
+ value=no_server_str,
+ visible=not is_public)
+ max_seq_len2 = gr.Number(value=kwargs['max_seq_len'] or 2048,
+ minimum=128,
+ maximum=2 ** 18,
+ info="If standard LLaMa-2, choose up to 4096",
+ label="max_seq_len Model 2")
+ rope_scaling2 = gr.Textbox(value=str(kwargs['rope_scaling'] or {}),
+ label="rope_scaling Model 2")
+
+ row_llama2 = gr.Row(
+ visible=kwargs['show_llama'] and kwargs['base_model'] == 'llama')
+ with row_llama2:
+ model_path_llama2 = gr.Textbox(
+ value=kwargs['llamacpp_dict']['model_path_llama'],
+ label="Choose LLaMa.cpp Model 2 Path/URL (for Base Model: llama)",
+ lines=4,
+ visible=kwargs['show_llama'])
+ n_gpu_layers2 = gr.Number(value=kwargs['llamacpp_dict']['n_gpu_layers'],
+ minimum=0, maximum=100,
+ label="LLaMa.cpp Num. GPU 2 Layers Offloaded",
+ visible=kwargs['show_llama'])
+ n_batch2 = gr.Number(value=kwargs['llamacpp_dict']['n_batch'],
+ minimum=0, maximum=2048,
+ label="LLaMa.cpp Model 2 Batch Size",
+ visible=kwargs['show_llama'])
+ n_gqa2 = gr.Number(value=kwargs['llamacpp_dict']['n_gqa'],
+ minimum=0, maximum=32,
+ label="LLaMa.cpp Model 2 Num. Group Query Attention (8 for 70B LLaMa2)",
+ visible=kwargs['show_llama'])
+ llamacpp_dict_more2 = gr.Textbox(value="{}",
+ lines=4,
+ label="Model 2 Dict for other LLaMa.cpp/GPT4All options",
+ visible=kwargs['show_llama'])
+ row_gpt4all2 = gr.Row(
+ visible=kwargs['show_gpt4all'] and kwargs['base_model'] in ['gptj',
+ 'gpt4all_llama'])
+ with row_gpt4all2:
+ model_name_gptj2 = gr.Textbox(value=kwargs['llamacpp_dict']['model_name_gptj'],
+ label="Choose GPT4All GPTJ Model 2 Path/URL (for Base Model: gptj)",
+ visible=kwargs['show_gpt4all'])
+ model_name_gpt4all_llama2 = gr.Textbox(
+ value=kwargs['llamacpp_dict']['model_name_gpt4all_llama'],
+ label="Choose GPT4All LLaMa Model 2 Path/URL (for Base Model: gpt4all_llama)",
+ visible=kwargs['show_gpt4all'])
+
+ with gr.Column(scale=1, visible=not kwargs['model_lock']):
+ model_load8bit_checkbox2 = gr.components.Checkbox(
+ label="Load 8-bit (Model 2) [requires support]",
+ value=kwargs['load_8bit'], interactive=not is_public)
+ model_load4bit_checkbox2 = gr.components.Checkbox(
+ label="Load 4-bit (Model 2) [requires support]",
+ value=kwargs['load_4bit'], interactive=not is_public)
+ model_low_bit_mode2 = gr.Slider(value=kwargs['low_bit_mode'],
+ # ok that same as Model 1
+ minimum=0, maximum=4, step=1,
+ label="low_bit_mode (Model 2)")
+ model_load_gptq2 = gr.Textbox(label="gptq (Model 2)", value='',
+ interactive=not is_public)
+ model_load_exllama_checkbox2 = gr.components.Checkbox(
+ label="Load load_exllama (Model 2) [requires support]",
+ value=False, interactive=not is_public)
+ model_safetensors_checkbox2 = gr.components.Checkbox(
+ label="Safetensors (Model 2) [requires support]",
+ value=False, interactive=not is_public)
+ model_revision2 = gr.Textbox(label="revision (Model 2)", value='',
+ interactive=not is_public)
+ model_use_gpu_id_checkbox2 = gr.components.Checkbox(
+ label="Choose Devices (Model 2) [If not Checked, use all GPUs]",
+ value=kwargs[
+ 'use_gpu_id'], interactive=not is_public)
+ model_gpu2 = gr.Dropdown(n_gpus_list,
+ label="GPU ID (Model 2) [-1 = all GPUs, if choose is enabled]",
+ value=kwargs['gpu_id'], interactive=not is_public)
+ # no model/lora loaded ever in model2 by default
+ model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str,
+ interactive=False)
+ lora_used2 = gr.Textbox(label="Current LORA (Model 2)", value=no_lora_str,
+ visible=kwargs['show_lora'], interactive=False)
+ server_used2 = gr.Textbox(label="Current Server (Model 2)", value=no_server_str,
+ interactive=False,
+ visible=not is_public)
+ prompt_dict2 = gr.Textbox(label="Prompt (or Custom) (Model 2)",
+ value=pprint.pformat(kwargs['prompt_dict'], indent=4),
+ interactive=not is_public, lines=4)
+ compare_checkbox = gr.components.Checkbox(label="Compare Two Models",
+ value=kwargs['model_lock'],
+ visible=not is_public and not kwargs['model_lock'])
+ with gr.Row(visible=not kwargs['model_lock']):
+ with gr.Column(scale=50):
+ new_model = gr.Textbox(label="New Model name/path/URL", interactive=not is_public)
+ with gr.Column(scale=50):
+ new_lora = gr.Textbox(label="New LORA name/path/URL", visible=kwargs['show_lora'],
+ interactive=not is_public)
+ with gr.Column(scale=50):
+ new_server = gr.Textbox(label="New Server url:port", interactive=not is_public)
+ with gr.Row():
+ add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0,
+ variant=variant_load_msg,
+ size='sm', interactive=not is_public)
+ system_tab = gr.TabItem("System") \
+ if kwargs['visible_system_tab'] else gr.Row(visible=False)
+ with system_tab:
+ with gr.Row():
+ with gr.Column(scale=1):
+ side_bar_text = gr.Textbox('on' if kwargs['visible_side_bar'] else 'off',
+ visible=False, interactive=False)
+ doc_count_text = gr.Textbox('on' if kwargs['visible_doc_track'] else 'off',
+ visible=False, interactive=False)
+ submit_buttons_text = gr.Textbox('on' if kwargs['visible_submit_buttons'] else 'off',
+ visible=False, interactive=False)
+ visible_models_text = gr.Textbox('on' if kwargs['visible_visible_models'] else 'off',
+ visible=False, interactive=False)
+
+ side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
+ doc_count_btn = gr.Button("Toggle SideBar Document Count/Show Newest", variant="secondary",
+ size="sm")
+ submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
+ visible_model_btn = gr.Button("Toggle Visible Models", variant="secondary", size="sm")
+ col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
+ text_outputs_height = gr.Slider(minimum=100, maximum=2000, value=kwargs['height'] or 400,
+ step=50, label='Chat Height')
+ dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
+ with gr.Column(scale=4):
+ pass
+ system_visible0 = not is_public and not admin_pass
+ admin_row = gr.Row()
+ with admin_row:
+ with gr.Column(scale=1):
+ admin_pass_textbox = gr.Textbox(label="Admin Password",
+ type='password',
+ visible=not system_visible0)
+ with gr.Column(scale=4):
+ pass
+ system_row = gr.Row(visible=system_visible0)
+ with system_row:
+ with gr.Column():
+ with gr.Row():
+ system_btn = gr.Button(value='Get System Info', size='sm')
+ system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True)
+ with gr.Row():
+ system_input = gr.Textbox(label='System Info Dict Password', interactive=True,
+ visible=not is_public)
+ system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public, size='sm')
+ system_text2 = gr.Textbox(label='System Info Dict', interactive=False,
+ visible=not is_public, show_copy_button=True)
+ with gr.Row():
+ system_btn3 = gr.Button(value='Get Hash', visible=not is_public, size='sm')
+ system_text3 = gr.Textbox(label='Hash', interactive=False,
+ visible=not is_public, show_copy_button=True)
+ system_btn4 = gr.Button(value='Get Model Names', visible=not is_public, size='sm')
+ system_text4 = gr.Textbox(label='Model Names', interactive=False,
+ visible=not is_public, show_copy_button=True)
+
+ with gr.Row():
+ zip_btn = gr.Button("Zip", size='sm')
+ zip_text = gr.Textbox(label="Zip file name", interactive=False)
+ file_output = gr.File(interactive=False, label="Zip file to Download")
+ with gr.Row():
+ s3up_btn = gr.Button("S3UP", size='sm')
+ s3up_text = gr.Textbox(label='S3UP result', interactive=False)
+
+ tos_tab = gr.TabItem("Terms of Service") \
+ if kwargs['visible_tos_tab'] else gr.Row(visible=False)
+ with tos_tab:
+ description = ""
+ description += """
etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation + is_same = True + # length of conversation has to be same + if len(x) != len(y): + return False + if len(x) != len(y): + return False + for stepx, stepy in zip(x, y): + if len(stepx) != len(stepy): + # something off with a conversation + return False + for stepxx, stepyy in zip(stepx, stepy): + if len(stepxx) != len(stepyy): + # something off with a conversation + return False + if len(stepxx) != 2: + # something off + return False + if len(stepyy) != 2: + # something off + return False + questionx = stepxx[0].replace('
', '').replace('
', '') if stepxx[0] is not None else None + answerx = stepxx[1].replace('', '').replace('
', '') if stepxx[1] is not None else None + + questiony = stepyy[0].replace('', '').replace('
', '') if stepyy[0] is not None else None + answery = stepyy[1].replace('', '').replace('
', '') if stepyy[1] is not None else None + + if questionx != questiony or answerx != answery: + return False + return is_same + + def save_chat(*args, chat_is_list=False, auth_filename=None, auth_freeze=None): + args_list = list(args) + db1s = args_list[0] + requests_state1 = args_list[1] + args_list = args_list[2:] + if not chat_is_list: + # list of chatbot histories, + # can't pass in list with list of chatbot histories and state due to gradio limits + chat_list = args_list[:-1] + else: + assert len(args_list) == 2 + chat_list = args_list[0] + # if old chat file with single chatbot, get into shape + if isinstance(chat_list, list) and len(chat_list) > 0 and isinstance(chat_list[0], list) and len( + chat_list[0]) == 2 and isinstance(chat_list[0][0], str) and isinstance(chat_list[0][1], str): + chat_list = [chat_list] + # remove None histories + chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None] + chat_list_none = [x for x in chat_list if x not in chat_list_not_none] + if len(chat_list_none) > 0 and len(chat_list_not_none) == 0: + raise ValueError("Invalid chat file") + # dict with keys of short chat names, values of list of list of chatbot histories + chat_state1 = args_list[-1] + short_chats = list(chat_state1.keys()) + if len(chat_list_not_none) > 0: + # make short_chat key from only first history, based upon question that is same anyways + chat_first = chat_list_not_none[0] + short_chat = get_short_chat(chat_first, short_chats) + if short_chat: + old_chat_lists = list(chat_state1.values()) + already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists]) + if not already_exists: + chat_state1[short_chat] = chat_list.copy() + + # reverse so newest at top + choices = list(chat_state1.keys()).copy() + choices.reverse() + + # save saved chats and chatbots to auth file + text_output1 = chat_list[0] + text_output21 = chat_list[1] + text_outputs1 = chat_list[2:] + save_auth(requests_state1, auth_filename, auth_freeze, chat_state1=chat_state1, + text_output1=text_output1, text_output21=text_output21, text_outputs1=text_outputs1) + + return chat_state1, gr.update(choices=choices, value=None) + + def switch_chat(chat_key, chat_state1, num_model_lock=0): + chosen_chat = chat_state1[chat_key] + # deal with possible different size of chat list vs. current list + ret_chat = [None] * (2 + num_model_lock) + for chati in range(0, 2 + num_model_lock): + ret_chat[chati % len(ret_chat)] = chosen_chat[chati % len(chosen_chat)] + return tuple(ret_chat) + + def clear_texts(*args): + return tuple([gr.Textbox.update(value='')] * len(args)) + + def clear_scores(): + return gr.Textbox.update(value=res_value), \ + gr.Textbox.update(value='Response Score: NA'), \ + gr.Textbox.update(value='Response Score: NA') + + switch_chat_fun = functools.partial(switch_chat, num_model_lock=len(text_outputs)) + radio_chats.input(switch_chat_fun, + inputs=[radio_chats, chat_state], + outputs=[text_output, text_output2] + text_outputs) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + + def remove_chat(chat_key, chat_state1): + if isinstance(chat_key, str): + chat_state1.pop(chat_key, None) + return gr.update(choices=list(chat_state1.keys()), value=None), chat_state1 + + remove_chat_event = remove_chat_btn.click(remove_chat, + inputs=[radio_chats, chat_state], + outputs=[radio_chats, chat_state], + queue=False, api_name='remove_chat') + + def get_chats1(chat_state1): + base = 'chats' + base = makedirs(base, exist_ok=True, tmp_ok=True, use_base=True) + filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4())) + with open(filename, "wt") as f: + f.write(json.dumps(chat_state1, indent=2)) + return filename + + export_chat_event = export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False, + api_name='export_chats' if allow_api else None) + + def add_chats_from_file(db1s, requests_state1, file, chat_state1, radio_chats1, chat_exception_text1, + auth_filename=None, auth_freeze=None): + if not file: + return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 + if isinstance(file, str): + files = [file] + else: + files = file + if not files: + return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 + chat_exception_list = [] + for file1 in files: + try: + if hasattr(file1, 'name'): + file1 = file1.name + with open(file1, "rt") as f: + new_chats = json.loads(f.read()) + for chat1_k, chat1_v in new_chats.items(): + # ignore chat1_k, regenerate and de-dup to avoid loss + chat_state1, _ = save_chat(db1s, requests_state1, chat1_v, chat_state1, chat_is_list=True) + except BaseException as e: + t, v, tb = sys.exc_info() + ex = ''.join(traceback.format_exception(t, v, tb)) + ex_str = "File %s exception: %s" % (file1, str(e)) + print(ex_str, flush=True) + chat_exception_list.append(ex_str) + chat_exception_text1 = '\n'.join(chat_exception_list) + # save chat to auth file + save_auth(requests_state1, auth_filename, auth_freeze, chat_state1=chat_state1) + return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 + + # note for update_user_db_func output is ignored for db + chatup_change_eventa = chatsup_output.change(user_state_setup, + inputs=[my_db_state, requests_state, langchain_mode], + outputs=[my_db_state, requests_state, langchain_mode], + show_progress='minimal') + add_chats_from_file_func = functools.partial(add_chats_from_file, + auth_filename=kwargs['auth_filename'], + auth_freeze=kwargs['auth_freeze'], + ) + chatup_change_event = chatup_change_eventa.then(add_chats_from_file_func, + inputs=[my_db_state, requests_state] + + [chatsup_output, chat_state, radio_chats, + chat_exception_text], + outputs=[chatsup_output, chat_state, radio_chats, + chat_exception_text], + queue=False, + api_name='add_to_chats' if allow_api else None) + + clear_chat_event = clear_chat_btn.click(fn=clear_texts, + inputs=[text_output, text_output2] + text_outputs, + outputs=[text_output, text_output2] + text_outputs, + queue=False, api_name='clear' if allow_api else None) \ + .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + + clear_eventa = save_chat_btn.click(user_state_setup, + inputs=[my_db_state, requests_state, langchain_mode], + outputs=[my_db_state, requests_state, langchain_mode], + show_progress='minimal') + save_chat_func = functools.partial(save_chat, + auth_filename=kwargs['auth_filename'], + auth_freeze=kwargs['auth_freeze'], + ) + clear_event = clear_eventa.then(save_chat_func, + inputs=[my_db_state, requests_state] + + [text_output, text_output2] + text_outputs + + [chat_state], + outputs=[chat_state, radio_chats], + api_name='save_chat' if allow_api else None) + if kwargs['score_model']: + clear_event2 = clear_event.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + + # NOTE: clear of instruction/iinput for nochat has to come after score, + # because score for nochat consumes actual textbox, while chat consumes chat history filled by user() + no_chat_args = dict(fn=fun, + inputs=[model_state, my_db_state, selection_docs_state, requests_state] + inputs_list, + outputs=text_output_nochat, + queue=queue, + ) + submit_event_nochat = submit_nochat.click(**no_chat_args, api_name='submit_nochat' if allow_api else None) \ + .then(clear_torch_cache) \ + .then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \ + .then(clear_instruct, None, instruction_nochat) \ + .then(clear_instruct, None, iinput_nochat) \ + .then(clear_torch_cache) + # copy of above with text box submission + submit_event_nochat2 = instruction_nochat.submit(**no_chat_args) \ + .then(clear_torch_cache) \ + .then(**score_args_nochat, queue=queue) \ + .then(clear_instruct, None, instruction_nochat) \ + .then(clear_instruct, None, iinput_nochat) \ + .then(clear_torch_cache) + + submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str, + inputs=[model_state, my_db_state, selection_docs_state, + requests_state, + inputs_dict_str], + outputs=text_output_nochat_api, + queue=True, # required for generator + api_name='submit_nochat_api' if allow_api else None) + + submit_event_nochat_api_plain = submit_nochat_api_plain.click(fun_with_dict_str_plain, + inputs=inputs_dict_str, + outputs=text_output_nochat_api, + queue=False, + api_name='submit_nochat_plain_api' if allow_api else None) + + def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, + load_8bit, load_4bit, low_bit_mode, + load_gptq, load_exllama, use_safetensors, revision, + use_gpu_id, gpu_id, max_seq_len1, rope_scaling1, + model_path_llama1, model_name_gptj1, model_name_gpt4all_llama1, + n_gpu_layers1, n_batch1, n_gqa1, llamacpp_dict_more1, + system_prompt1): + try: + llamacpp_dict = ast.literal_eval(llamacpp_dict_more1) + except: + print("Failed to use user input for llamacpp_dict_more1 dict", flush=True) + llamacpp_dict = {} + llamacpp_dict.update(dict(model_path_llama=model_path_llama1, + model_name_gptj=model_name_gptj1, + model_name_gpt4all_llama=model_name_gpt4all_llama1, + n_gpu_layers=n_gpu_layers1, + n_batch=n_batch1, + n_gqa=n_gqa1, + )) + + # ensure no API calls reach here + if is_public: + raise RuntimeError("Illegal access for %s" % model_name) + # ensure old model removed from GPU memory + if kwargs['debug']: + print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True) + + model0 = model_state0['model'] + if isinstance(model_state_old['model'], str) and \ + model0 is not None and \ + hasattr(model0, 'cpu'): + # best can do, move model loaded at first to CPU + model0.cpu() + + if model_state_old['model'] is not None and \ + not isinstance(model_state_old['model'], str): + if hasattr(model_state_old['model'], 'cpu'): + try: + model_state_old['model'].cpu() + except Exception as e: + # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data! + print("Unable to put model on CPU: %s" % str(e), flush=True) + del model_state_old['model'] + model_state_old['model'] = None + + if model_state_old['tokenizer'] is not None and not isinstance(model_state_old['tokenizer'], str): + del model_state_old['tokenizer'] + model_state_old['tokenizer'] = None + + clear_torch_cache() + if kwargs['debug']: + print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True) + if not model_name: + model_name = no_model_str + if model_name == no_model_str: + # no-op if no model, just free memory + # no detranscribe needed for model, never go into evaluate + lora_weights = no_lora_str + server_name = no_server_str + return kwargs['model_state_none'].copy(), \ + model_name, lora_weights, server_name, prompt_type_old, \ + gr.Slider.update(maximum=256), \ + gr.Slider.update(maximum=256) + + # don't deepcopy, can contain model itself + all_kwargs1 = all_kwargs.copy() + all_kwargs1['base_model'] = model_name.strip() + all_kwargs1['load_8bit'] = load_8bit + all_kwargs1['load_4bit'] = load_4bit + all_kwargs1['low_bit_mode'] = low_bit_mode + all_kwargs1['load_gptq'] = load_gptq + all_kwargs1['load_exllama'] = load_exllama + all_kwargs1['use_safetensors'] = use_safetensors + all_kwargs1['revision'] = None if not revision else revision # transcribe, don't pass '' + all_kwargs1['use_gpu_id'] = use_gpu_id + all_kwargs1['gpu_id'] = int(gpu_id) if gpu_id not in [None, 'None'] else None # detranscribe + all_kwargs1['llamacpp_dict'] = llamacpp_dict + all_kwargs1['max_seq_len'] = max_seq_len1 + try: + all_kwargs1['rope_scaling'] = str_to_dict(rope_scaling1) # transcribe + except: + print("Failed to use user input for rope_scaling dict", flush=True) + all_kwargs1['rope_scaling'] = {} + model_lower = model_name.strip().lower() + if model_lower in inv_prompt_type_to_model_lower: + prompt_type1 = inv_prompt_type_to_model_lower[model_lower] + else: + prompt_type1 = prompt_type_old + + # detranscribe + if lora_weights == no_lora_str: + lora_weights = '' + all_kwargs1['lora_weights'] = lora_weights.strip() + if server_name == no_server_str: + server_name = '' + all_kwargs1['inference_server'] = server_name.strip() + + model1, tokenizer1, device1 = get_model(reward_type=False, + **get_kwargs(get_model, exclude_names=['reward_type'], + **all_kwargs1)) + clear_torch_cache() + + tokenizer_base_model = model_name + prompt_dict1, error0 = get_prompt(prompt_type1, '', + chat=False, context='', reduced=False, making_context=False, + return_dict=True, system_prompt=system_prompt1) + model_state_new = dict(model=model1, tokenizer=tokenizer1, device=device1, + base_model=model_name, tokenizer_base_model=tokenizer_base_model, + lora_weights=lora_weights, inference_server=server_name, + prompt_type=prompt_type1, prompt_dict=prompt_dict1, + # FIXME: not typically required, unless want to expose adding h2ogpt endpoint in UI + visible_models=None, h2ogpt_key=None, + ) + + max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs) + + if kwargs['debug']: + print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True) + return model_state_new, model_name, lora_weights, server_name, prompt_type1, \ + gr.Slider.update(maximum=max_max_new_tokens1), \ + gr.Slider.update(maximum=max_max_new_tokens1) + + def get_prompt_str(prompt_type1, prompt_dict1, system_prompt1, which=0): + if prompt_type1 in ['', None]: + print("Got prompt_type %s: %s" % (which, prompt_type1), flush=True) + return str({}) + prompt_dict1, prompt_dict_error = get_prompt(prompt_type1, prompt_dict1, chat=False, context='', + reduced=False, making_context=False, return_dict=True, + system_prompt=system_prompt1) + if prompt_dict_error: + return str(prompt_dict_error) + else: + # return so user can manipulate if want and use as custom + return str(prompt_dict1) + + get_prompt_str_func1 = functools.partial(get_prompt_str, which=1) + get_prompt_str_func2 = functools.partial(get_prompt_str, which=2) + prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict, system_prompt], + outputs=prompt_dict, queue=False) + prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2, system_prompt], + outputs=prompt_dict2, + queue=False) + + def dropdown_prompt_type_list(x): + return gr.Dropdown.update(value=x) + + def chatbot_list(x, model_used_in): + return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]') + + load_model_args = dict(fn=load_model, + inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type, + model_load8bit_checkbox, model_load4bit_checkbox, model_low_bit_mode, + model_load_gptq, model_load_exllama_checkbox, + model_safetensors_checkbox, model_revision, + model_use_gpu_id_checkbox, model_gpu, + max_seq_len, rope_scaling, + model_path_llama, model_name_gptj, model_name_gpt4all_llama, + n_gpu_layers, n_batch, n_gqa, llamacpp_dict_more, + system_prompt], + outputs=[model_state, model_used, lora_used, server_used, + # if prompt_type changes, prompt_dict will change via change rule + prompt_type, max_new_tokens, min_new_tokens, + ]) + prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type) + chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output) + nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat) + load_model_event = load_model_button.click(**load_model_args, + api_name='load_model' if allow_api and not is_public else None) \ + .then(**prompt_update_args) \ + .then(**chatbot_update_args) \ + .then(**nochat_update_args) \ + .then(clear_torch_cache) + + load_model_args2 = dict(fn=load_model, + inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2, + model_load8bit_checkbox2, model_load4bit_checkbox2, model_low_bit_mode2, + model_load_gptq2, model_load_exllama_checkbox2, + model_safetensors_checkbox2, model_revision2, + model_use_gpu_id_checkbox2, model_gpu2, + max_seq_len2, rope_scaling2, + model_path_llama2, model_name_gptj2, model_name_gpt4all_llama2, + n_gpu_layers2, n_batch2, n_gqa2, llamacpp_dict_more2, + system_prompt], + outputs=[model_state2, model_used2, lora_used2, server_used2, + # if prompt_type2 changes, prompt_dict2 will change via change rule + prompt_type2, max_new_tokens2, min_new_tokens2 + ]) + prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2) + chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2) + load_model_event2 = load_model_button2.click(**load_model_args2, + api_name='load_model2' if allow_api and not is_public else None) \ + .then(**prompt_update_args2) \ + .then(**chatbot_update_args2) \ + .then(clear_torch_cache) + + def dropdown_model_lora_server_list(model_list0, model_x, + lora_list0, lora_x, + server_list0, server_x, + model_used1, lora_used1, server_used1, + model_used2, lora_used2, server_used2, + ): + model_new_state = [model_list0[0] + [model_x]] + model_new_options = [*model_new_state[0]] + if no_model_str in model_new_options: + model_new_options.remove(no_model_str) + model_new_options = [no_model_str] + sorted(model_new_options) + x1 = model_x if model_used1 == no_model_str else model_used1 + x2 = model_x if model_used2 == no_model_str else model_used2 + ret1 = [gr.Dropdown.update(value=x1, choices=model_new_options), + gr.Dropdown.update(value=x2, choices=model_new_options), + '', model_new_state] + + lora_new_state = [lora_list0[0] + [lora_x]] + lora_new_options = [*lora_new_state[0]] + if no_lora_str in lora_new_options: + lora_new_options.remove(no_lora_str) + lora_new_options = [no_lora_str] + sorted(lora_new_options) + # don't switch drop-down to added lora if already have model loaded + x1 = lora_x if model_used1 == no_model_str else lora_used1 + x2 = lora_x if model_used2 == no_model_str else lora_used2 + ret2 = [gr.Dropdown.update(value=x1, choices=lora_new_options), + gr.Dropdown.update(value=x2, choices=lora_new_options), + '', lora_new_state] + + server_new_state = [server_list0[0] + [server_x]] + server_new_options = [*server_new_state[0]] + if no_server_str in server_new_options: + server_new_options.remove(no_server_str) + server_new_options = [no_server_str] + sorted(server_new_options) + # don't switch drop-down to added server if already have model loaded + x1 = server_x if model_used1 == no_model_str else server_used1 + x2 = server_x if model_used2 == no_model_str else server_used2 + ret3 = [gr.Dropdown.update(value=x1, choices=server_new_options), + gr.Dropdown.update(value=x2, choices=server_new_options), + '', server_new_state] + + return tuple(ret1 + ret2 + ret3) + + add_model_lora_server_event = \ + add_model_lora_server_button.click(fn=dropdown_model_lora_server_list, + inputs=[model_options_state, new_model] + + [lora_options_state, new_lora] + + [server_options_state, new_server] + + [model_used, lora_used, server_used] + + [model_used2, lora_used2, server_used2], + outputs=[model_choice, model_choice2, new_model, model_options_state] + + [lora_choice, lora_choice2, new_lora, lora_options_state] + + [server_choice, server_choice2, new_server, + server_options_state], + queue=False) + + go_event = go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None, + queue=False) \ + .then(lambda: gr.update(visible=True), None, normal_block, queue=False) \ + .then(**load_model_args, queue=False).then(**prompt_update_args, queue=False) + + def compare_textbox_fun(x): + return gr.Textbox.update(visible=x) + + def compare_column_fun(x): + return gr.Column.update(visible=x) + + def compare_prompt_fun(x): + return gr.Dropdown.update(visible=x) + + def slider_fun(x): + return gr.Slider.update(visible=x) + + compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, + api_name="compare_checkbox" if allow_api else None) \ + .then(compare_column_fun, compare_checkbox, col_model2) \ + .then(compare_prompt_fun, compare_checkbox, prompt_type2) \ + .then(compare_textbox_fun, compare_checkbox, score_text2) \ + .then(slider_fun, compare_checkbox, max_new_tokens2) \ + .then(slider_fun, compare_checkbox, min_new_tokens2) + # FIXME: add score_res2 in condition, but do better + + # callback for logging flagged input/output + callback.setup(inputs_list + [text_output, text_output2] + text_outputs, "flagged_data_points") + flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2] + text_outputs, + None, + preprocess=False, + api_name='flag' if allow_api else None, queue=False) + flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None, + preprocess=False, + api_name='flag_nochat' if allow_api else None, queue=False) + + def get_system_info(): + if is_public: + time.sleep(10) # delay to avoid spam since queue=False + return gr.Textbox.update(value=system_info_print()) + + system_event = system_btn.click(get_system_info, outputs=system_text, + api_name='system_info' if allow_api else None, queue=False) + + def get_system_info_dict(system_input1, **kwargs1): + if system_input1 != os.getenv("ADMIN_PASS", ""): + return json.dumps({}) + exclude_list = ['admin_pass', 'examples'] + sys_dict = {k: v for k, v in kwargs1.items() if + isinstance(v, (str, int, bool, float)) and k not in exclude_list} + try: + sys_dict.update(system_info()) + except Exception as e: + # protection + print("Exception: %s" % str(e), flush=True) + return json.dumps(sys_dict) + + system_kwargs = all_kwargs.copy() + system_kwargs.update(dict(command=str(' '.join(sys.argv)))) + get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs) + + system_dict_event = system_btn2.click(get_system_info_dict_func, + inputs=system_input, + outputs=system_text2, + api_name='system_info_dict' if allow_api else None, + queue=False, # queue to avoid spam + ) + + def get_hash(): + return kwargs['git_hash'] + + system_event = system_btn3.click(get_hash, + outputs=system_text3, + api_name='system_hash' if allow_api else None, + queue=False, + ) + + def get_model_names(): + key_list = ['base_model', 'prompt_type', 'prompt_dict'] + list(kwargs['other_model_state_defaults'].keys()) + # don't want to expose backend inference server IP etc. + # key_list += ['inference_server'] + return [{k: x[k] for k in key_list if k in x} for x in model_states] + + models_list_event = system_btn4.click(get_model_names, + outputs=system_text4, + api_name='model_names' if allow_api else None, + queue=False, + ) + + def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1, + system_prompt1, chat_conversation1, + memory_restriction_level1=0, + keep_sources_in_context1=False, + ): + if model_state1 and not isinstance(model_state1['tokenizer'], str): + tokenizer = model_state1['tokenizer'] + elif model_state0 and not isinstance(model_state0['tokenizer'], str): + tokenizer = model_state0['tokenizer'] + else: + tokenizer = None + if tokenizer is not None: + langchain_mode1 = 'LLM' + add_chat_history_to_context1 = True + # fake user message to mimic bot() + chat1 = copy.deepcopy(chat1) + chat1 = chat1 + [['user_message1', None]] + model_max_length1 = tokenizer.model_max_length + context1 = history_to_context(chat1, + langchain_mode=langchain_mode1, + add_chat_history_to_context=add_chat_history_to_context1, + prompt_type=prompt_type1, + prompt_dict=prompt_dict1, + chat=True, + model_max_length=model_max_length1, + memory_restriction_level=memory_restriction_level1, + keep_sources_in_context=keep_sources_in_context1, + system_prompt=system_prompt1, + chat_conversation=chat_conversation1) + tokens = tokenizer(context1, return_tensors="pt")['input_ids'] + if len(tokens.shape) == 1: + return str(tokens.shape[0]) + elif len(tokens.shape) == 2: + return str(tokens.shape[1]) + else: + return "N/A" + else: + return "N/A" + + count_chat_tokens_func = functools.partial(count_chat_tokens, + memory_restriction_level1=memory_restriction_level, + keep_sources_in_context1=kwargs['keep_sources_in_context']) + count_tokens_event = count_chat_tokens_btn.click(fn=count_chat_tokens_func, + inputs=[model_state, text_output, prompt_type, prompt_dict, + system_prompt, chat_conversation], + outputs=chat_token_count, + api_name='count_tokens' if allow_api else None) + + # don't pass text_output, don't want to clear output, just stop it + # cancel only stops outer generation, not inner generation or non-generation + stop_btn.click(lambda: None, None, None, + cancels=submits1 + submits2 + submits3 + submits4 + + [submit_event_nochat, submit_event_nochat2] + + [eventdb1, eventdb2, eventdb3] + + [eventdb7a, eventdb7, eventdb8a, eventdb8, eventdb9a, eventdb9, eventdb12a, eventdb12] + + db_events + + [eventdbloadla, eventdbloadlb] + + [clear_event] + + [submit_event_nochat_api, submit_event_nochat] + + [load_model_event, load_model_event2] + + [count_tokens_event] + , + queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False) + + if kwargs['auth'] is not None: + auth = authf + load_func = user_state_setup + load_inputs = [my_db_state, requests_state, login_btn, login_btn] + load_outputs = [my_db_state, requests_state, login_btn] + else: + auth = None + load_func, load_inputs, load_outputs = None, None, None + + app_js = wrap_js_to_lambda( + len(load_inputs) if load_inputs else 0, + get_dark_js() if kwargs['dark'] else None, + get_heap_js(heap_app_id) if is_heap_analytics_enabled else None) + + load_event = demo.load(fn=load_func, inputs=load_inputs, outputs=load_outputs, _js=app_js) + + if load_func: + load_event2 = load_event.then(load_login_func, + inputs=login_inputs, + outputs=login_outputs) + if not kwargs['large_file_count_mode']: + load_event3 = load_event2.then(**get_sources_kwargs) + load_event4 = load_event3.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + load_event5 = load_event4.then(**show_sources_kwargs) + load_event6 = load_event5.then(**get_viewable_sources_args) + load_event7 = load_event6.then(**viewable_kwargs) + + demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open']) + favicon_file = "h2o-logo.svg" + favicon_path = favicon_file + if not os.path.isfile(favicon_file): + print("favicon_path1=%s not found" % favicon_file, flush=True) + alt_path = os.path.dirname(os.path.abspath(__file__)) + favicon_path = os.path.join(alt_path, favicon_file) + if not os.path.isfile(favicon_path): + print("favicon_path2: %s not found in %s" % (favicon_file, alt_path), flush=True) + alt_path = os.path.dirname(alt_path) + favicon_path = os.path.join(alt_path, favicon_file) + if not os.path.isfile(favicon_path): + print("favicon_path3: %s not found in %s" % (favicon_file, alt_path), flush=True) + favicon_path = None + + if kwargs['prepare_offline_level'] > 0: + from src.prepare_offline import go_prepare_offline + go_prepare_offline(**locals()) + return + + scheduler = BackgroundScheduler() + scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20) + if is_public and \ + kwargs['base_model'] not in non_hf_types: + # FIXME: disable for gptj, langchain or gpt4all modify print itself + # FIXME: and any multi-threaded/async print will enter model output! + scheduler.add_job(func=ping, trigger="interval", seconds=60) + if is_public or os.getenv('PING_GPU'): + scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10) + scheduler.start() + + # import control + if kwargs['langchain_mode'] == 'Disabled' and \ + os.environ.get("TEST_LANGCHAIN_IMPORT") and \ + kwargs['base_model'] not in non_hf_types: + assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have" + assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have" + + # set port in case GRADIO_SERVER_PORT was already set in prior main() call, + # gradio does not listen if change after import + # Keep None if not set so can find an open port above used ports + server_port = os.getenv('GRADIO_SERVER_PORT') + if server_port is not None: + server_port = int(server_port) + + demo.launch(share=kwargs['share'], + server_name=kwargs['server_name'], + show_error=True, + server_port=server_port, + favicon_path=favicon_path, + prevent_thread_lock=True, + auth=auth, + auth_message=auth_message, + root_path=kwargs['root_path']) + if kwargs['verbose'] or not (kwargs['base_model'] in ['gptj', 'gpt4all_llama']): + print("Started Gradio Server and/or GUI: server_name: %s port: %s" % (kwargs['server_name'], server_port), + flush=True) + if kwargs['block_gradio_exit']: + demo.block_thread() + + +def show_doc(db1s, selection_docs_state1, requests_state1, + langchain_mode1, + single_document_choice1, + view_raw_text_checkbox1, + text_context_list1, + dbs1=None, + load_db_if_exists1=None, + db_type1=None, + use_openai_embedding1=None, + hf_embedding_model1=None, + migrate_embedding_model_or_db1=None, + auto_migrate_db1=None, + verbose1=False, + get_userid_auth1=None, + max_raw_chunks=1000000, + api=False, + n_jobs=-1): + file = single_document_choice1 + document_choice1 = [single_document_choice1] + content = None + db_documents = [] + db_metadatas = [] + if db_type1 in ['chroma', 'chroma_old']: + assert langchain_mode1 is not None + langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] + langchain_mode_types = selection_docs_state1['langchain_mode_types'] + from src.gpt_langchain import set_userid, get_any_db, get_docs_and_meta + set_userid(db1s, requests_state1, get_userid_auth1) + top_k_docs = -1 + db = get_any_db(db1s, langchain_mode1, langchain_mode_paths, langchain_mode_types, + dbs=dbs1, + load_db_if_exists=load_db_if_exists1, + db_type=db_type1, + use_openai_embedding=use_openai_embedding1, + hf_embedding_model=hf_embedding_model1, + migrate_embedding_model=migrate_embedding_model_or_db1, + auto_migrate_db=auto_migrate_db1, + for_sources_list=True, + verbose=verbose1, + n_jobs=n_jobs, + ) + query_action = False # long chunks like would be used for summarize + # the below is as or filter, so will show doc or by chunk, unrestricted + from langchain.vectorstores import Chroma + if isinstance(db, Chroma): + # chroma >= 0.4 + if view_raw_text_checkbox1: + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$gte": -1}} + for x in document_choice1][0] + else: + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$eq": -1}} + for x in document_choice1][0] + filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']), + dict(chunk_id=one_filter['chunk_id'])]}) + else: + # migration for chroma < 0.4 + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$eq": -1}} + for x in document_choice1][0] + if view_raw_text_checkbox1: + # like or, full raw all chunk types + filter_kwargs = dict(filter=one_filter) + else: + filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']), + dict(chunk_id=one_filter['chunk_id'])]}) + db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, + text_context_list=text_context_list1) + # order documents + from langchain.docstore.document import Document + docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0) + for result in zip(db_documents, db_metadatas)] + doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas] + doc_page_ids = [x.get('page', 0) for x in db_metadatas] + doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas] + docs_with_score = [x for hx, px, cx, x in + sorted(zip(doc_hashes, doc_page_ids, doc_chunk_ids, docs_with_score), + key=lambda x: (x[0], x[1], x[2])) + # if cx == -1 + ] + db_metadatas = [x[0].metadata for x in docs_with_score][:max_raw_chunks] + db_documents = [x[0].page_content for x in docs_with_score][:max_raw_chunks] + # done reordering + if view_raw_text_checkbox1: + content = [dict_to_html(x) + '\n' + text_to_html(y) for x, y in zip(db_metadatas, db_documents)] + else: + content = [text_to_html(y) for x, y in zip(db_metadatas, db_documents)] + content = '\n'.join(content) + content = f""" + + +", ' ').replace("\r", ' ')
+ content = x.page_content
+ return f"""{title}
{content}
+%s
+
+""" % x
+
+
+def lg_to_gr(
+ **kwargs,
+):
+ # translate:
+ import torch
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
+ n_gpus, _ = cuda_vis_check(n_gpus)
+
+ image_loaders_options = ['Caption']
+ if n_gpus != 0:
+ image_loaders_options.extend(['CaptionBlip2', 'Pix2Struct'])
+ if have_tesseract:
+ image_loaders_options.append('OCR')
+ if have_doctr:
+ image_loaders_options.append('DocTR')
+
+ image_loaders_options0 = []
+ if have_tesseract and kwargs['enable_ocr']:
+ image_loaders_options0.append('OCR')
+ if have_doctr and kwargs['enable_doctr']:
+ image_loaders_options0.append('DocTR')
+ if kwargs['enable_captions']:
+ if kwargs['max_quality'] and n_gpus > 0:
+ # BLIP2 only on GPU
+ image_loaders_options0.append('CaptionBlip2')
+ else:
+ image_loaders_options0.append('Caption')
+
+ pdf_loaders_options = ['PyMuPDF', 'Unstructured', 'PyPDF', 'TryHTML']
+ if have_tesseract:
+ pdf_loaders_options.append('OCR')
+ if have_doctr:
+ pdf_loaders_options.append('DocTR')
+
+ pdf_loaders_options0 = []
+ if kwargs['use_pymupdf'] in [True, 'auto', 'on']:
+ pdf_loaders_options0.append('PyMuPDF')
+ if kwargs['enable_pdf_ocr'] in [True, 'on']:
+ pdf_loaders_options0.append('OCR')
+ if have_doctr and kwargs['enable_pdf_doctr'] in [True, 'on']:
+ pdf_loaders_options0.append('DocTR')
+
+ url_loaders_options = []
+ if only_unstructured_urls:
+ url_loaders_options.append('Unstructured')
+ elif have_selenium and only_selenium:
+ url_loaders_options.append('Selenium')
+ elif have_playwright and only_playwright:
+ url_loaders_options.append('PlayWright')
+ else:
+ url_loaders_options.append('Unstructured')
+ if have_selenium:
+ url_loaders_options.append('Selenium')
+ if have_playwright:
+ url_loaders_options.append('PlayWright')
+ url_loaders_options0 = [url_loaders_options[0]]
+
+ assert set(image_loaders_options0).issubset(image_loaders_options)
+ assert set(pdf_loaders_options0).issubset(pdf_loaders_options)
+ assert set(url_loaders_options0).issubset(url_loaders_options)
+
+ return image_loaders_options0, image_loaders_options, \
+ pdf_loaders_options0, pdf_loaders_options, \
+ url_loaders_options0, url_loaders_options
+
+
+def fix_json(s):
+
+ # Attempt to parse the string as-is.
+ try:
+ return json.loads(s)
+ except json.JSONDecodeError:
+ pass
+
+ # Initialize variables.
+ new_s = ""
+ stack = []
+ is_inside_string = False
+ escaped = False
+
+ # Process each character in the string one at a time.
+ for char in s:
+ if is_inside_string:
+ if char == '"' and not escaped:
+ is_inside_string = False
+ elif char == '\n' and not escaped:
+ char = '\\n' # Replace the newline character with the escape sequence.
+ elif char == '\\':
+ escaped = not escaped
+ else:
+ escaped = False
+ else:
+ if char == '"':
+ is_inside_string = True
+ escaped = False
+ elif char == '{':
+ stack.append('}')
+ elif char == '[':
+ stack.append(']')
+ elif char == '}' or char == ']':
+ if stack and stack[-1] == char:
+ stack.pop()
+ else:
+ # Mismatched closing character; the input is malformed.
+ return None
+
+ # Append the processed character to the new string.
+ new_s += char
+
+ # If we're still inside a string at the end of processing, we need to close the string.
+ if is_inside_string:
+ new_s += '"'
+
+ # Close any remaining open structures in the reverse order that they were opened.
+ for closing_char in reversed(stack):
+ new_s += closing_char
+
+ # Attempt to parse the modified string as JSON.
+ try:
+ return json.loads(new_s)
+ except json.JSONDecodeError:
+ # If we still can't parse the string as JSON, return None to indicate failure.
+ return None
+
+
+def wrap_in_try_except(code):
+ # Add import traceback
+ code = "import traceback\n" + code
+
+ # Parse the input code into an AST
+ parsed_code = ast.parse(code)
+
+ # Wrap the entire code's AST in a single try-except block
+ try_except = ast.Try(
+ body=parsed_code.body,
+ handlers=[
+ ast.ExceptHandler(
+ type=ast.Name(id="Exception", ctx=ast.Load()),
+ name=None,
+ body=[
+ ast.Expr(
+ value=ast.Call(
+ func=ast.Attribute(value=ast.Name(id="traceback", ctx=ast.Load()), attr="print_exc", ctx=ast.Load()),
+ args=[],
+ keywords=[]
+ )
+ ),
+ ]
+ )
+ ],
+ orelse=[],
+ finalbody=[]
+ )
+
+ # Assign the try-except block as the new body
+ parsed_code.body = [try_except]
+
+ # Convert the modified AST back to source code
+ return ast.unparse(parsed_code)
+
+
+def enqueue_output(file, queue):
+ for line in iter(file.readline, ''):
+ queue.put(line)
+ file.close()
+
+
+def read_popen_pipes(p):
+
+ with ThreadPoolExecutor(2) as pool:
+ q_stdout, q_stderr = Queue(), Queue()
+
+ pool.submit(enqueue_output, p.stdout, q_stdout)
+ pool.submit(enqueue_output, p.stderr, q_stderr)
+
+ while True:
+
+ if p.poll() is not None and q_stdout.empty() and q_stderr.empty():
+ break
+
+ out_line = err_line = ''
+
+ try:
+ out_line = q_stdout.get_nowait()
+ except Empty:
+ pass
+ try:
+ err_line = q_stderr.get_nowait()
+ except Empty:
+ pass
+
+ yield out_line, err_line
+
+
+def start_process(cmd):
+ start_cmd = sys.executable + " -i -q -u"
+ print_cmd = 'print("{}")'
+ cmd = [start_cmd] + [cmd]
+
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
+ for c in iter(lambda: process.stdout.read(1), b''):
+ sys.stdout.write(c)
+
+
+def str_to_list(x, allow_none=False):
+ if isinstance(x, str):
+ if len(x.strip()) > 0:
+ if x.strip().startswith('['):
+ x = ast.literal_eval(x.strip())
+ else:
+ raise ValueError("Invalid str_to_list for %s" % x)
+ else:
+ x = []
+ elif x is None and not allow_none:
+ x = []
+ if allow_none:
+ assert isinstance(x, (type(None), list))
+ else:
+ assert isinstance(x, list)
+ return x
+
+
+def str_to_dict(x):
+ if isinstance(x, str):
+ if len(x.strip()) > 0:
+ if x.strip().startswith('{'):
+ x = ast.literal_eval(x.strip())
+ else:
+ raise ValueError("Invalid str_to_dict for %s" % x)
+ else:
+ x = {}
+ elif x is None:
+ x = {}
+ assert isinstance(x, dict)
+ return x
+
+
+def get_token_count(x, tokenizer, token_count_fun=None):
+ # NOTE: Somewhat duplicates H2OTextGenerationPipeline.get_token_count()
+ # handle ambiguity in if get dict or list
+ if tokenizer:
+ if hasattr(tokenizer, 'encode'):
+ template_tokens = tokenizer.encode(x)
+ else:
+ template_tokens = tokenizer(x)
+ if isinstance(template_tokens, dict) and 'input_ids' in template_tokens:
+ n_tokens = len(tokenizer.encode(x)['input_ids'])
+ else:
+ n_tokens = len(tokenizer.encode(x))
+ elif token_count_fun is not None:
+ assert callable(token_count_fun)
+ n_tokens = token_count_fun(x)
+ else:
+ tokenizer = FakeTokenizer()
+ n_tokens = tokenizer.num_tokens_from_string(x)
+ return n_tokens
+
+
+def reverse_ucurve_list(lst):
+ if not lst:
+ return []
+ if len(lst) == 1:
+ return lst
+ if len(lst) == 2:
+ return [lst[1], lst[0]]
+
+ front_list = []
+ end_list = []
+
+ for i, item in enumerate(lst):
+ if i % 2 == 0:
+ end_list.append(item)
+ else:
+ front_list.append(item)
+
+ return front_list + end_list[::-1]
+
+
+def undo_reverse_ucurve_list(lst):
+ if not lst:
+ return []
+ if len(lst) == 1:
+ return lst
+ if len(lst) == 2:
+ return [lst[1], lst[0]]
+
+ # Split the list into two halves: the first half and the second half (reversed)
+ mid = len(lst) // 2
+ first_half = lst[:mid]
+ second_half = lst[mid:][::-1]
+
+ # Merge the two halves by taking elements alternatively from the second half and then the first half
+ result = []
+ for i in range(mid):
+ result.append(second_half[i])
+ result.append(first_half[i])
+
+ # If the length of the list is odd, append the last element of the second half
+ if len(lst) % 2 != 0:
+ result.append(second_half[-1])
+
+ return result
diff --git a/src/utils_langchain.py b/src/utils_langchain.py
new file mode 100644
index 0000000000000000000000000000000000000000..7483cca69443691de773196ba6c5134438e113aa
--- /dev/null
+++ b/src/utils_langchain.py
@@ -0,0 +1,152 @@
+import copy
+import os
+import types
+import uuid
+from typing import Any, Dict, List, Union, Optional
+import time
+import queue
+import pathlib
+from datetime import datetime
+
+from src.utils import hash_file, get_sha
+
+from langchain.callbacks.base import BaseCallbackHandler
+from langchain.schema import LLMResult
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain.docstore.document import Document
+
+
+class StreamingGradioCallbackHandler(BaseCallbackHandler):
+ """
+ Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
+ """
+ def __init__(self, timeout: Optional[float] = None, block=True):
+ super().__init__()
+ self.text_queue = queue.SimpleQueue()
+ self.stop_signal = None
+ self.do_stop = False
+ self.timeout = timeout
+ self.block = block
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ """Run when LLM starts running. Clean the queue."""
+ while not self.text_queue.empty():
+ try:
+ self.text_queue.get(block=False)
+ except queue.Empty:
+ continue
+
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ """Run on new LLM token. Only available when streaming is enabled."""
+ self.text_queue.put(token)
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ """Run when LLM ends running."""
+ self.text_queue.put(self.stop_signal)
+
+ def on_llm_error(
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
+ ) -> None:
+ """Run when LLM errors."""
+ self.text_queue.put(self.stop_signal)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ while True:
+ try:
+ value = self.stop_signal # value looks unused in pycharm, not true
+ if self.do_stop:
+ print("hit stop", flush=True)
+ # could raise or break, maybe best to raise and make parent see if any exception in thread
+ raise StopIteration()
+ # break
+ value = self.text_queue.get(block=self.block, timeout=self.timeout)
+ break
+ except queue.Empty:
+ time.sleep(0.01)
+ if value == self.stop_signal:
+ raise StopIteration()
+ else:
+ return value
+
+
+def _chunk_sources(sources, chunk=True, chunk_size=512, language=None, db_type=None):
+ assert db_type is not None
+
+ if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
+ # if just one document
+ sources = [sources]
+ if not chunk:
+ [x.metadata.update(dict(chunk_id=0)) for chunk_id, x in enumerate(sources)]
+ if db_type in ['chroma', 'chroma_old']:
+ # make copy so can have separate summarize case
+ source_chunks = [Document(page_content=x.page_content,
+ metadata=copy.deepcopy(x.metadata) or {})
+ for x in sources]
+ else:
+ source_chunks = sources # just same thing
+ else:
+ if language and False:
+ # Bug in langchain, keep separator=True not working
+ # https://github.com/hwchase17/langchain/issues/2836
+ # so avoid this for now
+ keep_separator = True
+ separators = RecursiveCharacterTextSplitter.get_separators_for_language(language)
+ else:
+ separators = ["\n\n", "\n", " ", ""]
+ keep_separator = False
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
+ separators=separators)
+ source_chunks = splitter.split_documents(sources)
+
+ # currently in order, but when pull from db won't be, so mark order and document by hash
+ [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
+
+ if db_type in ['chroma', 'chroma_old']:
+ # also keep original source for summarization and other tasks
+
+ # assign chunk_id=-1 for original content
+ # this assumes, as is currently true, that splitter makes new documents and list and metadata is deepcopy
+ [x.metadata.update(dict(chunk_id=-1)) for chunk_id, x in enumerate(sources)]
+
+ # in some cases sources is generator, so convert to list
+ return list(sources) + source_chunks
+ else:
+ return source_chunks
+
+
+def add_parser(docs1, parser):
+ [x.metadata.update(dict(parser=x.metadata.get('parser', parser))) for x in docs1]
+
+
+def _add_meta(docs1, file, headsize=50, filei=0, parser='NotSet'):
+ if os.path.isfile(file):
+ file_extension = pathlib.Path(file).suffix
+ hashid = hash_file(file)
+ else:
+ file_extension = str(file) # not file, just show full thing
+ hashid = get_sha(file)
+ doc_hash = str(uuid.uuid4())[:10]
+ if not isinstance(docs1, (list, tuple, types.GeneratorType)):
+ docs1 = [docs1]
+ [x.metadata.update(dict(input_type=file_extension,
+ parser=x.metadata.get('parser', parser),
+ date=str(datetime.now()),
+ time=time.time(),
+ order_id=order_id,
+ hashid=hashid,
+ doc_hash=doc_hash,
+ file_id=filei,
+ head=x.page_content[:headsize].strip())) for order_id, x in enumerate(docs1)]
+
+
+def fix_json_meta(docs1):
+ if not isinstance(docs1, (list, tuple, types.GeneratorType)):
+ docs1 = [docs1]
+ # fix meta, chroma doesn't like None, only str, int, float for values
+ [x.metadata.update(dict(sender_name=x.metadata.get('sender_name') or '')) for x in docs1]
+ [x.metadata.update(dict(timestamp_ms=x.metadata.get('timestamp_ms') or '')) for x in docs1]