import sys import fire import gradio as gr import json import torch from peft import PeftModel from transformers import GenerationConfig, AutoModel, AutoTokenizer import mdtex2html import re from textwrap import indent from huggingface_hub import login access_token_read = "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" access_token_write = "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" login(token = access_token_read) if torch.cuda.is_available(): device = "cuda" else: device = "cpu" try: if torch.backends.mps.is_available(): device = "mps" except: # noqa: E722 pass with open("node_map.json") as json_file: data = json.load(json_file) node_type_map = data.get('node_type_map') node_name_map = data.get('node_name_map') def main( base_model: str = "THUDM/chatglm-6b", lora_weights: str = "JIAFENG7/BFF-workflow-glm", share_gradio: bool = True, ): assert ( base_model ), "Please specify a --base_model, e.g. --base_model='THUDM/chatglm-6b'" tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) if device == "cuda": torch.set_default_tensor_type(torch.cuda.HalfTensor) model = AutoModel.from_pretrained(base_model, trust_remote_code=True).half().cuda() else: model = AutoModel.from_pretrained(base_model, trust_remote_code=True).float() model = PeftModel.from_pretrained(model, lora_weights, torch_dtype=torch.float16) model.eval() def postprocess(self, y): if y is None: return [] for i, (message, response) in enumerate(y): y[i] = ( None if message is None else mdtex2html.convert(message), None if response is None else mdtex2html.convert(response) if isinstance(response, str) else format_json( json.dumps(response, indent=4, sort_keys=True)), ) return y gr.Chatbot.postprocess = postprocess def parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): print('[line]:', line) if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'
'
                else:
                    lines[i] = f'
