import copy import os import types import uuid from typing import Any, Dict, List, Union, Optional import time import queue import pathlib from datetime import datetime from src.utils import hash_file, get_sha from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import LLMResult from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.docstore.document import Document class StreamingGradioCallbackHandler(BaseCallbackHandler): """ Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend """ def __init__(self, timeout: Optional[float] = None, block=True): super().__init__() self.text_queue = queue.SimpleQueue() self.stop_signal = None self.do_stop = False self.timeout = timeout self.block = block def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running. Clean the queue.""" while not self.text_queue.empty(): try: self.text_queue.get(block=False) except queue.Empty: continue def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run on new LLM token. Only available when streaming is enabled.""" self.text_queue.put(token) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" self.text_queue.put(self.stop_signal) def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: """Run when LLM errors.""" self.text_queue.put(self.stop_signal) def __iter__(self): return self def __next__(self): while True: try: value = self.stop_signal # value looks unused in pycharm, not true if self.do_stop: print("hit stop", flush=True) # could raise or break, maybe best to raise and make parent see if any exception in thread raise StopIteration() # break value = self.text_queue.get(block=self.block, timeout=self.timeout) break except queue.Empty: time.sleep(0.01) if value == self.stop_signal: raise StopIteration() else: return value def _chunk_sources(sources, chunk=True, chunk_size=512, language=None, db_type=None): assert db_type is not None if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources): # if just one document sources = [sources] if not chunk: [x.metadata.update(dict(chunk_id=0)) for chunk_id, x in enumerate(sources)] if db_type in ['chroma', 'chroma_old']: # make copy so can have separate summarize case source_chunks = [Document(page_content=x.page_content, metadata=copy.deepcopy(x.metadata) or {}) for x in sources] else: source_chunks = sources # just same thing else: if language and False: # Bug in langchain, keep separator=True not working # https://github.com/hwchase17/langchain/issues/2836 # so avoid this for now keep_separator = True separators = RecursiveCharacterTextSplitter.get_separators_for_language(language) else: separators = ["\n\n", "\n", " ", ""] keep_separator = False splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator, separators=separators) source_chunks = splitter.split_documents(sources) # currently in order, but when pull from db won't be, so mark order and document by hash [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)] if db_type in ['chroma', 'chroma_old']: # also keep original source for summarization and other tasks # assign chunk_id=-1 for original content # this assumes, as is currently true, that splitter makes new documents and list and metadata is deepcopy [x.metadata.update(dict(chunk_id=-1)) for chunk_id, x in enumerate(sources)] # in some cases sources is generator, so convert to list return list(sources) + source_chunks else: return source_chunks def add_parser(docs1, parser): [x.metadata.update(dict(parser=x.metadata.get('parser', parser))) for x in docs1] def _add_meta(docs1, file, headsize=50, filei=0, parser='NotSet'): if os.path.isfile(file): file_extension = pathlib.Path(file).suffix hashid = hash_file(file) else: file_extension = str(file) # not file, just show full thing hashid = get_sha(file) doc_hash = str(uuid.uuid4())[:10] if not isinstance(docs1, (list, tuple, types.GeneratorType)): docs1 = [docs1] [x.metadata.update(dict(input_type=file_extension, parser=x.metadata.get('parser', parser), date=str(datetime.now()), time=time.time(), order_id=order_id, hashid=hashid, doc_hash=doc_hash, file_id=filei, head=x.page_content[:headsize].strip())) for order_id, x in enumerate(docs1)] def fix_json_meta(docs1): if not isinstance(docs1, (list, tuple, types.GeneratorType)): docs1 = [docs1] # fix meta, chroma doesn't like None, only str, int, float for values [x.metadata.update(dict(sender_name=x.metadata.get('sender_name') or '')) for x in docs1] [x.metadata.update(dict(timestamp_ms=x.metadata.get('timestamp_ms') or '')) for x in docs1]