import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer, STOKEStreamer from threading import Thread import json import torch import matplotlib.pyplot as plt from matplotlib.colors import to_hex import numpy as np import os import urllib.request import zipfile class MLP(torch.nn.Module): def __init__(self, input_dim, output_dim, hidden_dim=1024, layer_id=0, cuda=False): super(MLP, self).__init__() self.fc1 = torch.nn.Linear(input_dim, hidden_dim) # Input layer to hidden layer self.fc3 = torch.nn.Linear(hidden_dim, output_dim) # Hidden layer to output layer self.layer_id = layer_id if cuda: self.device = "cuda" else: self.device = "cpu" self.to(self.device) def forward(self, x): x = torch.flatten(x, start_dim=1) x = torch.relu(self.fc1(x)) x = self.fc3(x) return torch.argmax(x, dim=-1).cpu().detach(), torch.softmax(x, dim=-1).cpu().detach() def map_value_to_color(value, colormap_name='tab20c'): """ Map a value between 0 and 1 to a CSS color using a Python colormap. Args: value (float): A value between 0 and 1. colormap_name (str): The name of the colormap to use (e.g., 'viridis'). Returns: str: A CSS color string in the form 'rgb(r, g, b)'. """ # Ensure the value is within the range [0, 1] value = np.clip(value, 0.0, 1.0) # Get the colormap colormap = plt.get_cmap(colormap_name) # Map the value to a color rgba_color = colormap(value) # Convert the RGBA color to CSS format css_color = to_hex(rgba_color) return css_color + "88" @st.cache_resource def get_model_and_tokenizer(name): # Load pre-trained model and tokenizer tok = AutoTokenizer.from_pretrained(name) model = AutoModelForCausalLM.from_pretrained(name) return model, tok @st.cache_resource def get_classifiers_for_model(att_size, emb_size, device, config_paths): classifier_token = None #print(config) config = { "classifier_token": json.load(open(os.path.join(config_paths["classifier_token"], "config.json"), "r")), "classifier_span": json.load(open(os.path.join(config_paths["classifier_span"], "config.json"), "r")) } layer_id = config["classifier_token"]["layer"] classifier_span = MLP(att_size, 2, hidden_dim=config["classifier_span"]["classifier_dim"]).to(device) classifier_span.load_state_dict(torch.load(os.path.join(config_paths["classifier_span"], "checkpoint.pt"), map_location=device)) classifier_token = MLP(emb_size, len(config["classifier_token"]["label_map"]), layer_id=layer_id, hidden_dim=config["classifier_token"]["classifier_dim"]).to(device) classifier_token.load_state_dict(torch.load(os.path.join(config_paths["classifier_token"], "checkpoint.pt"), map_location=device)) print(sum(p.numel() for p in classifier_span.parameters()), sum(p.numel() for p in classifier_token.parameters())) return classifier_span, classifier_token, config["classifier_token"]["label_map"] def get_available_models(): available_models = [] for model_name in ["gpt2", "gpt2-xl"]: if os.path.isfile(f"checkpoints/{model_name}/config.json"): available_models.append(model_name) return available_models def get_available_datasets(model_name): available_datasets = [] config_path = f"checkpoints/{model_name}/config.json" if os.path.isfile(config_path): with open(config_path, "r") as f: config = json.load(f) # Assuming datasets are keys in config.json available_datasets = list(config.keys()) return available_datasets def download_and_extract_zip(url, extract_dir): # Determine the parent directory parent_dir = os.path.split(os.path.dirname(extract_dir))[-2] print(parent_dir) # Download the zip file to the parent directory zip_file_path = os.path.join(parent_dir, "data.zip") urllib.request.urlretrieve(url, zip_file_path) # Extract the zip file with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: zip_ref.extractall(parent_dir) # Remove the zip file os.remove(zip_file_path) def find_datasets_and_model_ids(root_dir): datasets = {} # Check if the root directory exists if not os.path.exists(root_dir): # If root directory doesn't exist, download a zip file and unpack it print("Root directory doesn't exist. Downloading zip file...") url = "https://drive.usercontent.google.com/download?id=1i5UkWikRZGhsbv21ZZSjEZl6-VwNC0lp&export=download&authuser=0&confirm=t&uuid=c33ef625-9ec8-4dbf-bdb0-ad6cabc70a33&at=APZUnTWWJSzU9pV2XV-sMPtbgdgj%3A1711096726305" # Replace with your actual download URL download_and_extract_zip(url, root_dir) print("Zip file downloaded and unpacked successfully.") for root, dirs, files in os.walk(root_dir): if 'config.json' in files and 'stoke_config.json' in files: config_path = os.path.join(root, 'config.json') stoke_config_path = os.path.join(root, 'stoke_config.json') with open(config_path, 'r') as f: config_data = json.load(f) model_id = config_data.get('model_id') if model_id: dataset_name = os.path.basename(os.path.dirname(config_path)) with open(stoke_config_path, 'r') as f: stoke_config_data = json.load(f) if model_id: dataset_name = os.path.basename(os.path.dirname(stoke_config_path)) datasets.setdefault(model_id, {})[dataset_name] = stoke_config_data return datasets # Main content st.title("Playground") # Sidebar for model and dataset selection with st.sidebar: st.subheader("Model and Dataset Selection") datasets = find_datasets_and_model_ids("data/") available_models = datasets.keys() print(datasets) if available_models: model_selection = st.selectbox("Select Model", available_models) else: st.error("No models available. Please check the file paths.") # Select dataset based on selected model available_datasets = datasets[model_selection] if available_datasets: dataset_selection = st.selectbox("Select Dataset", sorted(available_datasets)) else: st.error("No datasets available for the selected model.") # Select dataset based on selected model available_configs = datasets[model_selection][dataset_selection] if available_configs: config_selection = st.selectbox("Select Config", available_configs.keys()) else: st.error("No configs available for the selected dataset.") # Load model and streamer based on selections model, tok = get_model_and_tokenizer(model_selection) if torch.cuda.is_available(): model.cuda() classifier_span, classifier_token, label_map = get_classifiers_for_model(model.config.n_head*model.config.n_layer, model.config.n_embd, model.device, datasets[model_selection][dataset_selection][config_selection]) streamer = STOKEStreamer(tok, classifier_token, classifier_span) new_tags = label_map def filter_spans(spans_and_values): if spans_and_values == []: return [], [] # Create a dictionary to store spans based on their second index values span_dict = {} spans, values = [x[0] for x in spans_and_values], [x[1] for x in spans_and_values] # Iterate through the spans and update the dictionary with the highest value for span, value in zip(spans, values): start, end = span if start > end or end - start > 15 or start == 0: continue current_value = span_dict.get(end, None) if current_value is None or current_value[1] < value: span_dict[end] = (span, value) if span_dict == {}: return [], [] # Extract the filtered spans and values filtered_spans, filtered_values = zip(*span_dict.values()) return list(filtered_spans), list(filtered_values) def remove_overlapping_spans(spans): # Sort the spans based on their end points sorted_spans = sorted(spans, key=lambda x: x[0][1]) non_overlapping_spans = [] last_end = float('-inf') # Iterate through the sorted spans for span in sorted_spans: start, end = span[0] value = span[1] # If the current span does not overlap with the previous one if start >= last_end: non_overlapping_spans.append(span) last_end = end else: # If it overlaps, choose the one with the highest value existing_span_index = -1 for i, existing_span in enumerate(non_overlapping_spans): if existing_span[0][1] <= start: existing_span_index = i break if existing_span_index != -1 and non_overlapping_spans[existing_span_index][1] < value: non_overlapping_spans[existing_span_index] = span return non_overlapping_spans def generate_html_no_overlap(tokenized_text, spans): current_index = 0 html_content = "" for (span_start, span_end), value in spans: # Add text before the span html_content += "".join(tokenized_text[current_index:span_start]) # Add the span with underlining html_content += "" html_content += "".join(tokenized_text[span_start:span_end]) html_content += " " current_index = span_end # Add any remaining text after the last span html_content += "".join(tokenized_text[current_index:]) return html_content css = """ """ def generate_html_spanwise(token_strings, tokenwise_preds, spans, tokenizer): # spanwise annotated text annotated = [] span_ends = -1 in_span = False out_of_span_tokens = [] for i in reversed(range(len(tokenwise_preds))): if in_span: if i >= span_ends: continue else: in_span = False predicted_class = "" style = "" span = None for s in spans: if s[1] == i+1: span = s if tokenwise_preds[i] != 0 and span is not None: predicted_class = f"highlight spanhighlight" style = f"background-color: {map_value_to_color((tokenwise_preds[i]-1)/(len(new_tags)-1))}" if tokenizer.convert_tokens_to_string([token_strings[i]]).startswith(" "): annotated.append("Ġ") span_opener = f"Ġ".replace(" ", "Ġ") span_end = f"{new_tags[tokenwise_preds[i]]}" annotated.extend(out_of_span_tokens) out_of_span_tokens = [] span_ends = span[0] in_span = True annotated.append(span_end) annotated.extend([token_strings[x] for x in reversed(range(span[0], span[1]))]) annotated.append(span_opener) else: out_of_span_tokens.append(token_strings[i]) annotated.extend(out_of_span_tokens) return [x for x in reversed(annotated)] # Define function to generate text based on input def generate_text(generation_kwargs, output_field): # Start text generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Display generated text as it becomes available text_tokenwise = "" text_spans = "" removed_spans = "" tags = [] spans = [] for new_text in streamer: if new_text[1] is not None and new_text[2] != ['']: text_tokenwise = "" tags.extend(new_text[1]) spans.extend(new_text[-1]) # Tokenwise Classification for tk, pred in zip(new_text[2],tags): if pred != 0: style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}" if tk.startswith(" "): text_tokenwise += " " text_tokenwise += f"{tk}" else: text_tokenwise += tk # Span Classification text_spans = "" if len(spans) > 0: filtered_spans = remove_overlapping_spans(spans) text_spans = generate_html_no_overlap(new_text[2], filtered_spans) if len(spans) - len(filtered_spans) > 0: removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap." else: for tk in new_text[2]: text_spans += f"{tk}" # Spanwise Classification annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok) generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "") output_field.empty() output = f"{css}" output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n
" output += "
Show tokenwise classification\n" + text_tokenwise.replace("\n", " ").replace("$", "\\$") #output += "
Show spans\n" + text_spans.replace("\n", " ").replace("$", "\\$") if removed_spans != "": output += f"

({removed_spans})" output += "
" output_field.write(output, unsafe_allow_html=True) # Input field input_text = st.text_area("Enter prompt for completion", "") # Sidebar for customizing generation parameters with st.sidebar: st.subheader("Generation Parameters") max_new_tokens = st.slider("Max New Tokens", min_value=1, max_value=100, value=30) repetition_penalty = st.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2) do_sample = st.checkbox("Do Sample", value=True) temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.0) top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.3) top_k = st.slider("Top-k", min_value=10, max_value=100, value=50) typical_p = st.slider("Typical P", min_value=0.1, max_value=1.0, value=1.0) # Button to generate text if st.button("Generate"): if input_text: output_field = st.empty() inputs = tok([" " + input_text], return_tensors="pt").to(model.device) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, typical_p=typical_p) generate_text(generation_kwargs, output_field) else: st.warning("Please enter some text first.")