' else: if i > 0: # if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = ("
" if i > 0 else "") + line text = "".join(lines) return text def format_json(json_data): key_pattern = re.compile("\"(.*)\"(?=:)") value_pattern = re.compile("(?<=: )(\"(.*)\"|\\d+)") a = re.sub(key_pattern, lambda m: f'{m.group(1)}', json_data) b = re.sub(value_pattern, lambda m: f'{m.group(1)}', a) return f'
{b}
' def parse_parallel_str(instruction, nodes, connections, type): result = {} if instruction.find("and") == -1: match_node = re.search(r'([a-zA-Z]+)(\d?)', instruction.strip()) if match_node: match_node_name = match_node.group(1) match_node_version = match_node.group(2) node_name = node_name_map[match_node_name] + match_node_version if f"{node_name}" not in connections: result.update({ f"{node_name}": { "type": "main", "index": 0, "main": [[]] } }) else: result.update({ f"{node_name}": connections.get(f"{node_name}") }) else: merge_name = '-'.join([node.strip() for node in instruction.split('and')]) for i, node in enumerate(instruction.split('and')): match_node = re.search(r'([a-zA-Z]+)(\d?)', node.strip()) if match_node: match_node_name = match_node.group(1) match_node_version = match_node.group(2) node_name = node_name_map[match_node_name] + match_node_version result.update({ f"{node_name}": { "main": [[{ "node": f"Merge-Node-{merge_name}", "type": "main", "index": i }] if type == 'input' else []] }, }) if type == 'input': if f"Merge-Node-{merge_name}" not in connections: result.update({ f"Merge-Node-{merge_name}": { "main": [[]] } }) else: result.update({ f"Merge-Node-{merge_name}": connections.get(f"Merge-Node-{merge_name}") }) nodes.append({ "node": "BFF-Merge", "name": f"Merge-Node-{merge_name}", "param": "mode: chooseBranch\noutput: empty" }) return result def parse_serial_str(input_nodes, output_nodes): try: if len(list(input_nodes.keys())) != 1: for input_key, input_value in input_nodes.items(): if not re.match(r"Merge-Node", input_key): continue for output_key, output_value in output_nodes.items(): input_value.get("main")[0].append({ "index": 0, "node": output_key, "type": "main" }) else: for input_key, input_value in input_nodes.items(): for output_key, output_value in output_nodes.items(): input_value.get("main")[0].append({ "index": 0, "node": output_key, "type": "main" }) return input_nodes except: pass def normalize_input(instruction, input): nodes = [] connections = {} # get nodes match_node_pattern = f'({data.get("node_name")})(\d)' nodes_from_instruction = re.findall(fr"{match_node_pattern}", str(input)) node_list = ["".join(node) for node in nodes_from_instruction] if nodes_from_instruction else [] for i, node in enumerate(node_list): start_node_line = re.search(fr"#{node}(.*)\n", input) start_node_index = 0 if not start_node_line: pass else: start_node_line = start_node_line.group(0) start_node_index = input.index(start_node_line) + len(start_node_line) match_node = re.search(r'([a-zA-Z]+)(\d?)', node.strip()) if match_node: match_node_name = match_node.group(1) match_node_version = match_node.group(2) next_node_index = node_list.index(node) + 1 end_node_index = len(input) - 1 if next_node_index == len(node_list) else input.find( f"#{node_list[next_node_index]}") params = re.split(rf"{match_node_pattern}.*?\n", str(input[start_node_index:end_node_index])) nodes.append({ "node": f"{node_type_map[match_node_name]}", "name": f"{node_name_map[match_node_name]}{match_node_version}", "param": list(filter(lambda x: x.strip() != '', params))[0], }) else: pass # get connections for split_comma_node in instruction.split(','): if split_comma_node.find("output") == -1: continue [split_input, split_output] = split_comma_node.split("output") merge_input = parse_parallel_str(split_input, nodes, connections, 'input') merge_output = parse_parallel_str(split_output, nodes, connections, 'output') connections.update(parse_serial_str(merge_input, merge_output)) return nodes, connections def predict( instruction, input_content, temperature, top_p, top_k, max_new_tokens, ): input_text = f"Instruction: {instruction}\n" if input_content is not None: input_text += f"Input: {input_content}\n" input_text += "Answer: " print('---', input_text) ids = tokenizer.encode(input_text) input_ids = torch.LongTensor([ids]) inputs = input_ids.to(device) output = model.generate( input_ids=inputs, max_length=max_new_tokens, do_sample=True, temperature=temperature, top_p = top_p, top_k = top_k ) decode_output = tokenizer.decode(output[0]).split("Answer:")[1] return decode_output def evaluate( instruction, input_content=None, temperature=0.1, top_p=0.75, top_k=40, max_new_tokens=256, history = [] ): # import ipdb; ipdb.set_trace() with torch.autocast(device): merged_nodes, connections = normalize_input(instruction, input_content) nodes = []; if len(merged_nodes) == 0: output = predict( instruction, input_content, temperature, top_p, top_k, max_new_tokens, ) print('[normal output]:', output) else: for node_data in merged_nodes: merged_instruction = node_data["node"] merged_input = node_data["param"] + "\n" + f"name: \"{node_data['name']}\"" output = predict( merged_instruction, merged_input, temperature, top_p, top_k, max_new_tokens, ) print('[node output]:', output) nodes.append(json.loads(output[output.find("{"):]) if output else [{"error": "errorFormat"}]) print('[nodes output]:', nodes) request = f"Summary: \n" + \ f"{indent(instruction, ' ')} \n" + \ f"Details: \n" + \ f"{indent(input_content, ' ') if input_content is not None else ''}" response = {'nodes': nodes, 'connections': connections} if len(merged_nodes) > 0 else output history.append((parse_text(request), response)) return history, history gr.Interface( fn=evaluate, inputs=[ gr.components.Textbox( lines=2, label="Summary", placeholder="Tell me the task you want to do with bff.", ), gr.components.Textbox(lines=2, label="Details", placeholder="""Example: #service1 gets ubtc_trip_in_aidsid key from cookies serviceName: 'userInfoReportService' serviceCode: '18768' method: 'reportOrderAttribution' #service2 transfers all of cookies serviceName: 'userInfoReportService' serviceCode: '18768' """), 'state' # gr.components.Slider( # minimum=0, maximum=1, value=0.1, label="Temperature", # info="Controls randomness, higher values increase diversity." # ), # gr.components.Slider( # minimum=0, maximum=1, value=0.75, label="Top p", # info="The cumulative probability cutoff for token selection. Lower values mean sampling from a smaller, more top-weighted nucleus." # ), # gr.components.Slider( # minimum=0, maximum=100, step=1, value=40, label="Top k", # info="Sample from the k most likely next tokens at each step. Lower k focuses on higher probability tokens." # ), # # gr.components.Slider( # minimum=1, maximum=4, step=1, value=4, label="Beams" # ), # gr.components.Slider( # minimum=1, maximum=2000, step=1, value=512, label="Max tokens" # ) ], # outputs=['json'], outputs=[ gr.Chatbot(), 'state' ], examples=[ ["How many nodes in tripflow?"], ["How many parameters in Cargo?"], ["How many parameters in shark?"] ], title="Workflow BFF Chat", description = """ The bot was trained to answer questions based on tripflow. Ask anything!

""", css="style.css" ).launch(share=share_gradio) """ # testing code for readme for instruction in [ "What is the n8n", "Tell me about the president of Mexico in 2019.", "Tell me about the king of France in 2019.", "List all Canadian provinces in alphabetical order.", "Write a Python program that prints the first 10 Fibonacci numbers.", "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501 "Tell me five words that rhyme with 'shock'.", "Translate the sentence 'I have no mouth but I must scream' into Spanish.", "Count up from 1 to 500.", ]: print("Instruction:", instruction) print("Response:", evaluate(instruction)) print() """ if __name__ == "__main__": fire.Fire(main)