Spaces:
Running
Running
import streamlit as st | |
import plotly.graph_objects as go | |
from transformers import pipeline | |
import re | |
import time | |
import requests | |
from PIL import Image | |
import itertools | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import rgb2hex | |
import matplotlib | |
from matplotlib.colors import ListedColormap, rgb2hex | |
import ipywidgets as widgets | |
from IPython.display import display, HTML | |
import re | |
import pandas as pd | |
from pprint import pprint | |
from tenacity import retry | |
from tqdm import tqdm | |
import tiktoken | |
import scipy.stats | |
import inseq | |
import torch | |
from transformers import AutoModelForCausalLM | |
from transformers import GPT2LMHeadModel | |
import tiktoken | |
import seaborn as sns | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# from colorama import Fore, Style | |
import openai # for OpenAI API calls | |
###################################### | |
def find_indices(arr, target): | |
indices = [] | |
start_index = None | |
for i, element in enumerate(arr): | |
if target in element: | |
if start_index is None: | |
start_index = i | |
else: | |
indices.append((start_index, i - 1)) | |
start_index = i | |
if start_index is not None: | |
indices.append((start_index, len(arr) - 1)) | |
return indices | |
###################################### | |
import streamlit as st | |
def colorize_tokens(token_data, sentence): | |
colored_sentence = "" | |
start = 0 | |
for token in token_data: | |
entity_group = token["entity_group"] | |
word = token["word"] | |
tag = f"[{entity_group}]" | |
tag_color = tag_colors.get(entity_group, "white") # Default to white if color not found | |
colored_chunk = f'<span style="color:black;background-color:{tag_color}">{word} {tag}</span>' | |
colored_sentence += sentence[start:token["start"]] + colored_chunk | |
start = token["end"] | |
# Add the remaining part of the sentence | |
colored_sentence += sentence[start:] | |
return colored_sentence | |
# Define colors for the tags | |
tag_colors = { | |
"ADJP": "#8F6B9F", # Blue | |
"ADVP": "#7275A7", # Green | |
"CONJP": "#5BA4BB", # Red | |
"INTJ": "#95CA73", # Cyan | |
"LST": "#DFDA70", # Magenta | |
"NP": "#EFBC65", # Yellow | |
"PP": "#FC979B", # Purple | |
"PRT": "#F1C5C1", # Dark Blue | |
"SBAR": "#FAEBE8", # Dark Green | |
"VP": "#90DFD2", # Dark Cyan | |
} | |
################## | |
################### | |
def generate_tagged_sentence(sentence, entity_tags): | |
# Create a list to hold the tagged tokens | |
tagged_tokens = [] | |
# Process the entity tags to annotate the sentence | |
for tag in entity_tags: | |
start = tag['start'] | |
end = tag['end'] | |
if end<len(sentence)-1: | |
token = sentence[start:end] # Adjust for 0-based indexing | |
else: | |
token = sentence[start:end+1] | |
tag_name = f"[{tag['entity_group']}]" | |
tagged_tokens.append(f"{token} {tag_name}") | |
# Return the tagged sentence | |
return " ".join(tagged_tokens) | |
def replace_pp_with_pause(sentence, entity_tags): | |
# Create a list to hold the tagged tokens | |
tagged_tokens = [] | |
# Process the entity tags to replace [PP] with [PAUSE] | |
for tag in entity_tags: | |
start = tag['start'] | |
end = tag['end'] | |
if end < len(sentence) - 1: | |
token = sentence[start:end] # Adjust for 0-based indexing | |
else: | |
token = sentence[start:end + 1] | |
tag_name = '[PAUSE]' if tag['entity_group'] == 'PP' else '' | |
tagged_tokens.append(f"{token}{tag_name}") | |
print(tagged_tokens) | |
# Return the sentence with [PAUSE] replacement and spaces preserved | |
modified_words = [] | |
for i, word in enumerate(tagged_tokens): | |
if word.startswith("'s"): | |
modified_words[-1] = modified_words[-1] + word | |
else: | |
modified_words.append(word) | |
output = " ".join(modified_words) | |
return output | |
def get_split_sentences(sentence, entity_tags): | |
split_sentences = [] | |
# Initialize a variable to hold the current sentence | |
current_sentence = [] | |
# Process the entity tags to split the sentence | |
for tag in entity_tags: | |
if tag['entity_group'] == 'PP': | |
start = tag['start'] | |
end = tag['end'] | |
if end<len(sentence)-1: | |
token = sentence[start:end] # Adjust for 0-based indexing | |
else: | |
token = sentence[start:end+1] | |
current_sentence.append(token) | |
split_sentences.append(" ".join(current_sentence)) | |
current_sentence = [] # Reset the current sentence | |
else: | |
start = tag['start'] | |
end = tag['end'] | |
if end<len(sentence)-1: | |
token = sentence[start:end] # Adjust for 0-based indexing | |
else: | |
token = sentence[start:end+1] | |
current_sentence.append(token) | |
# If the sentence ends without a [PAUSE] token, add the final sentence | |
if current_sentence: | |
split_sentences.append("".join(current_sentence)) | |
return split_sentences | |
################## | |
###################################### | |
st.set_page_config(page_title="Hallucination", layout="wide") | |
st.title(':blue[Sorry come again! This time slowly, please]') | |
st.header("Rephrasing LLM Prompts for Better Comprehension Reduces :blue[Hallucination]") | |
############################ | |
video_file1 = open('machine.mp4', 'rb') | |
video_file2 = open('Pause 3 Out1.mp4', 'rb') | |
video_bytes1 = video_file1.read() | |
video_bytes2 = video_file2.read() | |
col1a, col1b = st.columns(2) | |
with col1a: | |
st.caption("Original") | |
st.video(video_bytes1) | |
with col1b: | |
st.caption("Paraphrased and added [PAUSE]") | |
st.video(video_bytes2) | |
############################# | |
HF_SPACES_API_KEY = st.secrets["HF_token"] | |
#API_URL = "https://api-inference.huggingface.co/models/openlm-research/open_llama_3b" | |
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom" | |
headers = {"Authorization": HF_SPACES_API_KEY} | |
def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
API_URL_chunk = "https://api-inference.huggingface.co/models/flair/chunk-english" | |
def query_chunk(payload): | |
response = requests.post(API_URL_chunk, headers=headers, json=payload) | |
return response.json() | |
from tenacity import ( | |
retry, | |
stop_after_attempt, | |
wait_random_exponential, | |
) # for exponential backoff | |
# openai.api_key = f"{st.secrets['OpenAI_API']}" | |
# model_engine = "gpt-4" | |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) | |
# def get_answers(prompt): | |
# completion = openai.ChatCompletion.create( | |
# model = 'gpt-3.5-turbo', | |
# messages = [ | |
# {'role': 'user', 'content': prompt} | |
# ], | |
# temperature = 0,max_tokens= 200, | |
# ) | |
# return completion['choices'][0]['message']['content'] | |
prompt = '''Generate a story from the given text. | |
Text : ''' | |
# paraphrase_prompt = '''Rephrase the given text: ''' | |
# _gpt3tokenizer = tiktoken.get_encoding("cl100k_base") | |
########################## | |
# def render_heatmap(original_text, importance_scores_df): | |
# # Extract the importance scores | |
# importance_values = importance_scores_df['importance_value'].values | |
# # Check for division by zero during normalization | |
# min_val = np.min(importance_values) | |
# max_val = np.max(importance_values) | |
# if max_val - min_val != 0: | |
# normalized_importance_values = (importance_values - min_val) / (max_val - min_val) | |
# else: | |
# normalized_importance_values = np.zeros_like(importance_values) # Fallback: all-zero array | |
# # Generate a colormap for the heatmap | |
# cmap = matplotlib.colormaps['inferno'] | |
# # Function to determine text color based on background color | |
# def get_text_color(bg_color): | |
# brightness = 0.299 * bg_color[0] + 0.587 * bg_color[1] + 0.114 * bg_color[2] | |
# if brightness < 0.5: | |
# return 'white' | |
# else: | |
# return 'black' | |
# # Initialize pointers for the original text and token importance | |
# original_pointer = 0 | |
# token_pointer = 0 | |
# # Create an HTML representation | |
# html = "" | |
# while original_pointer < len(original_text): | |
# token = importance_scores_df.loc[token_pointer, 'token'] | |
# if original_pointer == original_text.find(token, original_pointer): | |
# importance = normalized_importance_values[token_pointer] | |
# rgba = cmap(importance) | |
# bg_color = rgba[:3] | |
# text_color = get_text_color(bg_color) | |
# html += f'<span style="background-color: rgba({int(bg_color[0]*255)}, {int(bg_color[1]*255)}, {int(bg_color[2]*255)}, 1); color: {text_color};">{token}</span>' | |
# original_pointer += len(token) | |
# token_pointer += 1 | |
# else: | |
# html += original_text[original_pointer] | |
# original_pointer += 1 | |
# #display(HTML(html)) | |
# st.markdown(html, unsafe_allow_html=True) | |
def render_heatmap(original_text, importance_scores_df): | |
# Extract the importance scores | |
importance_values = importance_scores_df['importance_value'].values | |
# Check for division by zero during normalization | |
min_val = np.min(importance_values) | |
max_val = np.max(importance_values) | |
if max_val - min_val != 0: | |
normalized_importance_values = (importance_values - min_val) / (max_val - min_val) | |
else: | |
normalized_importance_values = np.zeros_like(importance_values) # Fallback: all-zero array | |
# Generate a colormap for the heatmap (use "Blues") | |
cmap = matplotlib.cm.get_cmap('Blues') | |
# Function to determine text color based on background color | |
def get_text_color(bg_color): | |
brightness = 0.299 * bg_color[0] + 0.587 * bg_color[1] + 0.114 * bg_color[2] | |
if brightness < 0.5: | |
return 'white' | |
else: | |
return 'black' | |
# Initialize pointers for the original text and token importance | |
original_pointer = 0 | |
token_pointer = 0 | |
# Create an HTML representation | |
html = "" | |
while original_pointer < len(original_text): | |
token = importance_scores_df.loc[token_pointer, 'token'] | |
if original_pointer == original_text.find(token, original_pointer): | |
importance = normalized_importance_values[token_pointer] | |
rgba = cmap(importance) | |
bg_color = rgba[:3] | |
text_color = get_text_color(bg_color) | |
html += f'<span style="background-color: rgba({int(bg_color[0]*255)}, {int(bg_color[1]*255)}, {int(bg_color[2]*255)}, 1); color: {text_color};">{token}</span>' | |
original_pointer += len(token) | |
token_pointer += 1 | |
else: | |
html += original_text[original_pointer] | |
original_pointer += 1 | |
st.markdown(html, unsafe_allow_html=True) | |
########################## | |
# Create selectbox | |
prompt_list=["Which individuals possessed the ships that were part of the Boston Tea Party?", | |
"Freddie Frith", "Robert used PDF for his math homework." | |
] | |
options = [f"Prompt #{i+1}: {prompt_list[i]}" for i in range(3)] + ["Another Prompt..."] | |
selection = st.selectbox("Choose a prompt from the dropdown below . Click on :blue['Another Prompt...'] , if you want to enter your own custom prompt.", options=options) | |
check=[] | |
# if selection == "Another Prompt...": | |
# otherOption = st.text_input("Enter your custom prompt...") | |
# if otherOption: | |
# st.caption(f""":white_check_mark: Your input prompt is : {otherOption}""") | |
# st.caption(':green[Kindly hold on for a few minutes while the AI text is being generated]') | |
# check=otherOption | |
# st.caption(f"""{check}""") | |
# else: | |
# result = re.split(r'#\d+:', selection, 1) | |
# if result: | |
# st.caption(f""":white_check_mark: Your input prompt is : {result[1]}""") | |
# st.caption(':green[Kindly hold on for a few minutes while the AI text is being generated]') | |
# check=result[1] | |
if selection == "Another Prompt...": | |
check = st.text_input("Enter your custom prompt...") | |
check = " " + check | |
if check: | |
st.caption(f""":white_check_mark: Your input prompt is : {check}""") | |
st.caption(':green[Kindly hold on for a few minutes while the AI text is being generated]') | |
# check=otherOption | |
# st.caption(f"""{check}""") | |
else: | |
check = re.split(r'#\d+:', selection, 1)[1] | |
if check: | |
st.caption(f""":white_check_mark: Your input prompt is : {check}""") | |
st.caption(':green[Kindly hold on for a few minutes while the AI text is being generated]') | |
# check=result[1] | |
# @st.cache_data | |
def load_chunk_model(check): | |
iden=['error'] | |
while 'error' in iden: | |
time.sleep(1) | |
try: | |
output = query_chunk({"inputs": f"""{check}""",}) | |
iden = output # Update 'check' with the new result | |
except Exception as e: | |
print(f"An exception occurred: {e}") | |
return output | |
################################## | |
# st.write(entity_tags) | |
################################## | |
# colored_output, _ = colorize_tokens(load_chunk_model(check),check) | |
# st.caption('The below :blue[NER] tags are found for orginal prompt:') | |
# st.markdown(colored_output, unsafe_allow_html=True) | |
# @st.cache_resource | |
def load_text_gen_model(check): | |
iden=['error'] | |
while 'error' in iden: | |
time.sleep(1) | |
try: | |
output = query({ | |
"inputs": f"""{check}""", | |
"parameters": { | |
"min_new_tokens": 30, | |
"max_new_tokens": 100, | |
"do_sample":True, | |
#"remove_invalid_values" : True | |
#"temperature" :0.6 | |
# "top_k":1 | |
# "num_beams":2, | |
# "no_repeat_ngram_size":2, | |
# "early_stopping":True | |
} | |
}) | |
iden = output # Update 'check' with the new result | |
except Exception as e: | |
print(f"An exception occurred: {e}") | |
return output[0]['generated_text'] | |
# @st.cache_data | |
# def load_text_gen_model(check): | |
# return get_answers(prompt + check) | |
def decoded_tokens(string, tokenizer): | |
return [tokenizer.decode([x]) for x in tokenizer.encode(string)] | |
# def analyze_heatmap(df): | |
# sns.set_palette(sns.color_palette("viridis")) | |
# # Create a copy of the DataFrame to prevent modification of the original | |
# df_copy = df.copy() | |
# # Ensure DataFrame has the required columns | |
# if 'token' not in df_copy.columns or 'importance_value' not in df_copy.columns: | |
# raise ValueError("The DataFrame must contain 'token' and 'importance_value' columns.") | |
# # Add 'Position' column to the DataFrame copy | |
# df_copy['Position'] = range(len(df_copy)) | |
# # Plot a bar chart for importance score per token | |
# plt.figure(figsize=(len(df_copy) * 0.3, 4)) | |
# sns.barplot(x='token', y='importance_value', data=df_copy) | |
# plt.xticks(rotation=45, ha='right') | |
# plt.title('Importance Score per Token') | |
# return plt | |
# #plt.show() | |
# ########################### | |
# def analyze_heatmap(df_input): | |
# df = df_input.copy() | |
# df["Position"] = range(len(df)) | |
# # Get the viridis colormap | |
# viridis = matplotlib.cm.get_cmap("viridis") | |
# # Create a Matplotlib figure and axis | |
# fig, ax = plt.subplots(figsize=(10, 6)) | |
# # Normalize the importance values | |
# min_val = df["importance_value"].min() | |
# max_val = df["importance_value"].max() | |
# normalized_values = (df["importance_value"] - min_val) / (max_val - min_val) | |
# # Create the bars, colored based on normalized importance_value | |
# for i, (token, norm_value) in enumerate(zip(df["token"], normalized_values)): | |
# color = viridis(norm_value) | |
# ax.bar( | |
# x=[i], # Use index for x-axis | |
# height=[df["importance_value"].iloc[i]], | |
# width=1.0, # Set the width to make bars touch each other | |
# color=[color], | |
# ) | |
# # Additional styling | |
# ax.set_title("Importance Score per Token", size=25) | |
# ax.set_xlabel("Token") | |
# ax.set_ylabel("Importance Value") | |
# ax.set_xticks(range(len(df["token"]))) | |
# ax.set_xticklabels(df["token"], rotation=45) | |
# return fig | |
def analyze_heatmap(df_input): | |
df = df_input.copy() | |
df["Position"] = range(len(df)) | |
# Get the Blues colormap | |
blues = matplotlib.cm.get_cmap("Blues") | |
# Create a Matplotlib figure and axis | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
# Normalize the importance values | |
min_val = df["importance_value"].min() | |
max_val = df["importance_value"].max() | |
normalized_values = (df["importance_value"] - min_val) / (max_val - min_val) | |
# Create the bars, colored based on normalized importance_value | |
for i, (token, norm_value) in enumerate(zip(df["token"], normalized_values)): | |
color = blues(norm_value) | |
ax.bar( | |
x=[i], # Use index for x-axis | |
height=[df["importance_value"].iloc[i]], | |
width=1.0, # Set the width to make bars touch each other | |
color=[color], | |
) | |
# Additional styling | |
# ax.set_title("Importance Score per Token", size=25) | |
# ax.set_xlabel("Token") | |
# ax.set_ylabel("Importance Value") | |
ax.set_xticks(range(len(df["token"]))) | |
ax.set_xticklabels(df["token"], rotation=45) | |
return fig | |
# def analyze_heatmap(df_input): | |
# df = df_input.copy() | |
# df["Position"] = range(len(df)) | |
# # Get the viridis colormap | |
# viridis = matplotlib.colormaps["viridis"] | |
# # Initialize the figure | |
# fig = go.Figure() | |
# # Create the histogram bars with viridis coloring | |
# # Normalize the importance values | |
# min_val = df["importance_value"].min() | |
# max_val = df["importance_value"].max() | |
# normalized_values = (df["importance_value"] - min_val) / (max_val - min_val) | |
# # Initialize the figure | |
# fig = go.Figure() | |
# # Create the bars, colored based on normalized importance_value | |
# for i, (token, norm_value) in enumerate(zip(df["token"], normalized_values)): | |
# color = f"rgb({int(viridis(norm_value)[0] * 255)}, {int(viridis(norm_value)[1] * 255)}, {int(viridis(norm_value)[2] * 255)})" | |
# fig.add_trace( | |
# go.Bar( | |
# x=[i], # Use index for x-axis | |
# y=[df["importance_value"].iloc[i]], | |
# width=1.0, # Set the width to make bars touch each other | |
# marker=dict(color=color), | |
# ) | |
# ) | |
# # Additional styling | |
# fig.update_layout( | |
# title=f"Importance Score per Token", | |
# title_font={'size': 25}, | |
# xaxis_title="Token", | |
# yaxis_title="Importance Value", | |
# showlegend=False, | |
# bargap=0, # Remove gap between bars | |
# xaxis=dict( # Set tick labels to tokens | |
# tickmode="array", | |
# tickvals=list(range(len(df["token"]))), | |
# ticktext=list(df["token"]), | |
# ), | |
# ) | |
# # Rotate x-axis labels by 45 degrees | |
# fig.update_xaxes(tickangle=45) | |
# return fig | |
############################ | |
# @st.cache_data | |
def integrated_gradients(input_ids, baseline, model, n_steps= 10): #100 | |
# Convert input_ids and baseline to LongTensors | |
input_ids = input_ids.long() | |
baseline = baseline.long() | |
# Initialize tensor to store accumulated gradients | |
accumulated_grads = None | |
# Create interpolated inputs | |
alphas = torch.linspace(0, 1, n_steps) | |
delta = input_ids - baseline | |
interpolates = [(baseline + (alpha * delta).long()).long() for alpha in alphas] # Explicitly cast to LongTensor | |
# Initialize tqdm progress bar | |
pbar = tqdm(total=n_steps, desc="Calculating Integrated Gradients") | |
for interpolate in interpolates: | |
# Update tqdm progress bar | |
pbar.update(1) | |
# Convert interpolated samples to embeddings | |
interpolate_embedding = model.transformer.wte(interpolate).clone().detach().requires_grad_(True) | |
# Forward pass | |
output = model(inputs_embeds=interpolate_embedding, output_attentions=False)[0] | |
# Aggregate the logits across all positions (using sum in this example) | |
aggregated_logit = output.sum() | |
# Backward pass to calculate gradients | |
aggregated_logit.backward() | |
# Accumulate gradients | |
if accumulated_grads is None: | |
accumulated_grads = interpolate_embedding.grad.clone() | |
else: | |
accumulated_grads += interpolate_embedding.grad | |
# Clear gradients | |
model.zero_grad() | |
interpolate_embedding.grad.zero_() | |
# Close tqdm progress bar | |
pbar.close() | |
# Compute average gradients | |
avg_grads = accumulated_grads / n_steps | |
# Compute attributions | |
with torch.no_grad(): | |
input_embedding = model.transformer.wte(input_ids) | |
baseline_embedding = model.transformer.wte(baseline) | |
attributions = (input_embedding - baseline_embedding) * avg_grads | |
return attributions | |
# @st.cache_data | |
def process_integrated_gradients(input_text, _gpt2tokenizer, model): | |
inputs = torch.tensor([_gpt2tokenizer.encode(input_text)]) | |
gpt2tokens = decoded_tokens(input_text, _gpt2tokenizer) | |
with torch.no_grad(): | |
outputs = model(inputs, output_attentions=True, output_hidden_states=True) | |
attentions = outputs[-1] | |
# Initialize a baseline as zero tensor | |
baseline = torch.zeros_like(inputs).long() | |
# Compute Integrated Gradients targeting the aggregated sequence output | |
attributions = integrated_gradients(inputs, baseline, model) | |
# Convert tensors to numpy array for easier manipulation | |
attributions_np = attributions.detach().numpy().sum(axis=2) | |
# Sum across the embedding dimensions to get a single attribution score per token | |
attributions_sum = attributions.sum(axis=2).squeeze(0).detach().numpy() | |
l2_norm_attributions = np.linalg.norm(attributions_sum, 2) | |
normalized_attributions_sum = attributions_sum / l2_norm_attributions | |
clamped_attributions_sum = np.where(normalized_attributions_sum < 0, 0, normalized_attributions_sum) | |
attribution_df = pd.DataFrame({ | |
'token': gpt2tokens, | |
'importance_value': clamped_attributions_sum | |
}) | |
return attribution_df | |
######################## | |
model_type = 'gpt2' | |
model_version = 'gpt2' | |
model = GPT2LMHeadModel.from_pretrained(model_version, output_attentions=True) | |
_gpt2tokenizer = tiktoken.get_encoding("gpt2") | |
####################### | |
para_tokenizer = AutoTokenizer.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base") | |
para_model = AutoModelForSeq2SeqLM.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base") | |
###################### | |
def paraphrase( | |
question, | |
num_beams=5, | |
num_beam_groups=5, | |
num_return_sequences=5, | |
repetition_penalty=10.0, | |
diversity_penalty=3.0, | |
no_repeat_ngram_size=2, | |
temperature=0.7, | |
max_length=64 #128 | |
): | |
input_ids = para_tokenizer( | |
f'paraphrase: {question}', | |
return_tensors="pt", padding="longest", | |
max_length=max_length, | |
truncation=True, | |
).input_ids | |
outputs = para_model.generate( | |
input_ids, temperature=temperature, repetition_penalty=repetition_penalty, | |
num_return_sequences=num_return_sequences, no_repeat_ngram_size=no_repeat_ngram_size, | |
num_beams=num_beams, num_beam_groups=num_beam_groups, | |
max_length=max_length, diversity_penalty=diversity_penalty | |
) | |
res = para_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return res | |
########################### | |
class SentenceAnalyzer: | |
def __init__(self, check, original, _gpt2tokenizer, model): | |
self.check = check | |
self.original = original | |
self._gpt2tokenizer = _gpt2tokenizer | |
self.model = model | |
self.entity_tags = load_chunk_model(check) | |
self.tagged_sentence = generate_tagged_sentence(check, self.entity_tags) | |
self.sentence_with_pause = replace_pp_with_pause(check, self.entity_tags) | |
self.split_sentences = get_split_sentences(check, self.entity_tags) | |
self.colored_output = colorize_tokens(self.entity_tags, check) | |
def analyze(self): | |
# st.caption(f"The below :blue[shallow parsing] tags are found for {self.original} prompt:") | |
# st.markdown(self.colored_output, unsafe_allow_html=True) | |
attribution_df1 = process_integrated_gradients(self.check, self._gpt2tokenizer, self.model) | |
st.caption(f":blue[{self.original}]:") | |
render_heatmap(self.check, attribution_df1) | |
# st.write("Original") | |
st.pyplot(analyze_heatmap(attribution_df1)) | |
# st.write("After [PAUSE]") | |
# st.write("Sentence with [PAUSE] Replacement:", self.sentence_with_pause) | |
dataframes_list = [] | |
for i, split_sentence in enumerate(self.split_sentences): | |
# st.write(f"Sentence {i + 1} : {split_sentence}") | |
attribution_df1 = process_integrated_gradients(split_sentence, self._gpt2tokenizer, self.model) | |
if i < len(self.split_sentences) - 1: | |
# Add a row with [PAUSE] and value 0 at the end | |
pause_row = pd.DataFrame({'token': '[PAUSE]', 'importance_value': 0},index=[len(attribution_df1)]) | |
# pause_row = pd.DataFrame({'', '': 0},index=[len(attribution_df1)]) | |
attribution_df1 = pd.concat([attribution_df1,pause_row], ignore_index=True) | |
dataframes_list.append(attribution_df1) | |
# After the loop, you can concatenate the dataframes in the list if needed | |
if dataframes_list: | |
combined_dataframe = pd.concat(dataframes_list, axis=0) | |
combined_dataframe = combined_dataframe[combined_dataframe['token'] != " "].reset_index(drop=True) | |
combined_dataframe1 = combined_dataframe[combined_dataframe['token'] != "[PAUSE]"] | |
combined_dataframe1.reset_index(drop=True, inplace=True) | |
st.write(f"Sentence with [PAUSE] Replacement:") | |
# st.dataframe(combined_dataframe1) | |
render_heatmap(self.sentence_with_pause,combined_dataframe1) | |
# render_heatmap(self.sentence_with_pause,combined_dataframe) | |
st.pyplot(analyze_heatmap(combined_dataframe)) | |
paraphrase_list=paraphrase(check) | |
# st.write(paraphrase_list) | |
###################### | |
col1, col2 = st.columns(2) | |
with col1: | |
analyzer = SentenceAnalyzer(check, "Original Prompt", _gpt2tokenizer, model) | |
analyzer.analyze() | |
with col2: | |
ai_gen_text=load_text_gen_model(check) | |
st.caption(':blue[AI generated text by GPT4]') | |
st.write(ai_gen_text) | |
#st.markdown("""<hr style="height:5px;border:none;color:#333;background-color:#333;" /> """, unsafe_allow_html=True) | |
st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:lightblue;" /> """, unsafe_allow_html=True) | |
col3, col4 = st.columns(2) | |
with col3: | |
analyzer = SentenceAnalyzer(" "+paraphrase_list[0], "Paraphrase 1", _gpt2tokenizer, model) | |
analyzer.analyze() | |
with col4: | |
ai_gen_text=load_text_gen_model(paraphrase_list[0]) | |
st.caption(':blue[AI generated text by GPT4]') | |
st.write(ai_gen_text) | |
st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:skyblue;" /> """, unsafe_allow_html=True) | |
col5, col6 = st.columns(2) | |
with col5: | |
analyzer = SentenceAnalyzer(" "+paraphrase_list[1], "Paraphrase 2", _gpt2tokenizer, model) | |
analyzer.analyze() | |
with col6: | |
ai_gen_text=load_text_gen_model(paraphrase_list[1]) | |
st.caption(':blue[AI generated text by GPT4]') | |
st.write(ai_gen_text) | |
st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:skyblue;" /> """, unsafe_allow_html=True) | |
col7, col8 = st.columns(2) | |
with col7: | |
analyzer = SentenceAnalyzer(" "+paraphrase_list[2], "Paraphrase 3", _gpt2tokenizer, model) | |
analyzer.analyze() | |
with col8: | |
ai_gen_text=load_text_gen_model(paraphrase_list[2]) | |
st.caption(':blue[AI generated text by GPT4]') | |
st.write(ai_gen_text) | |
st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:skyblue;" /> """, unsafe_allow_html=True) | |
col9, col10 = st.columns(2) | |
with col9: | |
analyzer = SentenceAnalyzer(" "+paraphrase_list[3], "Paraphrase 4", _gpt2tokenizer, model) | |
analyzer.analyze() | |
with col10: | |
ai_gen_text=load_text_gen_model(paraphrase_list[3]) | |
st.caption(':blue[AI generated text by GPT4]') | |
st.write(ai_gen_text) | |
st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:skyblue;" /> """, unsafe_allow_html=True) | |
col11, col12 = st.columns(2) | |
with col11: | |
analyzer = SentenceAnalyzer(" "+paraphrase_list[4], "Paraphrase 5", _gpt2tokenizer, model) | |
analyzer.analyze() | |
with col12: | |
ai_gen_text=load_text_gen_model(paraphrase_list[4]) | |
st.caption(':blue[AI generated text by GPT4]') | |
st.write(ai_gen_text) | |