Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import ctypes | |
| import pathlib | |
| from typing import Optional, List | |
| import enum | |
| from pathlib import Path | |
| class DataType(enum.IntEnum): | |
| def __str__(self): | |
| return str(self.name) | |
| F16 = 0 | |
| F32 = 1 | |
| I32 = 2 | |
| L64 = 3 | |
| Q4_0 = 4 | |
| Q4_1 = 5 | |
| Q5_0 = 6 | |
| Q5_1 = 7 | |
| Q8_0 = 8 | |
| Q8_1 = 9 | |
| Q2_K = 10 | |
| Q3_K = 11 | |
| Q4_K = 12 | |
| Q5_K = 13 | |
| Q6_K = 14 | |
| Q8_K = 15 | |
| class Verbosity(enum.IntEnum): | |
| SILENT = 0 | |
| ERR = 1 | |
| INFO = 2 | |
| DEBUG = 3 | |
| class ImageFormat(enum.IntEnum): | |
| UNKNOWN = 0 | |
| F32 = 1 | |
| U8 = 2 | |
| I32 = ctypes.c_int32 | |
| U32 = ctypes.c_uint32 | |
| F32 = ctypes.c_float | |
| SIZE_T = ctypes.c_size_t | |
| VOID_PTR = ctypes.c_void_p | |
| CHAR_PTR = ctypes.POINTER(ctypes.c_char) | |
| FLOAT_PTR = ctypes.POINTER(ctypes.c_float) | |
| INT_PTR = ctypes.POINTER(ctypes.c_int32) | |
| CHAR_PTR_PTR = ctypes.POINTER(ctypes.POINTER(ctypes.c_char)) | |
| MiniGPT4ContextP = VOID_PTR | |
| class MiniGPT4Context: | |
| def __init__(self, ptr: ctypes.pointer): | |
| self.ptr = ptr | |
| class MiniGPT4Image(ctypes.Structure): | |
| _fields_ = [ | |
| ('data', VOID_PTR), | |
| ('width', I32), | |
| ('height', I32), | |
| ('channels', I32), | |
| ('format', I32) | |
| ] | |
| class MiniGPT4Embedding(ctypes.Structure): | |
| _fields_ = [ | |
| ('data', FLOAT_PTR), | |
| ('n_embeddings', SIZE_T), | |
| ] | |
| MiniGPT4ImageP = ctypes.POINTER(MiniGPT4Image) | |
| MiniGPT4EmbeddingP = ctypes.POINTER(MiniGPT4Embedding) | |
| class MiniGPT4SharedLibrary: | |
| """ | |
| Python wrapper around minigpt4.cpp shared library. | |
| """ | |
| def __init__(self, shared_library_path: str): | |
| """ | |
| Loads the shared library from specified file. | |
| In case of any error, this method will throw an exception. | |
| Parameters | |
| ---------- | |
| shared_library_path : str | |
| Path to minigpt4.cpp shared library. On Windows, it would look like 'minigpt4.dll'. On UNIX, 'minigpt4.so'. | |
| """ | |
| self.library = ctypes.cdll.LoadLibrary(shared_library_path) | |
| self.library.minigpt4_model_load.argtypes = [ | |
| CHAR_PTR, # const char *path | |
| CHAR_PTR, # const char *llm_model | |
| I32, # int verbosity | |
| I32, # int seed | |
| I32, # int n_ctx | |
| I32, # int n_batch | |
| I32, # int numa | |
| ] | |
| self.library.minigpt4_model_load.restype = MiniGPT4ContextP | |
| self.library.minigpt4_image_load_from_file.argtypes = [ | |
| MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
| CHAR_PTR, # const char *path | |
| MiniGPT4ImageP, # struct MiniGPT4Image *image | |
| I32, # int flags | |
| ] | |
| self.library.minigpt4_image_load_from_file.restype = I32 | |
| self.library.minigpt4_encode_image.argtypes = [ | |
| MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
| MiniGPT4ImageP, # const struct MiniGPT4Image *image | |
| MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding | |
| I32, # size_t n_threads | |
| ] | |
| self.library.minigpt4_encode_image.restype = I32 | |
| self.library.minigpt4_begin_chat_image.argtypes = [ | |
| MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
| MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding | |
| CHAR_PTR, # const char *s | |
| I32, # size_t n_threads | |
| ] | |
| self.library.minigpt4_begin_chat_image.restype = I32 | |
| self.library.minigpt4_end_chat_image.argtypes = [ | |
| MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
| CHAR_PTR_PTR, # const char **token | |
| I32, # size_t n_threads | |
| F32, # float temp | |
| I32, # int32_t top_k | |
| F32, # float top_p | |
| F32, # float tfs_z | |
| F32, # float typical_p | |
| I32, # int32_t repeat_last_n | |
| F32, # float repeat_penalty | |
| F32, # float alpha_presence | |
| F32, # float alpha_frequency | |
| I32, # int mirostat | |
| F32, # float mirostat_tau | |
| F32, # float mirostat_eta | |
| I32, # int penalize_nl | |
| ] | |
| self.library.minigpt4_end_chat_image.restype = I32 | |
| self.library.minigpt4_system_prompt.argtypes = [ | |
| MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
| I32, # size_t n_threads | |
| ] | |
| self.library.minigpt4_system_prompt.restype = I32 | |
| self.library.minigpt4_begin_chat.argtypes = [ | |
| MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
| CHAR_PTR, # const char *s | |
| I32, # size_t n_threads | |
| ] | |
| self.library.minigpt4_begin_chat.restype = I32 | |
| self.library.minigpt4_end_chat.argtypes = [ | |
| MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
| CHAR_PTR_PTR, # const char **token | |
| I32, # size_t n_threads | |
| F32, # float temp | |
| I32, # int32_t top_k | |
| F32, # float top_p | |
| F32, # float tfs_z | |
| F32, # float typical_p | |
| I32, # int32_t repeat_last_n | |
| F32, # float repeat_penalty | |
| F32, # float alpha_presence | |
| F32, # float alpha_frequency | |
| I32, # int mirostat | |
| F32, # float mirostat_tau | |
| F32, # float mirostat_eta | |
| I32, # int penalize_nl | |
| ] | |
| self.library.minigpt4_end_chat.restype = I32 | |
| self.library.minigpt4_reset_chat.argtypes = [ | |
| MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
| ] | |
| self.library.minigpt4_reset_chat.restype = I32 | |
| self.library.minigpt4_contains_eos_token.argtypes = [ | |
| CHAR_PTR, # const char *s | |
| ] | |
| self.library.minigpt4_contains_eos_token.restype = I32 | |
| self.library.minigpt4_is_eos.argtypes = [ | |
| CHAR_PTR, # const char *s | |
| ] | |
| self.library.minigpt4_is_eos.restype = I32 | |
| self.library.minigpt4_free.argtypes = [ | |
| MiniGPT4ContextP, # struct MiniGPT4Context *ctx | |
| ] | |
| self.library.minigpt4_free.restype = I32 | |
| self.library.minigpt4_free_image.argtypes = [ | |
| MiniGPT4ImageP, # struct MiniGPT4Image *image | |
| ] | |
| self.library.minigpt4_free_image.restype = I32 | |
| self.library.minigpt4_free_embedding.argtypes = [ | |
| MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding | |
| ] | |
| self.library.minigpt4_free_embedding.restype = I32 | |
| self.library.minigpt4_error_code_to_string.argtypes = [ | |
| I32, # int error_code | |
| ] | |
| self.library.minigpt4_error_code_to_string.restype = CHAR_PTR | |
| self.library.minigpt4_quantize_model.argtypes = [ | |
| CHAR_PTR, # const char *in_path | |
| CHAR_PTR, # const char *out_path | |
| I32, # int data_type | |
| ] | |
| self.library.minigpt4_quantize_model.restype = I32 | |
| self.library.minigpt4_set_verbosity.argtypes = [ | |
| I32, # int verbosity | |
| ] | |
| self.library.minigpt4_set_verbosity.restype = None | |
| def panic_if_error(self, error_code: int) -> None: | |
| """ | |
| Raises an exception if the error code is not 0. | |
| Parameters | |
| ---------- | |
| error_code : int | |
| Error code to check. | |
| """ | |
| if error_code != 0: | |
| raise RuntimeError(self.library.minigpt4_error_code_to_string(I32(error_code))) | |
| def minigpt4_model_load(self, model_path: str, llm_model_path: str, verbosity: int = 1, seed: int = 1337, n_ctx: int = 2048, n_batch: int = 512, numa: int = 0) -> MiniGPT4Context: | |
| """ | |
| Loads a model from a file. | |
| Args: | |
| model_path (str): Path to model file. | |
| llm_model_path (str): Path to LLM model file. | |
| verbosity (int): Verbosity level: 0 = silent, 1 = error, 2 = info, 3 = debug. Defaults to 0. | |
| n_ctx (int): Size of context for llm model. Defaults to 2048. | |
| seed (int): Seed for llm model. Defaults to 1337. | |
| numa (int): NUMA node to use (0 = NUMA disabled, 1 = NUMA enabled). Defaults to 0. | |
| Returns: | |
| MiniGPT4Context: Context. | |
| """ | |
| ptr = self.library.minigpt4_model_load( | |
| model_path.encode('utf-8'), | |
| llm_model_path.encode('utf-8'), | |
| I32(verbosity), | |
| I32(seed), | |
| I32(n_ctx), | |
| I32(n_batch), | |
| I32(numa), | |
| ) | |
| assert ptr is not None, 'minigpt4_model_load failed' | |
| return MiniGPT4Context(ptr) | |
| def minigpt4_image_load_from_file(self, ctx: MiniGPT4Context, path: str, flags: int) -> MiniGPT4Image: | |
| """ | |
| Loads an image from a file | |
| Args: | |
| ctx (MiniGPT4Context): context | |
| path (str): path | |
| flags (int): flags | |
| Returns: | |
| MiniGPT4Image: image | |
| """ | |
| image = MiniGPT4Image() | |
| self.panic_if_error(self.library.minigpt4_image_load_from_file(ctx.ptr, path.encode('utf-8'), ctypes.pointer(image), I32(flags))) | |
| return image | |
| def minigpt4_preprocess_image(self, ctx: MiniGPT4Context, image: MiniGPT4Image, flags: int = 0) -> MiniGPT4Image: | |
| """ | |
| Preprocesses an image | |
| Args: | |
| ctx (MiniGPT4Context): Context | |
| image (MiniGPT4Image): Image | |
| flags (int): Flags. Defaults to 0. | |
| Returns: | |
| MiniGPT4Image: Preprocessed image | |
| """ | |
| preprocessed_image = MiniGPT4Image() | |
| self.panic_if_error(self.library.minigpt4_preprocess_image(ctx.ptr, ctypes.pointer(image), ctypes.pointer(preprocessed_image), I32(flags))) | |
| return preprocessed_image | |
| def minigpt4_encode_image(self, ctx: MiniGPT4Context, image: MiniGPT4Image, n_threads: int = 0) -> MiniGPT4Embedding: | |
| """ | |
| Encodes an image into embedding | |
| Args: | |
| ctx (MiniGPT4Context): Context. | |
| image (MiniGPT4Image): Image. | |
| n_threads (int): Number of threads to use, if 0, uses all available. Defaults to 0. | |
| Returns: | |
| embedding (MiniGPT4Embedding): Output embedding. | |
| """ | |
| embedding = MiniGPT4Embedding() | |
| self.panic_if_error(self.library.minigpt4_encode_image(ctx.ptr, ctypes.pointer(image), ctypes.pointer(embedding), n_threads)) | |
| return embedding | |
| def minigpt4_begin_chat_image(self, ctx: MiniGPT4Context, image_embedding: MiniGPT4Embedding, s: str, n_threads: int = 0): | |
| """ | |
| Begins a chat with an image. | |
| Args: | |
| ctx (MiniGPT4Context): Context. | |
| image_embedding (MiniGPT4Embedding): Image embedding. | |
| s (str): Question to ask about the image. | |
| n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
| Returns: | |
| None | |
| """ | |
| self.panic_if_error(self.library.minigpt4_begin_chat_image(ctx.ptr, ctypes.pointer(image_embedding), s.encode('utf-8'), n_threads)) | |
| def minigpt4_end_chat_image(self, ctx: MiniGPT4Context, n_threads: int = 0, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1) -> str: | |
| """ | |
| Ends a chat with an image. | |
| Args: | |
| ctx (MiniGPT4Context): Context. | |
| n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
| temp (float, optional): Temperature. Defaults to 0.8. | |
| top_k (int, optional): Top K. Defaults to 40. | |
| top_p (float, optional): Top P. Defaults to 0.9. | |
| tfs_z (float, optional): Tfs Z. Defaults to 1.0. | |
| typical_p (float, optional): Typical P. Defaults to 1.0. | |
| repeat_last_n (int, optional): Repeat last N. Defaults to 64. | |
| repeat_penalty (float, optional): Repeat penality. Defaults to 1.1. | |
| alpha_presence (float, optional): Alpha presence. Defaults to 1.0. | |
| alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0. | |
| mirostat (int, optional): Mirostat. Defaults to 0. | |
| mirostat_tau (float, optional): Mirostat Tau. Defaults to 5.0. | |
| mirostat_eta (float, optional): Mirostat Eta. Defaults to 1.0. | |
| penalize_nl (int, optional): Penalize NL. Defaults to 1. | |
| Returns: | |
| str: Token generated. | |
| """ | |
| token = CHAR_PTR() | |
| self.panic_if_error(self.library.minigpt4_end_chat_image(ctx.ptr, ctypes.pointer(token), n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl)) | |
| return ctypes.cast(token, ctypes.c_char_p).value.decode('utf-8') | |
| def minigpt4_system_prompt(self, ctx: MiniGPT4Context, n_threads: int = 0): | |
| """ | |
| Generates a system prompt. | |
| Args: | |
| ctx (MiniGPT4Context): Context. | |
| n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
| """ | |
| self.panic_if_error(self.library.minigpt4_system_prompt(ctx.ptr, n_threads)) | |
| def minigpt4_begin_chat(self, ctx: MiniGPT4Context, s: str, n_threads: int = 0): | |
| """ | |
| Begins a chat continuing after minigpt4_begin_chat_image | |
| Args: | |
| ctx (MiniGPT4Context): Context. | |
| s (str): Question to ask about the image. | |
| n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
| Returns: | |
| None | |
| """ | |
| self.panic_if_error(self.library.minigpt4_begin_chat(ctx.ptr, s.encode('utf-8'), n_threads)) | |
| def minigpt4_end_chat(self, ctx: MiniGPT4Context, n_threads: int = 0, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1) -> str: | |
| """ | |
| Ends a chat. | |
| Args: | |
| ctx (MiniGPT4Context): Context. | |
| n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0. | |
| temp (float, optional): Temperature. Defaults to 0.8. | |
| top_k (int, optional): Top K. Defaults to 40. | |
| top_p (float, optional): Top P. Defaults to 0.9. | |
| tfs_z (float, optional): Tfs Z. Defaults to 1.0. | |
| typical_p (float, optional): Typical P. Defaults to 1.0. | |
| repeat_last_n (int, optional): Repeat last N. Defaults to 64. | |
| repeat_penalty (float, optional): Repeat penality. Defaults to 1.1. | |
| alpha_presence (float, optional): Alpha presence. Defaults to 1.0. | |
| alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0. | |
| mirostat (int, optional): Mirostat. Defaults to 0. | |
| mirostat_tau (float, optional): Mirostat Tau. Defaults to 5.0. | |
| mirostat_eta (float, optional): Mirostat Eta. Defaults to 1.0. | |
| penalize_nl (int, optional): Penalize NL. Defaults to 1. | |
| Returns: | |
| str: Token generated. | |
| """ | |
| token = CHAR_PTR() | |
| self.panic_if_error(self.library.minigpt4_end_chat(ctx.ptr, ctypes.pointer(token), n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl)) | |
| return ctypes.cast(token, ctypes.c_char_p).value.decode('utf-8') | |
| def minigpt4_reset_chat(self, ctx: MiniGPT4Context): | |
| """ | |
| Resets the chat. | |
| Args: | |
| ctx (MiniGPT4Context): Context. | |
| """ | |
| self.panic_if_error(self.library.minigpt4_reset_chat(ctx.ptr)) | |
| def minigpt4_contains_eos_token(self, s: str) -> bool: | |
| """ | |
| Checks if a string contains an EOS token. | |
| Args: | |
| s (str): String to check. | |
| Returns: | |
| bool: True if the string contains an EOS token, False otherwise. | |
| """ | |
| return self.library.minigpt4_contains_eos_token(s.encode('utf-8')) | |
| def minigpt4_is_eos(self, s: str) -> bool: | |
| """ | |
| Checks if a string is EOS. | |
| Args: | |
| s (str): String to check. | |
| Returns: | |
| bool: True if the string contains an EOS, False otherwise. | |
| """ | |
| return self.library.minigpt4_is_eos(s.encode('utf-8')) | |
| def minigpt4_free(self, ctx: MiniGPT4Context) -> None: | |
| """ | |
| Frees a context. | |
| Args: | |
| ctx (MiniGPT4Context): Context. | |
| """ | |
| self.panic_if_error(self.library.minigpt4_free(ctx.ptr)) | |
| def minigpt4_free_image(self, image: MiniGPT4Image) -> None: | |
| """ | |
| Frees an image. | |
| Args: | |
| image (MiniGPT4Image): Image. | |
| """ | |
| self.panic_if_error(self.library.minigpt4_free_image(ctypes.pointer(image))) | |
| def minigpt4_free_embedding(self, embedding: MiniGPT4Embedding) -> None: | |
| """ | |
| Frees an embedding. | |
| Args: | |
| embedding (MiniGPT4Embedding): Embedding. | |
| """ | |
| self.panic_if_error(self.library.minigpt4_free_embedding(ctypes.pointer(embedding))) | |
| def minigpt4_error_code_to_string(self, error_code: int) -> str: | |
| """ | |
| Converts an error code to a string. | |
| Args: | |
| error_code (int): Error code. | |
| Returns: | |
| str: Error string. | |
| """ | |
| return self.library.minigpt4_error_code_to_string(error_code).decode('utf-8') | |
| def minigpt4_quantize_model(self, in_path: str, out_path: str, data_type: DataType): | |
| """ | |
| Quantizes a model file. | |
| Args: | |
| in_path (str): Path to input model file. | |
| out_path (str): Path to write output model file. | |
| data_type (DataType): Must be one DataType enum values. | |
| """ | |
| self.panic_if_error(self.library.minigpt4_quantize_model(in_path.encode('utf-8'), out_path.encode('utf-8'), data_type)) | |
| def minigpt4_set_verbosity(self, verbosity: Verbosity): | |
| """ | |
| Sets verbosity. | |
| Args: | |
| verbosity (int): Verbosity. | |
| """ | |
| self.library.minigpt4_set_verbosity(I32(verbosity)) | |
| def load_library() -> MiniGPT4SharedLibrary: | |
| """ | |
| Attempts to find minigpt4.cpp shared library and load it. | |
| """ | |
| file_name: str | |
| if 'win32' in sys.platform or 'cygwin' in sys.platform: | |
| file_name = 'minigpt4.dll' | |
| elif 'darwin' in sys.platform: | |
| file_name = 'libminigpt4.dylib' | |
| else: | |
| file_name = 'libminigpt4.so' | |
| cwd = pathlib.Path(os.getcwd()) | |
| repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent | |
| paths = [ | |
| # If we are in "minigpt4" directory | |
| f'../bin/Release/{file_name}', | |
| # If we are in repo root directory | |
| f'bin/Release/{file_name}', | |
| # If we compiled in build directory | |
| f'build/bin/Release/{file_name}', | |
| # If we compiled in build directory | |
| f'build/{file_name}', | |
| f'../build/{file_name}', | |
| # Search relative to this file | |
| str(repo_root_dir / 'bin' / 'Release' / file_name), | |
| # Fallback | |
| str(repo_root_dir / file_name), | |
| str(cwd / file_name) | |
| ] | |
| for path in paths: | |
| if os.path.isfile(path): | |
| return MiniGPT4SharedLibrary(path) | |
| return MiniGPT4SharedLibrary(paths[-1]) | |
| class MiniGPT4ChatBot: | |
| def __init__(self, model_path: str, llm_model_path: str, verbosity: Verbosity = Verbosity.SILENT, n_threads: int = 0): | |
| """ | |
| Creates a new MiniGPT4ChatBot instance. | |
| Args: | |
| model_path (str): Path to model file. | |
| llm_model_path (str): Path to language model model file. | |
| verbosity (Verbosity, optional): Verbosity. Defaults to Verbosity.SILENT. | |
| n_threads (int, optional): Number of threads to use. Defaults to 0. | |
| """ | |
| self.library = load_library() | |
| self.ctx = self.library.minigpt4_model_load(model_path, llm_model_path, verbosity) | |
| self.n_threads = n_threads | |
| from PIL import Image | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| self.image_size = 224 | |
| mean = (0.48145466, 0.4578275, 0.40821073) | |
| std = (0.26862954, 0.26130258, 0.27577711) | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.RandomResizedCrop( | |
| self.image_size, | |
| interpolation=InterpolationMode.BICUBIC, | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean, std) | |
| ] | |
| ) | |
| self.embedding: Optional[MiniGPT4Embedding] = None | |
| self.is_image_chat = False | |
| self.chat_history = [] | |
| def free(self): | |
| if self.ctx: | |
| self.library.minigpt4_free(self.ctx) | |
| def generate(self, message: str, limit: int = 1024, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1): | |
| """ | |
| Generates a chat response. | |
| Args: | |
| message (str): Message. | |
| limit (int, optional): Limit. Defaults to 1024. | |
| temp (float, optional): Temperature. Defaults to 0.8. | |
| top_k (int, optional): Top K. Defaults to 40. | |
| top_p (float, optional): Top P. Defaults to 0.9. | |
| tfs_z (float, optional): TFS Z. Defaults to 1.0. | |
| typical_p (float, optional): Typical P. Defaults to 1.0. | |
| repeat_last_n (int, optional): Repeat last N. Defaults to 64. | |
| repeat_penalty (float, optional): Repeat penalty. Defaults to 1.1. | |
| alpha_presence (float, optional): Alpha presence. Defaults to 1.0. | |
| alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0. | |
| mirostat (int, optional): Mirostat. Defaults to 0. | |
| mirostat_tau (float, optional): Mirostat tau. Defaults to 5.0. | |
| mirostat_eta (float, optional): Mirostat eta. Defaults to 1.0. | |
| penalize_nl (int, optional): Penalize NL. Defaults to 1. | |
| """ | |
| if self.is_image_chat: | |
| self.is_image_chat = False | |
| self.library.minigpt4_begin_chat_image(self.ctx, self.embedding, message, self.n_threads) | |
| chat = '' | |
| for _ in range(limit): | |
| token = self.library.minigpt4_end_chat_image(self.ctx, self.n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl) | |
| chat += token | |
| if self.library.minigpt4_contains_eos_token(token): | |
| continue | |
| if self.library.minigpt4_is_eos(chat): | |
| break | |
| yield token | |
| else: | |
| self.library.minigpt4_begin_chat(self.ctx, message, self.n_threads) | |
| chat = '' | |
| for _ in range(limit): | |
| token = self.library.minigpt4_end_chat(self.ctx, self.n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl) | |
| chat += token | |
| if self.library.minigpt4_contains_eos_token(token): | |
| continue | |
| if self.library.minigpt4_is_eos(chat): | |
| break | |
| yield token | |
| def reset_chat(self): | |
| """ | |
| Resets the chat. | |
| """ | |
| self.is_image_chat = False | |
| if self.embedding: | |
| self.library.minigpt4_free_embedding(self.embedding) | |
| self.embedding = None | |
| self.library.minigpt4_reset_chat(self.ctx) | |
| self.library.minigpt4_system_prompt(self.ctx, self.n_threads) | |
| def upload_image(self, image): | |
| """ | |
| Uploads an image. | |
| Args: | |
| image (Image): Image. | |
| """ | |
| self.reset_chat() | |
| image = self.transform(image) | |
| image = image.unsqueeze(0) | |
| image = image.numpy() | |
| image = image.ctypes.data_as(ctypes.c_void_p) | |
| minigpt4_image = MiniGPT4Image(image, self.image_size, self.image_size, 3, ImageFormat.F32) | |
| self.embedding = self.library.minigpt4_encode_image(self.ctx, minigpt4_image, self.n_threads) | |
| self.is_image_chat = True | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Test loading minigpt4') | |
| parser.add_argument('model_path', help='Path to model file') | |
| parser.add_argument('llm_model_path', help='Path to llm model file') | |
| parser.add_argument('-i', '--image_path', help='Image to test', default='images/llama.png') | |
| parser.add_argument('-p', '--prompts', help='Text to test', default='what is the text in the picture?,what is the color of it?') | |
| args = parser.parse_args() | |
| model_path = args.model_path | |
| llm_model_path = args.llm_model_path | |
| image_path = args.image_path | |
| prompts = args.prompts | |
| if not Path(model_path).exists(): | |
| print(f'Model does not exist: {model_path}') | |
| exit(1) | |
| if not Path(llm_model_path).exists(): | |
| print(f'LLM Model does not exist: {llm_model_path}') | |
| exit(1) | |
| prompts = prompts.split(',') | |
| print('Loading minigpt4 shared library...') | |
| library = load_library() | |
| print(f'Loaded library {library}') | |
| ctx = library.minigpt4_model_load(model_path, llm_model_path, Verbosity.DEBUG) | |
| image = library.minigpt4_image_load_from_file(ctx, image_path, 0) | |
| preprocessed_image = library.minigpt4_preprocess_image(ctx, image, 0) | |
| question = prompts[0] | |
| n_threads = 0 | |
| embedding = library.minigpt4_encode_image(ctx, preprocessed_image, n_threads) | |
| library.minigpt4_system_prompt(ctx, n_threads) | |
| library.minigpt4_begin_chat_image(ctx, embedding, question, n_threads) | |
| chat = '' | |
| while True: | |
| token = library.minigpt4_end_chat_image(ctx, n_threads) | |
| chat += token | |
| if library.minigpt4_contains_eos_token(token): | |
| continue | |
| if library.minigpt4_is_eos(chat): | |
| break | |
| print(token, end='') | |
| for i in range(1, len(prompts)): | |
| prompt = prompts[i] | |
| library.minigpt4_begin_chat(ctx, prompt, n_threads) | |
| chat = '' | |
| while True: | |
| token = library.minigpt4_end_chat(ctx, n_threads) | |
| chat += token | |
| if library.minigpt4_contains_eos_token(token): | |
| continue | |
| if library.minigpt4_is_eos(chat): | |
| break | |
| print(token, end='') | |
| library.minigpt4_free_image(image) | |
| library.minigpt4_free_image(preprocessed_image) | |
| library.minigpt4_free(ctx) | |