#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from collections import defaultdict import json import os import platform import re import string from typing import List from project_settings import project_path os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() import gradio as gr from threading import Thread from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.bert.tokenization_bert import BertTokenizer from transformers.generation.streamers import TextIteratorStreamer import torch def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--max_new_tokens", default=512, type=int) parser.add_argument("--top_p", default=0.9, type=float) parser.add_argument("--temperature", default=0.35, type=float) parser.add_argument("--repetition_penalty", default=1.0, type=float) parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str) parser.add_argument( "--examples_json_file", default="examples.json", type=str ) args = parser.parse_args() return args def repl1(match): result = "{}{}".format(match.group(1), match.group(2)) return result def repl2(match): result = "{}".format(match.group(1)) return result def remove_space_between_cn_en(text): splits = re.split(" ", text) if len(splits) < 2: return text result = "" for t in splits: if t == "": continue if re.search(f"[a-zA-Z0-9{string.punctuation}]$", result) and re.search("^[a-zA-Z0-9]", t): result += " " result += t else: if not result == "": result += t else: result = t if text.endswith(" "): result += " " return result def main(): args = get_args() description = """ ## GPT2 Chat """ # example json with open(args.examples_json_file, "r", encoding="utf-8") as f: examples = json.load(f) if args.device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' else: device = args.device input_text_box = gr.Text(label="text") output_text_box = gr.Text(lines=4, label="generated_content") def fn_stream(text: str, max_new_tokens: int = 200, top_p: float = 0.85, temperature: float = 0.35, repetition_penalty: float = 1.2, model_name: str = "qgyd2021/lip_service_4chan", is_chat: bool = True, ): tokenizer = BertTokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) model = model.eval() text_encoded = tokenizer.__call__(text, add_special_tokens=False) input_ids_ = text_encoded["input_ids"] input_ids = [tokenizer.cls_token_id] input_ids.extend(input_ids_) if is_chat: input_ids.append(tokenizer.sep_token_id) input_ids = torch.tensor([input_ids], dtype=torch.long) input_ids = input_ids.to(device) streamer = TextIteratorStreamer(tokenizer=tokenizer) generation_kwargs = dict( inputs=input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.sep_token_id if is_chat else None, pad_token_id=tokenizer.pad_token_id, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() output: str = "" first_answer = True for output_ in streamer: if first_answer: first_answer = False continue output_ = output_.replace("[UNK] ", "") output_ = output_.replace("[UNK]", "") output_ = output_.replace("[CLS] ", "") output_ = output_.replace("[CLS]", "") output += output_ if output.startswith("[SEP]"): output = output[5:] output = output.lstrip(" ,.!?") output = remove_space_between_cn_en(output) # output = re.sub(r"([,。!?\u4e00-\u9fa5]) ([,。!?\u4e00-\u9fa5])", repl1, output) # output = re.sub(r"([,。!?\u4e00-\u9fa5]) ", repl2, output) output = output.replace("[SEP] ", "\n") output = output.replace("[SEP]", "\n") yield output model_name_choices = ["trained_models/lip_service_4chan", "trained_models/chinese_porn_novel"] \ if platform.system() == "Windows" else \ [ "qgyd2021/lip_service_4chan", "qgyd2021/chinese_chitchat", "qgyd2021/chinese_porn_novel", "qgyd2021/few_shot_intent_gpt2", "qgyd2021/similar_question_generation", ] demo = gr.Interface( fn=fn_stream, inputs=[ input_text_box, gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"), gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"), gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"), gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"), gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"), gr.Checkbox(value=True, label="is_chat") ], outputs=[output_text_box], examples=examples, cache_examples=False, examples_per_page=50, title="GPT2 Chat", description=description, ) demo.queue().launch( share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=7860 ) return if __name__ == '__main__': main()