|
import torch |
|
import torch.nn |
|
import torchvision.models as models |
|
from copy import deepcopy |
|
import cv2 |
|
|
|
import cv2 |
|
import numpy as np |
|
import sys |
|
import itertools |
|
import os |
|
import IPython |
|
import matplotlib |
|
matplotlib.use("Agg") |
|
|
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
|
|
import openai |
|
from sklearn.manifold import TSNE |
|
from sklearn.decomposition import PCA, KernelPCA |
|
import seaborn as sns |
|
|
|
import time |
|
from matplotlib.offsetbox import OffsetImage, AnnotationBbox |
|
import colorsys |
|
from torchvision import datasets |
|
import argparse |
|
import matplotlib.patheffects as PathEffects |
|
from sklearn.cluster import KMeans |
|
|
|
|
|
sns.set_style("white") |
|
sns.set_palette("muted") |
|
|
|
font = { |
|
"size": 22, |
|
} |
|
|
|
matplotlib.rc("font", **font) |
|
sns.set_context("paper", font_scale=3.0) |
|
|
|
|
|
plt_param = {'legend.fontsize': 60, |
|
'axes.labelsize': 80, |
|
'axes.titlesize':80, |
|
'font.size' : 80 , |
|
'xtick.labelsize':80, |
|
'ytick.labelsize':80, |
|
'lines.linewidth': 10, |
|
'lines.color': (0,0,0)} |
|
|
|
plt.rcParams.update(plt_param) |
|
|
|
openai.api_key ="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" |
|
GPT_MODEL = "gpt4" |
|
EMBEDDING_MODEL = "text-embedding-ada-002" |
|
|
|
|
|
def normalize_numpy_array(arr): |
|
return arr / (arr.max(axis=-1, keepdims=True) - arr.min(axis=-1, keepdims=True)) |
|
|
|
def fashion_scatter( |
|
x, class_labels, fig_name, class_names, add_text=True |
|
): |
|
|
|
x = np.array(x) |
|
class_labels = np.array(class_labels) |
|
num_classes = np.max(class_labels) + 1 |
|
|
|
|
|
fig_size1, fig_size2 = 140 * 0.8, 80 * 0.6 |
|
plt.clf() |
|
plt.cla() |
|
f = plt.figure(figsize=(fig_size1, fig_size2)) |
|
ax = plt.subplot() |
|
|
|
|
|
|
|
for x_i in range(num_classes): |
|
mask = class_labels == x_i |
|
if mask.sum() > 0: |
|
sc = ax.scatter( |
|
x[mask, 0], |
|
x[mask, 1], |
|
lw=0, |
|
s=1500, |
|
label=class_names[x_i] |
|
|
|
) |
|
if add_text: |
|
txts = [] |
|
for i in range(len(class_names)): |
|
xtext, ytext = x[i, :] |
|
txt = ax.text(xtext, ytext, str(class_names[i]), fontsize=40) |
|
txt.set_path_effects( |
|
[PathEffects.Stroke(linewidth=5, foreground="w"), PathEffects.Normal()] |
|
) |
|
txts.append(txt) |
|
|
|
|
|
ax.axis("on") |
|
|
|
plt.savefig(fig_name +".pdf") |
|
plt.clf() |
|
print("save figure to ", fig_name) |
|
|
|
def compute_embedding(response): |
|
while True: |
|
try: |
|
print('ping openai api') |
|
response_embedding = openai.Embedding.create( |
|
model=EMBEDDING_MODEL, |
|
input=response, |
|
) |
|
|
|
response_embedding = np.array(response_embedding["data"][0]['embedding']) |
|
return response_embedding |
|
except Exception as e: |
|
print(e) |
|
|
|
def draw_latent_plot( |
|
max_num=80, |
|
method="pca+tsne", |
|
fig_name="", |
|
): |
|
|
|
latents = [] |
|
class_labels = [] |
|
label_sets = [] |
|
|
|
|
|
total_tasks = [os.path.join("cliport/tasks", x) for x in os.listdir("cliport/tasks")] + [os.path.join("cliport/generated_tasks", x) for x in os.listdir("cliport/generated_tasks")] |
|
total_tasks = [t for t in total_tasks if 'pycache' not in t and 'init' not in t \ |
|
and 'README' not in t and 'extended' not in t and 'gripper' not in t and 'primitive' not in t\ |
|
and 'task.py' not in t and 'camera' not in t and 'seq' not in t] |
|
cache_embedding_path = "output/output_embedding/task_cache_embedding.npz" |
|
cache_embedding = {} |
|
|
|
if os.path.exists(cache_embedding_path): |
|
cache_embedding = dict(np.load(cache_embedding_path)) |
|
|
|
print(total_tasks) |
|
|
|
for idx, task_name in enumerate(total_tasks): |
|
if task_name in cache_embedding: |
|
code_embedding = cache_embedding[task_name] |
|
else: |
|
code = open(task_name).read() |
|
code_embedding = compute_embedding(code) |
|
|
|
latents.append(code_embedding) |
|
label_sets.append(task_name.split("/")[-1][:-3]) |
|
cache_embedding[task_name] = code_embedding |
|
class_labels.append(idx) |
|
|
|
latents = np.array(latents) |
|
print("latents shape:", latents.shape) |
|
np.savez(cache_embedding_path, **cache_embedding) |
|
|
|
n_clusters = 6 |
|
kmeans = KMeans(n_clusters=n_clusters, init="k-means++", random_state=42) |
|
kmeans.fit(latents) |
|
cluster_labels = kmeans.labels_ |
|
|
|
if method == "pca+tsne": |
|
|
|
pca = PCA(random_state=123, n_components=min(50, max_num)) |
|
|
|
X_embedded = pca.fit_transform(latents) |
|
print( |
|
"Variance explained per principal component: {}".format( |
|
pca.explained_variance_ratio_[:5] |
|
) |
|
) |
|
print("PCA data shape:", X_embedded.shape) |
|
X_embedded = TSNE(random_state=123, perplexity=20).fit_transform(X_embedded) |
|
|
|
if method == "pca": |
|
pca = KernelPCA(random_state=123, n_components=2) |
|
X_embedded = pca.fit_transform(latents[:, :5]) |
|
|
|
if method == "tsne": |
|
X_embedded = TSNE(random_state=123).fit_transform(latents) |
|
|
|
fashion_scatter(X_embedded, class_labels, fig_name, label_sets) |
|
fashion_scatter(X_embedded, cluster_labels, fig_name + "_cluster", label_sets) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Generate chat-gpt embeddings") |
|
""" |
|
load task descriptions from the tasks folder and embed |
|
""" |
|
parser.add_argument("--file", type=str, default="task_embedding") |
|
args = parser.parse_args() |
|
draw_latent_plot(fig_name=f'output/output_embedding/{args.file}') |