#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from collections import defaultdict import json import os import platform import re from typing import List from project_settings import project_path os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() import gradio as gr from threading import Thread from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.bert.tokenization_bert import BertTokenizer from transformers.generation.streamers import TextIteratorStreamer import torch def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--max_new_tokens", default=512, type=int) parser.add_argument("--top_p", default=0.9, type=float) parser.add_argument("--temperature", default=0.35, type=float) parser.add_argument("--repetition_penalty", default=1.0, type=float) parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str) parser.add_argument( "--examples_json_file", default="examples.json", type=str ) args = parser.parse_args() return args def repl1(match): result = "{}{}".format(match.group(1), match.group(2)) return result def repl2(match): result = "{}".format(match.group(1)) return result def main(): args = get_args() description = """ ## GPT2 Chat """ # example json with open(args.examples_json_file, "r", encoding="utf-8") as f: examples = json.load(f) if args.device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' else: device = args.device input_text_box = gr.Text(label="text") output_text_box = gr.Text(lines=4, label="generated_content") def fn_stream(text: str, max_new_tokens: int = 200, top_p: float = 0.85, temperature: float = 0.35, repetition_penalty: float = 1.2, model_name: str = "qgyd2021/lib_service_4chan", is_chat: bool = True, ): tokenizer = BertTokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) model = model.eval() text_encoded = tokenizer.__call__(text, add_special_tokens=False) input_ids_ = text_encoded["input_ids"] input_ids = [tokenizer.cls_token_id] input_ids.extend(input_ids_) if is_chat: input_ids.append(tokenizer.sep_token_id) input_ids = torch.tensor([input_ids], dtype=torch.long) input_ids = input_ids.to(device) streamer = TextIteratorStreamer(tokenizer=tokenizer) generation_kwargs = dict( inputs=input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.sep_token_id if is_chat else None, pad_token_id=tokenizer.pad_token_id, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() output: str = "" first_answer = True for output_ in streamer: if first_answer: first_answer = False continue output_ = output_.replace("[UNK] ", "") output_ = output_.replace("[UNK]", "") output_ = output_.replace("[CLS] ", "") output_ = output_.replace("[CLS]", "") output += output_ if output.startswith("[SEP]"): output = output[5:] output = output.lstrip(" ,.!?") output = re.sub(r"([,。!?\u4e00-\u9fa5]) ([,。!?\u4e00-\u9fa5])", repl1, output) output = re.sub(r"([,。!?\u4e00-\u9fa5]) ", repl2, output) output = output.replace("[SEP] ", "\n") output = output.replace("[SEP]", "\n") yield output model_name_choices = ["trained_models/lib_service_4chan", "trained_models/chinese_porn_novel"] \ if platform.system() == "Windows" else \ [ "qgyd2021/lib_service_4chan", "qgyd2021/chinese_chitchat", "qgyd2021/chinese_porn_novel", "qgyd2021/few_shot_intent" ] demo = gr.Interface( fn=fn_stream, inputs=[ input_text_box, gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"), gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"), gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"), gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"), gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"), gr.Checkbox(value=True, label="is_chat") ], outputs=[output_text_box], examples=examples, cache_examples=False, examples_per_page=50, title="GPT2 Chat", description=description, ) demo.queue().launch() return if __name__ == '__main__': main()