Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import logging | |
| import os | |
| import platform | |
| import re | |
| import string | |
| from typing import List, Tuple | |
| from project_settings import project_path | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() | |
| logging.basicConfig( | |
| level=logging.INFO if platform.system() == "Windows" else logging.INFO, | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| import dingtalk_stream | |
| from dingtalk_stream import AckMessage | |
| import gradio as gr | |
| from threading import Thread | |
| import torch | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel | |
| from transformers.models.bert.tokenization_bert import BertTokenizer | |
| from project_settings import environment | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--client_id", | |
| default=environment.get("client_id"), | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--client_secret", | |
| default=environment.get("client_secret"), | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--model_name", | |
| default=(project_path / "trained_models/lib_service_4chan").as_posix() if platform.system() == "Windows" else "qgyd2021/lip_service_4chan", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--dingtalk_develop_md_file", | |
| default="dingtalk_develop.md", | |
| type=str, | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| class LipService4ChanHandler(dingtalk_stream.ChatbotHandler): | |
| def __init__(self, | |
| model_name: str = "qgyd2021/lip_service_4chan", | |
| max_input_len: int = 512, | |
| max_new_tokens: int = 512, | |
| top_p: float = 0.9, | |
| temperature: float = 0.35, | |
| repetition_penalty: float = 1.0, | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
| ): | |
| super(LipService4ChanHandler, self).__init__() | |
| self.model_name = model_name | |
| self.max_input_len = max_input_len | |
| self.max_new_tokens = max_new_tokens | |
| self.top_p = top_p | |
| self.temperature = temperature | |
| self.repetition_penalty = repetition_penalty | |
| self.device = device | |
| tokenizer = BertTokenizer.from_pretrained(model_name) | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| model = model.eval() | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| async def process(self, callback: dingtalk_stream.CallbackMessage): | |
| incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data) | |
| text = incoming_message.text.content.strip() | |
| logger.info("incoming message: {};".format(text)) | |
| answer = self.get_answer(text) | |
| self.reply_text(answer, incoming_message) | |
| logger.info("incoming message: {}; reply text: {};".format(text, answer)) | |
| return AckMessage.STATUS_OK, "OK" | |
| def remove_space_between_cn_en(text: str): | |
| splits = re.split(" ", text) | |
| if len(splits) < 2: | |
| return text | |
| result = "" | |
| for t in splits: | |
| if t == "": | |
| continue | |
| if re.search(f"[a-zA-Z0-9{string.punctuation}]$", result) and re.search("^[a-zA-Z0-9]", t): | |
| result += " " | |
| result += t | |
| else: | |
| if not result == "": | |
| result += t | |
| else: | |
| result = t | |
| if text.endswith(" "): | |
| result += " " | |
| return result | |
| def get_answer(self, text: str): | |
| prompt_encoded = self.tokenizer.__call__(text, add_special_tokens=True) | |
| input_ids: List[int] = prompt_encoded["input_ids"] | |
| input_ids = torch.tensor([input_ids], dtype=torch.long) | |
| input_ids = input_ids[:, -self.max_input_len:] | |
| self.tokenizer.eos_token = self.tokenizer.sep_token | |
| self.tokenizer.eos_token_id = self.tokenizer.sep_token_id | |
| # generate | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| input_ids=input_ids, | |
| max_new_tokens=self.max_new_tokens, | |
| do_sample=True, | |
| top_p=self.top_p, | |
| temperature=self.temperature, | |
| repetition_penalty=self.repetition_penalty, | |
| eos_token_id=self.tokenizer.sep_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| outputs = outputs.tolist()[0][len(input_ids[0]):] | |
| answer = self.tokenizer.decode(outputs) | |
| answer = answer.strip().replace(self.tokenizer.eos_token, "").strip() | |
| answer = self.remove_space_between_cn_en(answer) | |
| return answer | |
| def dingtalk_server(client: dingtalk_stream.DingTalkStreamClient): | |
| client.start_forever() | |
| def main(): | |
| args = get_args() | |
| # ding talk | |
| credential = dingtalk_stream.Credential( | |
| client_id=args.client_id, | |
| client_secret=args.client_secret, | |
| ) | |
| client = dingtalk_stream.DingTalkStreamClient(credential, logger) | |
| client.register_callback_handler( | |
| dingtalk_stream.chatbot.ChatbotMessage.TOPIC, | |
| LipService4ChanHandler( | |
| model_name=args.model_name | |
| ) | |
| ) | |
| # client.start_forever() | |
| # background task | |
| thread = Thread(target=dingtalk_server, kwargs={"client": client}) | |
| thread.start() | |
| with open(args.dingtalk_develop_md_file, "r", encoding="utf-8") as f: | |
| dingtalk_develop_md = f.read() | |
| # ui | |
| with gr.Blocks() as blocks: | |
| gr.Markdown(value=dingtalk_develop_md) | |
| blocks.queue().launch( | |
| share=False if platform.system() == "Windows" else False, | |
| server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", | |
| server_port=7860 | |
| ) | |
| return | |
| if __name__ == '__main__': | |
| main() | |