Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import threading | |
| import arrow | |
| import time | |
| import argparse | |
| import logging | |
| from dataclasses import dataclass | |
| import torch | |
| import sentencepiece as spm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import GPTNeoXForCausalLM, GPTNeoXConfig | |
| from transformers.generation.streamers import BaseStreamer | |
| from huggingface_hub import hf_hub_download, login | |
| logger = logging.getLogger() | |
| logger.setLevel("INFO") | |
| gr_interface = None | |
| VERSION = "0.1.0" | |
| class DefaultArgs: | |
| hf_model_name_or_path: str = None | |
| hf_tokenizer_name_or_path: str = None | |
| spm_model_path: str = None | |
| env: str = "dev" | |
| port: int = 7860 | |
| make_public: bool = False | |
| if os.getenv("RUNNING_ON_HF_SPACE"): | |
| login(token=os.getenv("HF_TOKEN")) | |
| hf_repo = os.getenv("HF_MODEL_REPO") | |
| args = DefaultArgs() | |
| args.hf_model_name_or_path = hf_repo | |
| args.hf_tokenizer_name_or_path = os.path.join(hf_repo, "tokenizer") | |
| args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model") | |
| else: | |
| parser = argparse.ArgumentParser(description="") | |
| parser.add_argument("--hf_model_name_or_path", type=str, required=True) | |
| parser.add_argument("--hf_tokenizer_name_or_path", type=str, required=False) | |
| parser.add_argument("--spm_model_path", type=str, required=True) | |
| parser.add_argument("--env", type=str, default="dev") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--make_public", action='store_true') | |
| args = parser.parse_args() | |
| def load_model( | |
| model_dir, | |
| ): | |
| config = GPTNeoXConfig.from_pretrained(model_dir) | |
| config.is_decoder = True | |
| model = GPTNeoXForCausalLM.from_pretrained(model_dir, config=config, torch_dtype=torch.bfloat16) | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda:0") | |
| return model | |
| logging.info("Loading model") | |
| model = load_model(args.hf_model_name_or_path) | |
| sp = spm.SentencePieceProcessor(model_file=args.spm_model_path) | |
| logging.info("Finished loading model") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| args.hf_model_name_or_path, | |
| subfolder="tokenizer", | |
| use_fast=False | |
| ) | |
| class TokenizerStreamer(BaseStreamer): | |
| def __init__(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| self.num_invoked = 0 | |
| self.prompt = "" | |
| self.generated_text = "" | |
| self.ended = False | |
| def put(self, t: torch.Tensor): | |
| d = t.dim() | |
| if d == 1: | |
| pass | |
| elif d == 2: | |
| t = t[0] | |
| else: | |
| raise NotImplementedError | |
| t = [int(x) for x in t.numpy()] | |
| text = tokenizer.decode(t) | |
| if text in [tokenizer.bos_token, tokenizer.eos_token]: | |
| text = "" | |
| if self.num_invoked == 0: | |
| self.prompt = text | |
| self.num_invoked += 1 | |
| return | |
| self.generated_text += text | |
| logging.debug(f"[streamer]: {self.generated_text}") | |
| def end(self): | |
| self.ended = True | |
| INPUT_PROMPT = """ไปฅไธใฏใใฟในใฏใ่ชฌๆใใๆ็คบใจใๆ่ใฎใใๅ ฅๅใฎ็ตใฟๅใใใงใใ่ฆๆฑใ้ฉๅใซๆบใใๅฟ็ญใๆธใใชใใใ | |
| ### ๆ็คบ: | |
| {instruction} | |
| ### ๅ ฅๅ: | |
| {input} | |
| ### ๅฟ็ญ: """ | |
| NO_INPUT_PROMPT = """ไปฅไธใฏใใฟในใฏใ่ชฌๆใใๆ็คบใจใๆ่ใฎใใๅ ฅๅใฎ็ตใฟๅใใใงใใ่ฆๆฑใ้ฉๅใซๆบใใๅฟ็ญใๆธใใชใใใ | |
| ### ๆ็คบ: | |
| {instruction} | |
| ### ๅฟ็ญ: """ | |
| def postprocess_output(output): | |
| output = output\ | |
| .split('### ๅฟ็ญ:')[1]\ | |
| .split('###')[0]\ | |
| .split('##')[0]\ | |
| .lstrip(tokenizer.bos_token)\ | |
| .rstrip(tokenizer.eos_token)\ | |
| .replace("###", "")\ | |
| .strip() | |
| return output | |
| def generate( | |
| prompt, | |
| max_new_tokens, | |
| temperature, | |
| repetition_penalty, | |
| do_sample, | |
| no_repeat_ngram_size, | |
| ): | |
| log = dict(locals()) | |
| logging.debug(log) | |
| input_text = NO_INPUT_PROMPT.format(instruction=prompt) | |
| input_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt") | |
| streamer = TokenizerStreamer(tokenizer=tokenizer) | |
| max_possilbe_new_tokens = model.config.max_position_embeddings - input_ids.shape[0] | |
| max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens) | |
| thr = threading.Thread(target=model.generate, args=(), kwargs=dict( | |
| input_ids=input_ids.to(model.device), | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| max_new_tokens=max_possilbe_new_tokens, | |
| pad_token_id=tokenizer.pad_token_id, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| bad_words_ids=[[tokenizer.unk_token_id]], | |
| streamer=streamer, | |
| )) | |
| thr.start() | |
| while not streamer.ended: | |
| time.sleep(0.05) | |
| yield streamer.generated_text | |
| # TODO: optimize for final few tokens | |
| gen = streamer.generated_text | |
| log.update(dict( | |
| generation=gen, | |
| version=VERSION, | |
| time=str(arrow.now("+09:00")))) | |
| logging.info(log) | |
| yield gen | |
| def process_feedback( | |
| rating, | |
| prompt, | |
| generation, | |
| max_new_tokens, | |
| temperature, | |
| repetition_penalty, | |
| do_sample, | |
| no_repeat_ngram_size, | |
| ): | |
| log = dict(locals()) | |
| log.update(dict( | |
| time=str(arrow.now("+09:00")), | |
| version=VERSION, | |
| )) | |
| logging.info(log) | |
| if gr_interface: | |
| gr_interface.close(verbose=False) | |
| with gr.Blocks() as gr_interface: | |
| with gr.Row(): | |
| gr.Markdown(f"# ๆฅๆฌ่ช StableLM Tuned Pre-Alpha ({VERSION})") | |
| # gr.Markdown(f"ใใผใธใงใณ๏ผ{VERSION}") | |
| with gr.Row(): | |
| gr.Markdown("ใใฎ่จ่ชใขใใซใฏ Stability AI Japan ใ้็บใใๅๆใใผใธใงใณใฎๆฅๆฌ่ชใขใใซใงใใใขใใซใฏใใใญใณใใใใซๅ ฅๅใใ่ใใใใใจใซๅฏพใใฆใใใใใใๅฟ็ญใใใใใจใใงใใพใใ") | |
| with gr.Row(): | |
| # left panel | |
| with gr.Column(scale=1): | |
| # generation params | |
| with gr.Box(): | |
| gr.Markdown("ใใฉใกใผใฟ") | |
| # hidden default params | |
| do_sample = gr.Checkbox(True, label="Do Sample", info="ใตใณใใชใณใฐ็ๆ", visible=True) | |
| no_repeat_ngram_size = gr.Slider(0, 10, value=3, step=1, label="No Repeat Ngram Size", visible=False) | |
| # visible params | |
| max_new_tokens = gr.Slider( | |
| 128, | |
| min(512, model.config.max_position_embeddings), | |
| value=128, | |
| step=128, | |
| label="max tokens", | |
| info="็ๆใใใใผใฏใณใฎๆๅคงๆฐใๆๅฎใใ", | |
| ) | |
| temperature = gr.Slider( | |
| 0, 1, value=0.1, step=0.05, label="temperature", | |
| info="ไฝใๅคใฏๅบๅใใใ้ไธญใใใฆๆฑบๅฎ่ซ็ใซใใ") | |
| repetition_penalty = gr.Slider( | |
| 1, 1.5, value=1.2, step=0.05, label="frequency penalty", | |
| info="้ซใๅคใฏAIใ็นฐใ่ฟใๅฏ่ฝๆงใๆธๅฐใใใ") | |
| # grouping params for easier reference | |
| gr_params = [ | |
| max_new_tokens, | |
| temperature, | |
| repetition_penalty, | |
| do_sample, | |
| no_repeat_ngram_size, | |
| ] | |
| # right panel | |
| with gr.Column(scale=2): | |
| # user input block | |
| with gr.Box(): | |
| textbox_prompt = gr.Textbox( | |
| label="ใใญใณใใ", | |
| placeholder="ๆฅๆฌใฎ้ฆ้ฝใฏ๏ผ", | |
| interactive=True, | |
| lines=5, | |
| value="" | |
| ) | |
| with gr.Box(): | |
| with gr.Row(): | |
| btn_stop = gr.Button(value="ใญใฃใณใปใซ", variant="secondary") | |
| btn_submit = gr.Button(value="ๅฎ่ก", variant="primary") | |
| # model output block | |
| with gr.Box(): | |
| textbox_generation = gr.Textbox( | |
| label="็ๆ็ตๆ", | |
| lines=5, | |
| value="" | |
| ) | |
| # rating block | |
| with gr.Row(): | |
| gr.Markdown("ใใ่ฏใ่จ่ชใขใใซใ็ๆงใซๆไพใงใใใใใ็ๆๅ่ณชใซใคใใฆใฎใๆ่ฆใใ่ใใใใ ใใใ") | |
| with gr.Box(): | |
| with gr.Row(): | |
| rating_options = [ | |
| "ๆๆช", | |
| "ไธๅๆ ผ", | |
| "ไธญ็ซ", | |
| "ๅๆ ผ", | |
| "ๆ้ซ", | |
| ] | |
| btn_ratings = [gr.Button(value=v) for v in rating_options] | |
| # TODO: we might not need this for sharing with close groups | |
| # with gr.Box(): | |
| # gr.Markdown("TODO๏ผFor more feedback link for google form") | |
| # event handling | |
| inputs = [textbox_prompt] + gr_params | |
| click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True) | |
| btn_stop.click(None, None, None, cancels=click_event, queue=False) | |
| for btn_rating in btn_ratings: | |
| btn_rating.click(process_feedback, [btn_rating, textbox_prompt, textbox_generation] + gr_params, queue=False) | |
| gr_interface.queue(max_size=32, concurrency_count=2) | |
| gr_interface.launch(server_port=args.port, share=args.make_public) | |