import os import torch import numpy as np import gradio as gr from random import sample from detoxify import Detoxify from datasets import load_dataset from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM from transformers import BloomTokenizerFast, BloomForCausalLM HF_AUTH_TOKEN = os.environ.get('hf_token' or True) DATASET = "allenai/real-toxicity-prompts" CHECKPOINTS = { "DistilGPT2 by HuggingFace 🤗" : "distilgpt2", "GPT-Neo 125M by EleutherAI 🤖" : "EleutherAI/gpt-neo-125M", "BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m", "Custom Model" : None } MODEL_CLASSES = { "DistilGPT2 by HuggingFace 🤗" : (GPT2LMHeadModel, GPT2Tokenizer), "GPT-Neo 125M by EleutherAI 🤖" : (GPTNeoForCausalLM, GPT2Tokenizer), "BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast), "Custom Model" : (AutoModelForCausalLM, AutoTokenizer), } def load_model(model_name, custom_model_path): try: model_class, tokenizer_class = MODEL_CLASSES[model_name] model_path = CHECKPOINTS[model_name] except KeyError: model_class, tokenizer_class = MODEL_CLASSES['Custom Model'] model_path = custom_model_path model = model_class.from_pretrained(model_path) tokenizer = tokenizer_class.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id model.eval() return model, tokenizer MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop def set_seed(seed, n_gpu): np.random.seed(seed) torch.manual_seed(seed) if n_gpu > 0: torch.cuda.manual_seed_all(seed) def adjust_length_to_model(length, max_sequence_length): if length < 0 and max_sequence_length > 0: length = max_sequence_length elif 0 < max_sequence_length < length: length = max_sequence_length # No generation bigger than model size elif length < 0: length = MAX_LENGTH # avoid infinite loop return length def generate(model_name, custom_model_path, input_sentence, length = 75, temperature = 0.7, top_k = 50, top_p = 0.95, seed = 42, no_cuda = False, num_return_sequences = 1, stop_token = '.' ): # load device #if not no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu") n_gpu = 0 if no_cuda else torch.cuda.device_count() # Set seed set_seed(seed, n_gpu) # Load model model, tokenizer = load_model(model_name, custom_model_path) model.to(device) #length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings) # Tokenize input encoded_prompt = tokenizer.encode(input_sentence, add_special_tokens=False, return_tensors='pt') encoded_prompt = encoded_prompt.to(device) input_ids = encoded_prompt # Generate output output_sequences = model.generate(input_ids=input_ids, max_length=length + len(encoded_prompt[0]), temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, num_return_sequences=num_return_sequences ) generated_sequences = list() for generated_sequence_idx, generated_sequence in enumerate(output_sequences): generated_sequence = generated_sequence.tolist() text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) #remove prompt text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] #remove all text after last occurence of stop_token text = text[:text.rfind(stop_token)+1] generated_sequences.append(text) return generated_sequences[0] def prepare_dataset(dataset): dataset = load_dataset(dataset, split='train') return dataset def load_prompts(dataset): prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))] return prompts def random_sample(prompt_list): random_sample = sample(prompt_list,10) return random_sample def show_dataset(dataset): raw_data = prepare_dataset(dataset) prompts = load_prompts(raw_data) return (gr.update(choices=random_sample(prompts), label='You can find below a random subset from the RealToxicityPrompts dataset', visible=True), gr.update(visible=True), prompts, ) def update_dropdown(prompts): return gr.update(choices=random_sample(prompts)) def show_search_bar(value): if value == 'Custom Model': return (value, gr.update(visible=True) ) else: return (value, gr.update(visible=False) ) def search_model(model_name): api = HfApi() model_args = ModelSearchArguments() filt = ModelFilter( task=model_args.pipeline_tag.TextGeneration, library=model_args.library.PyTorch) results = api.list_models(filter=filt, search=model_name) model_list = [model.modelId for model in results] return gr.update(visible=True, choices=model_list, label='Choose the model', ) def forward_model_choice(model_choice_path): return (model_choice_path, model_choice_path) def auto_complete(input, generated): output = input + ' ' + generated output_spans = [{'entity': 'OUTPUT', 'start': len(input), 'end': len(output)}] completed_prompt = {"text": output, "entities": output_spans} return completed_prompt def process_user_input(model, custom_model_path, input): warning = 'Please enter a valid prompt.' if input == None: generated = warning else: generated = generate(model, custom_model_path, input) generated_with_spans = auto_complete(input, generated) return ( generated_with_spans, gr.update(visible=True), gr.update(visible=True), input, generated ) def pass_to_textbox(input): return gr.update(value=input) def run_detoxify(text): results = Detoxify('original').predict(text) json_ready_results = {cat:float(score) for (cat,score) in results.items()} return json_ready_results def compute_toxi_output(output_text): scores = run_detoxify(output_text) return ( gr.update(value=scores, visible=True), gr.update(visible=True) ) def compute_change(input, output): change_percent = round(((float(output)-input)/input)*100, 2) return change_percent def compare_toxi_scores(input_text, output_scores): input_scores = run_detoxify(input_text) json_ready_results = {cat:float(score) for (cat,score) in input_scores.items()} compare_scores = { cat:compute_change(json_ready_results[cat], output_scores[cat]) for cat in json_ready_results for cat in output_scores } return ( gr.update(value=json_ready_results, visible=True), gr.update(value=compare_scores, visible=True) ) def show_flag_choices(): return gr.update(visible=True) def update_flag(flag_value): return (flag_value, gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) ) def upload_flag(*args): if flagging_callback.flag(list(args), flag_option = None): return gr.update(visible=True) with gr.Blocks() as demo: gr.Markdown("# Project Interface proposal") gr.Markdown("### Pick a text generation model below, write a prompt and explore the output") dataset = gr.Variable(value=DATASET) prompts_var = gr.Variable(value=None) input_var = gr.Variable(label="Input Prompt", value=None) output_var = gr.Variable(label="Output",value=None) model_choice = gr.Variable(label="Model", value=None) custom_model_path = gr.Variable(value=None) flag_choice = gr.Variable(label = "Flag", value=None) flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN, dataset_name = "fsdlredteam/flagged_2", organization = "fsdlredteam", private = True ) with gr.Row(): with gr.Column(scale=1): # input & prompts dataset exploration gr.Markdown("### 1. Select a prompt") input_text = gr.Textbox(label="Write your prompt below.", interactive=True, lines=4) gr.Markdown("— or —") inspo_button = gr.Button('Click here if you need some inspiration') prompts_drop = gr.Dropdown(visible=False) randomize_button = gr.Button('Show another subset', visible=False) with gr.Column(scale=1): # Model choice & output gr.Markdown("### 2. Evaluate output") model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()), label='Model', interactive=True) search_bar = gr.Textbox(label="Search model", interactive=True, visible=False) model_drop = gr.Dropdown(visible=False) generate_button = gr.Button('Submit your prompt') output_spans = gr.HighlightedText(visible=True, label="Generated text") flag_button = gr.Button("Report output here", visible=False) with gr.Row(): # Flagging with gr.Column(scale=1): flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",], label="What's wrong with the output ?", interactive=True, visible=False) user_comment = gr.Textbox(label="(Optional) Briefly describe the issue", visible=False, interactive=True) confirm_flag_button = gr.Button("Confirm report", visible=False) with gr.Row(): # Flagging success success_message = gr.Markdown("Your report has been successfully registered. Thank you!", visible=False,) with gr.Row(): # Toxicity buttons toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False) toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False) with gr.Row(): # Toxicity scores toxi_scores_input = gr.JSON(label = "Detoxify classification of your input", visible=False) toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output", visible=False) toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output", visible=False) inspo_button.click(fn=show_dataset, inputs=dataset, outputs=[prompts_drop, randomize_button, prompts_var]) prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text) randomize_button.click(fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop), model_radio.change(fn=show_search_bar, inputs=model_radio, outputs=[model_choice,search_bar]) search_bar.submit(fn=search_model, inputs=search_bar, outputs=model_drop, show_progress=True) model_drop.change(fn=forward_model_choice, inputs=model_drop, outputs=[model_choice,custom_model_path]) generate_button.click(fn=process_user_input, inputs=[model_choice, custom_model_path, input_text], outputs=[output_spans, toxi_button, flag_button, input_var, output_var], show_progress=True) toxi_button.click(fn=compute_toxi_output, inputs=output_var, outputs=[toxi_scores_output, toxi_button_compare], show_progress=True) toxi_button_compare.click(fn=compare_toxi_scores, inputs=[input_text, toxi_scores_output], outputs=[toxi_scores_input, toxi_scores_compare], show_progress=True) flag_button.click(fn=show_flag_choices, inputs=None, outputs=flag_radio) flag_radio.change(fn=update_flag, inputs=flag_radio, outputs=[flag_choice, confirm_flag_button, user_comment, flag_button]) flagging_callback.setup([input_var, output_var, model_choice, user_comment, flag_choice], "flagged_data_points") confirm_flag_button.click(fn = upload_flag, inputs = [input_var, output_var, model_choice, user_comment, flag_choice], outputs=success_message) #demo.launch(debug=True) if __name__ == "__main__": demo.launch(enable_queue=False)