import re import copy import global_vars from discordbot.utils import ( get_chat_manager, get_global_context ) from discordbot.flags import ( parse_known_flags, known_flags_def ) from pingpong import PingPong from pingpong.context import CtxLastWindowStrategy from discord import NotFound from transformers import GenerationConfig def sync_task(prompt, args): input_ids = global_vars.tokenizer(prompt, return_tensors="pt").input_ids.to(global_vars.device) gen_config = copy.deepcopy(global_vars.gen_config) if args["max-new-tokens"] is not None: gen_config.max_new_tokens = args["max-new-tokens"] if args["temperature"] is not None: gen_config.temperature = args["temperature"] if args["do-sample"] is not None: gen_config.do_sample = args["do-sample"] if args["top-p"] is not None: gen_config.top_p = args["top-p"] generated_ids = global_vars.model.generate( input_ids=input_ids, generation_config=gen_config ) response = global_vars.tokenizer.decode(generated_ids[0][input_ids.shape[-1]:]) return response async def build_prompt(ppmanager, ctx_include=True, win_size=3): dummy_ppm = copy.deepcopy(ppmanager) if ctx_include: dummy_ppm.ctx = get_global_context(global_vars.model_type) else: dummy_ppm.ctx = "" lws = CtxLastWindowStrategy(win_size) return lws(dummy_ppm) async def build_ppm(msg, msg_content, username, user_id): ppm = get_chat_manager(global_vars.model_type) channel = msg.channel user_msg = msg_content packs = [] partial_count = 0 total_count = 0 while True: try: if msg.reference is not None: ref_id = msg.reference.message_id msg = await channel.fetch_message(ref_id) msg_content = msg.content.replace(f"@{username} ", "").replace(f"<@{user_id}> ", "") try: idx = msg_content.index("💬") msg_content = msg_content[idx+1:].strip() except: msg_content = msg_content.strip() msg_content, _ = parse_known_flags( msg_content, known_flags_def, global_vars.gen_config ) print(msg_content) packs.insert( 0, msg_content ) partial_count = partial_count + 1 if partial_count >= 2: partial_count = 0 else: break except NotFound: break for idx in range(0, len(packs), 2): ppm.add_pingpong( PingPong(packs[idx], packs[idx+1]) ) ppm.add_pingpong( PingPong(user_msg, "") ) print(ppm.pingpongs) return ppm