#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from collections import defaultdict import os import gradio as gr from threading import Thread from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.bert.tokenization_bert import BertTokenizer from transformers.generation.streamers import TextIteratorStreamer import torch from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--max_new_tokens", default=512, type=int) parser.add_argument("--top_p", default=0.9, type=float) parser.add_argument("--temperature", default=0.35, type=float) parser.add_argument("--repetition_penalty", default=1.0, type=float) parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str) args = parser.parse_args() return args description = """ ## GPT2 Chat """ examples = [ ] def main(): args = get_args() if args.device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' else: device = args.device input_text_box = gr.Text(label="text") output_text_box = gr.Text(lines=4, label="generated_content") def fn_stream(text: str, max_new_tokens: int = 200, top_p: float = 0.85, temperature: float = 0.35, repetition_penalty: float = 1.2, model_name: str = "qgyd2021/lib_service_4chan", is_chat: bool = True, ): tokenizer = BertTokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) model = model.eval() text_encoded = tokenizer.__call__(text, add_special_tokens=False) input_ids_ = text_encoded["input_ids"] input_ids = [tokenizer.cls_token_id] input_ids.extend(input_ids_) if is_chat: input_ids.append(tokenizer.sep_token_id) input_ids = torch.tensor([input_ids], dtype=torch.long) input_ids = input_ids.to(device) output = "" streamer = TextIteratorStreamer(tokenizer=tokenizer) generation_kwargs = dict( inputs=input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.sep_token_id if is_chat else None, pad_token_id=tokenizer.pad_token_id, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() for output_ in streamer: output_ = output_.replace(" ", "") output_ = output_.replace("[CLS]", "") output_ = output_.replace("[SEP]", "\n") output_ = output_.replace("[UNK]", "") output += output_ output_text_box.value += output yield output demo = gr.Interface( fn=fn_stream, inputs=[ input_text_box, gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"), gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"), gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"), gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"), gr.Dropdown(choices=["qgyd2021/lib_service_4chan"], label="model_name"), gr.Checkbox(label="is_chat") ], outputs=[output_text_box], examples=[ ["怎样擦屁股才能擦的干净", 512, 0.75, 0.35, 1.2, "qgyd2021/lib_service_4chan", True], ], cache_examples=False, examples_per_page=50, title="H Novel Generate", description=description, ) demo.queue().launch() return if __name__ == '__main__': main()