AddLat2D / Data_Plotting /Plot_TSNE.py
marta-marta's picture
Added TSNE Plot
7e70cfa
raw
history blame
2.55 kB
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
# Latent Feature Cluster for Training Data using T-SNE
def TSNE_reduction(latent_points, perplexity=30, learning_rate=20):
latent_dimensionality = len(latent_points[0])
model = TSNE(n_components=2, random_state=0, perplexity=perplexity,
learning_rate=learning_rate) # Perplexity(5-50) | learning_rate(10-1000)
embedding = model
# configuring the parameters
# the number of components = dimension of the embedded space
# default perplexity = 30 " Perplexity balances the attention t-SNE gives to local and global aspects of the data.
# It is roughly a guess of the number of close neighbors each point has. ..a denser dataset ... requires higher perplexity value"
# default learning rate = 200 "If the learning rate is too high, the data may look like a ‘ball’ with any point
# approximately equidistant from its nearest neighbours. If the learning rate is too low,
# most points may look compressed in a dense cloud with few outliers."
tsne_data = model.fit_transform(
latent_points) # When there are more data points, trainX should be the first couple hundred points so TSNE doesn't take too long
x = tsne_data[:, 0]
y = tsne_data[:, 1]
title = ("T-SNE of Data")
return x, y, title, embedding
########################################################################################################################
import pandas as pd
import json
df = pd.read_csv('2D_Lattice.csv')
row = 0
box = df.iloc[row,1]
array = np.array(json.loads(box))
"""
# For plotting CSV data
# define a function to flatten a box
def flatten_box(box_str):
box = json.loads(box_str)
return np.array(box).flatten()
# apply the flatten_box function to each row of the dataframe and create a list of flattened arrays
flattened_arrays = df['Array'].apply(flatten_box).tolist()
x, y, title, embedding = TSNE_reduction(flattened_arrays)
plt.scatter(x,y)
plt.title(title)
plt.show()
"""
# def plot_dimensionality_reduction(x, y, label_set, title):
# plt.title(title)
# if label_set[0].dtype == float:
# plt.scatter(x, y, c=label_set)
# plt.colorbar()
# print("using scatter")
# else:
# for label in set(label_set):
# cond = np.where(np.array(label_set) == str(label))
# plt.plot(x[cond], y[cond], marker='o', linestyle='none', label=label)
#
# plt.legend(numpoints=1)
#
# plt.show()
# plt.close()