chuangpt / modules /overwrites.py
eggacheb's picture
Upload 105 files
1ea2ba0 verified
raw
history blame contribute delete
No virus
3.63 kB
from __future__ import annotations
import inspect
import gradio as gr
from gradio.components.chatbot import ChatbotData, FileMessage
from gradio.data_classes import FileData
from gradio_client import utils as client_utils
from modules.utils import convert_bot_before_marked, convert_user_before_marked
def postprocess(
self,
value: list[list[str | tuple[str] | tuple[str, str] | None] | tuple] | None,
) -> ChatbotData:
"""
Parameters:
value: expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed.
Returns:
an object of type ChatbotData
"""
if value is None:
return ChatbotData(root=[])
processed_messages = []
for message_pair in value:
if not isinstance(message_pair, (tuple, list)):
raise TypeError(
f"Expected a list of lists or list of tuples. Received: {message_pair}"
)
if len(message_pair) != 2:
raise TypeError(
f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
)
processed_messages.append(
[
self._postprocess_chat_messages(message_pair[0], "user"),
self._postprocess_chat_messages(message_pair[1], "bot"),
]
)
return ChatbotData(root=processed_messages)
def postprocess_chat_messages(
self, chat_message: str | tuple | list | None, role: str
) -> str | FileMessage | None:
if chat_message is None:
return None
elif isinstance(chat_message, (tuple, list)):
filepath = str(chat_message[0])
mime_type = client_utils.get_mimetype(filepath)
return FileMessage(
file=FileData(path=filepath, mime_type=mime_type),
alt_text=chat_message[1] if len(chat_message) > 1 else None,
)
elif isinstance(chat_message, str):
# chat_message = inspect.cleandoc(chat_message)
if role == "bot":
# chat_message = inspect.cleandoc(chat_message)
chat_message = convert_bot_before_marked(chat_message)
elif role == "user":
chat_message = convert_user_before_marked(chat_message)
return chat_message
else:
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
def init_with_class_name_as_elem_classes(original_func):
def wrapper(self, *args, **kwargs):
if "elem_classes" in kwargs and isinstance(kwargs["elem_classes"], str):
kwargs["elem_classes"] = [kwargs["elem_classes"]]
else:
kwargs["elem_classes"] = []
kwargs["elem_classes"].append("gradio-" + self.__class__.__name__.lower())
if kwargs.get("multiselect", False):
kwargs["elem_classes"].append("multiselect")
res = original_func(self, *args, **kwargs)
return res
return wrapper
def patch_gradio():
gr.components.Component.__init__ = init_with_class_name_as_elem_classes(
gr.components.Component.__init__
)
gr.blocks.BlockContext.__init__ = init_with_class_name_as_elem_classes(
gr.blocks.BlockContext.__init__
)
gr.Chatbot._postprocess_chat_messages = postprocess_chat_messages
gr.Chatbot.postprocess = postprocess