import argparse import os from queue import SimpleQueue from threading import Thread from typing import Iterator import gradio as gr import spaces import torch from gradio import Chatbot from huggingface_hub import InferenceClient from image_utils import ImageStitcher from StreamDiffusionIO import LatentConsistencyModelStreamIO MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) DESCRIPTION = """\ # Kanji-Streaming Chat 🌍 This Space is adapted from [Llama-2-7b-chat](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat) space, demonstrating how to "chat" with LLM with [Kanji-Streaming](https://github.com/AgainstEntropy/kanji). 🔨 The technique behind Kanji-Streaming is [StreamDiffusionIO](https://github.com/AgainstEntropy/StreamDiffusionIO), which is based on [StreamDiffusion](https://github.com/cumulo-autumn/StreamDiffusion), *but especially allows to render text streams into image streams*. 🔎 For more details about Kanji-Streaming, take a look at the [github repository](https://github.com/AgainstEntropy/kanji). """ LICENSE = """

--- As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). """ parser = argparse.ArgumentParser(description="Gradio launcher for Streaming-Kanji.") parser.add_argument( "--sd_model_id_or_path", type=str, default="runwayml/stable-diffusion-v1-5", required=False, help="Path to downloaded sd-1-5 model or model identifier from huggingface.co/models.", ) parser.add_argument( "--lora_path", type=str, default="AgainstEntropy/kanji-lora-sd-v1-5", required=False, help="Path to downloaded LoRA weight or model identifier from huggingface.co/models.", ) parser.add_argument( "--lcm_lora_path", type=str, default="AgainstEntropy/kanji-lcm-lora-sd-v1-5", required=False, help="Path to downloaded LCM-LoRA weight or model identifier from huggingface.co/models.", ) parser.add_argument( "--img_res", type=int, default=64, required=False, help="Image resolution for displaying Kanji characters in ChatBot.", ) parser.add_argument( "--img_per_line", type=int, default=16, required=False, help="Number of Kanji characters to display in a single line.", ) parser.add_argument( "--tmp_dir", type=str, default="./tmp", required=False, help="Path to save temporary images generated by StreamDiffusionIO.", ) args = parser.parse_args() if torch.cuda.is_available(): device = "cuda" else: device = "cpu" DESCRIPTION += "\n

Running on CPU 🥶 This demo works best on GPU.

" client = InferenceClient( model="mistralai/Mixtral-8x7B-Instruct-v0.1", ) def format_prompt(message, history, system_prompt=''): prompt = f" {system_prompt}" for user_prompt, bot_response in history: prompt += f"[INST] {user_prompt} [/INST]" if isinstance(bot_response, tuple): bot_response = bot_response[1] if not bot_response.endswith(""): bot_response += "" prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt lcm_stream = LatentConsistencyModelStreamIO( model_id_or_path=args.sd_model_id_or_path, lcm_lora_path=args.lcm_lora_path, lora_dict={args.lora_path: 1}, resolution=128, device=device, use_xformers=True, verbose=True, ) tmp_dir_template = f"{args.tmp_dir}/%d" response_num = 0 stitcher = ImageStitcher( tmp_dir=tmp_dir_template % response_num, img_res=args.img_res, img_per_line=args.img_per_line, verbose=True, ) @spaces.GPU def generate( message: str, chat_history: list[tuple[str, str]], show_original_response: bool, seed: int, system_prompt: str = '', max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: if temperature < 1e-2: temperature = 1e-2 generate_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, repetition_penalty=repetition_penalty, ) formatted_prompt = format_prompt(message, chat_history, system_prompt) print(formatted_prompt) streamer = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) outputs = [] prompt_queue = SimpleQueue() lcm_stream.reset(seed) stitcher.reset() global response_num response_num += 1 stitcher.update_tmp_dir(tmp_dir_template % response_num) def append_to_queue(): for response in streamer: text = response.token.text outputs.append(text) prompt = text.strip() if prompt and prompt not in ['']: if prompt.endswith("."): prompt = prompt[:-1] prompt_queue.put(prompt) prompt_queue.put(None) append_thread = Thread(target=append_to_queue) append_thread.start() def show_image(prompt: str = None): image, text = lcm_stream(prompt) img_path = None if image is not None: img_path = stitcher.add(image, text) return img_path while True: prompt = prompt_queue.get() if prompt is None: break img_path = show_image(prompt) if img_path is not None: yield (img_path, ) # Continue to display the remaining images while True: img_path = show_image() if img_path is not None: yield (img_path, ''.join(outputs)) if lcm_stream.stop(): break print(outputs) if show_original_response: yield ''.join(outputs) chat_interface = gr.ChatInterface( fn=generate, chatbot=Chatbot(height=400), additional_inputs=[ gr.Checkbox( label="Show original response", value=False, ), gr.Number( label="Seed", info="Random Seed for Kanji Generation (maybe some kind of accent 🤔)", step=1, value=1026, ), gr.Textbox( label="System prompt", value="", lines=4), gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], stop_btn=None, examples=[ ["Hello there! How are you doing?"], ["Can you explain briefly to me what is the Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], ], ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") chat_interface.render() gr.Markdown(LICENSE) if __name__ == "__main__": demo.queue(max_size=20).launch(server_name="0.0.0.0", share=False, show_api=False)