# Copyright (c) OpenMMLab. All rights reserved. import json from enum import Enum from types import SimpleNamespace import redis import requests from loguru import logger from .config import redis_host, redis_passwd, redis_port class TaskCode(Enum): FS_ADD_DOC = 'add_doc' FS_UPDATE_SAMPLE = 'update_sample' FS_UPDATE_PIPELINE = 'update_pipeline' CHAT = 'chat' CHAT_RESPONSE = 'chat_response' class ErrorCode(Enum): """Define an enumerated type for error codes, each has a numeric value and a description. Each enum member is associated with a numeric code and a description string. The numeric code is used as the return code in function calls, and the description provides a human-readable explanation of the error. """ SUCCESS = 0, 'success' NOT_A_QUESTION = 1, 'query is not a question' NO_TOPIC = 2, 'The question does not have a topic. It might be a meaningless sentence.' # noqa E501 UNRELATED = 3, 'Topics unrelated to the knowledge base. Updating good_questions and bad_questions can improve accuracy.' # noqa E501 NO_SEARCH_KEYWORDS = 4, 'Cannot extract keywords.' NO_SEARCH_RESULT = 5, 'Cannot retrieve results.' BAD_ANSWER = 6, 'Irrelevant answer.' SECURITY = 7, 'Reply has a high relevance to prohibited topics.' NOT_WORK_TIME = 8, 'Non-working hours. The config.ini file can be modified to adjust this. **In scenarios where speech may pose risks, let the robot operate under human supervision**' # noqa E501 PARAMETER_ERROR = 9, "HTTP interface parameter error. Query cannot be empty; the format of history is list of lists, like [['question1', 'reply1'], ['question2'], ['reply2']]" # noqa E501 PARAMETER_MISS = 10, 'Missing key in http json input parameters.' WORK_IN_PROGRESS = 11, 'not finish' FAILED = 12, 'fail' BAD_PARAMETER = 13, 'bad parameter' INTERNAL_ERROR = 14, 'internal error' SEARCH_FAIL = 15, 'Search fail, please check TOKEN and quota' ANNOTATECLUSTER = 16, 'Annotate cluster' def __new__(cls, value, description): """Create new instance of ErrorCode.""" obj = object.__new__(cls) obj._value_ = value obj.description = description return obj def __int__(self): """Return the integer representation of the error code.""" return self.value def describe(self): """Return the description of the error code.""" return self.description @classmethod def format(cls, code): """Format the error code into a JSON result. Args: code (ErrorCode): Error code to be formatted. Returns: dict: A dictionary that includes the error code and its description. # noqa E501 Raises: TypeError: If the input is not an instance of ErrorCode. """ if isinstance(code, cls): return {'code': int(code), 'message': code.describe()} raise TypeError(f'Expected type {cls}, got {type(code)}') class Queue: def __init__(self, name, namespace='HuixiangDou', **redis_kwargs): self.__db = redis.Redis(host=redis_host(), port=redis_port(), password=redis_passwd(), charset='utf-8', decode_responses=True) self.key = '%s:%s' % (namespace, name) def qsize(self): """Return the approximate size of the queue.""" return self.__db.llen(self.key) def empty(self): """Return True if the queue is empty, False otherwise.""" return self.qsize() == 0 def put(self, item): """Put item into the queue.""" self.__db.rpush(self.key, item) def peek_tail(self): return self.__db.lrange(self.key, -1, -1) def get(self, block=True, timeout=None): """Remove and return an item from the queue. If optional args block is true and timeout is None (the default), block if necessary until an item is available. """ if block: item = self.__db.blpop(self.key, timeout=timeout) else: item = self.__db.lpop(self.key) if item: item = item[1] return item def get_nowait(self): """Equivalent to get(False).""" return self.get(False) class QueryTracker: """A class to track queries and log them into a file. This class provides functionality to keep track of queries and write them into a log file. Whenever a query is made, it can be logged using this class, and when the instance of this class is destroyed, all logged queries are written to the file. """ def __init__(self, log_file_path): """Initialize the QueryTracker with the path of the log file.""" self.log_file_path = log_file_path self.log_list = [] def log(self, key, value=''): """Log a query. Args: key (str): The key associated with the query. value (str): The value or result associated with the query. """ self.log_list.append((key, value)) def __del__(self): """Write all logged queries into the file when the QueryTracker instance is destroyed. It opens the log file in append mode, writes all logged queries into the file, and then closes the file. If any exception occurs during this process, it will be caught and printed to standard output. """ try: with open(self.log_file_path, 'a', encoding='utf8') as log_file: for key, value in self.log_list: log_file.write(f'{key}: {value}\n') log_file.write('\n') except Exception as e: print(e) def parse_json_str(json_str: str): try: logger.info(json_str) return json.loads(json_str, object_hook=lambda d: SimpleNamespace(**d)), None except Exception as e: logger.error(str(e)) return None, e def multimodal(filepath: str, timeout=5): header = {'Content-Type': 'application/json'} data = {'image_path': filepath} try: resp = requests.post('http://127.0.0.1:9999/api', headers=header, data=json.dumps(data), timeout=timeout) resp_json = resp.json() content = resp_json['content'] # check bad encode ratio useful_char_cnt = 0 scopes = [['a', 'z'], ['\u4e00', '\u9fff'], ['A', 'Z'], ['0', '9']] for char in content: for scope in scopes: if char >= scope[0] and char <= scope[1]: useful_char_cnt += 1 break if useful_char_cnt / len(content) <= 0.5: # Garbled characters return None if len(content) <= 100: return None return content except Exception as e: logger.error(str(e)) return None