| import asyncio |
| import importlib |
| import inspect |
| import logging |
| import os |
| import os.path as osp |
| import sys |
| import time |
| from functools import partial |
| from logging.handlers import RotatingFileHandler |
| from typing import Any, Dict, Generator, Iterable, List, Optional, Union |
|
|
|
|
| def load_class_from_string(class_path: str, path=None): |
| path_in_sys = False |
| if path: |
| if path not in sys.path: |
| path_in_sys = True |
| sys.path.insert(0, path) |
|
|
| try: |
| module_name, class_name = class_path.rsplit('.', 1) |
| module = importlib.import_module(module_name) |
| cls = getattr(module, class_name) |
| return cls |
| finally: |
| if path and path_in_sys: |
| sys.path.remove(path) |
|
|
|
|
| def create_object(config: Union[Dict, Any] = None): |
| """Create an instance based on the configuration where 'type' is a |
| preserved key to indicate the class (path). When accepting non-dictionary |
| input, the function degenerates to an identity. |
| """ |
| if config is None or not isinstance(config, dict): |
| return config |
| assert isinstance(config, dict) and 'type' in config |
|
|
| config = config.copy() |
| obj_type = config.pop('type') |
| if isinstance(obj_type, str): |
| obj_type = load_class_from_string(obj_type) |
| if inspect.isclass(obj_type): |
| obj = obj_type(**config) |
| else: |
| assert callable(obj_type) |
| obj = partial(obj_type, **config) |
| return obj |
|
|
|
|
| async def async_as_completed(futures: Iterable[asyncio.Future]): |
| """A asynchronous wrapper for `asyncio.as_completed`""" |
| loop = asyncio.get_event_loop() |
| wrappers = [] |
| for fut in futures: |
| assert isinstance(fut, asyncio.Future) |
| wrapper = loop.create_future() |
| fut.add_done_callback(wrapper.set_result) |
| wrappers.append(wrapper) |
| for next_completed in asyncio.as_completed(wrappers): |
| yield await next_completed |
|
|
|
|
| def filter_suffix(response: Union[str, List[str]], |
| suffixes: Optional[List[str]] = None) -> str: |
| """Filter response with suffixes. |
| |
| Args: |
| response (Union[str, List[str]]): generated responses by LLMs. |
| suffixes (str): a list of suffixes to be deleted. |
| |
| Return: |
| str: a clean response. |
| """ |
| if suffixes is None: |
| return response |
| batched = True |
| if isinstance(response, str): |
| response = [response] |
| batched = False |
| processed = [] |
| for resp in response: |
| for item in suffixes: |
| |
| |
| if item in resp: |
| resp = resp.split(item)[0] |
| processed.append(resp) |
| if not batched: |
| return processed[0] |
| return processed |
|
|
|
|
| def get_logger( |
| name: str = 'lagent', |
| level: str = 'debug', |
| fmt: |
| str = '%(asctime)s %(levelname)8s %(filename)20s %(lineno)4s - %(message)s', |
| add_file_handler: bool = False, |
| log_dir: str = 'log', |
| log_file: str = time.strftime('%Y-%m-%d.log', time.localtime()), |
| max_bytes: int = 5 * 1024 * 1024, |
| backup_count: int = 3, |
| ): |
| logger = logging.getLogger(name) |
| logger.propagate = False |
| logger.setLevel(getattr(logging, level.upper(), logging.DEBUG)) |
|
|
| formatter = logging.Formatter(fmt) |
| console_handler = logging.StreamHandler() |
| console_handler.setFormatter(formatter) |
| logger.addHandler(console_handler) |
|
|
| if add_file_handler: |
| if not osp.exists(log_dir): |
| os.makedirs(log_dir) |
| log_file_path = osp.join(log_dir, log_file) |
| file_handler = RotatingFileHandler( |
| log_file_path, |
| maxBytes=max_bytes, |
| backupCount=backup_count, |
| encoding='utf-8') |
| file_handler.setFormatter(formatter) |
| logger.addHandler(file_handler) |
|
|
| return logger |
|
|
|
|
| class GeneratorWithReturn: |
| """Generator wrapper to capture the return value.""" |
|
|
| def __init__(self, generator: Generator): |
| self.generator = generator |
| self.ret = None |
|
|
| def __iter__(self): |
| self.ret = yield from self.generator |
| return self.ret |
|
|