#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Mar 10 21:13:04 2023 @author: zhihuang """ import pickle import os import pandas as pd import numpy as np import umap import seaborn as sns import matplotlib.pyplot as plt opj=os.path.join if __name__ == '__main__': dd = '/home/zhihuang/Desktop/webplip/data' with open(opj(dd, 'twitter.asset'),'rb') as f: data = pickle.load(f) n_neighbors = 15 random_state = 0 reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=0.1, metric='euclidean', random_state=random_state) img_2d = reducer.fit(data['image_embedding']) img_2d = reducer.transform(data['image_embedding']) df_img = pd.DataFrame(np.c_[img_2d, data['meta'].values], columns = ['UMAP_1','UMAP_2'] + list(data['meta'].columns)) df_img.to_csv(opj(dd, 'img_2d_embedding.csv')) # reducer = umap.UMAP(n_components=2, # n_neighbors=n_neighbors, # min_dist=0.1, # metric='euclidean', # random_state=random_state) txt_2d = reducer.fit_transform(data['text_embedding']) df_txt = pd.DataFrame(np.c_[txt_2d, data['meta'].values], columns = ['UMAP_1','UMAP_2'] + list(data['meta'].columns)) df_txt.to_csv(opj(dd, 'txt_2d_embedding.csv')) fig, ax = plt.subplots(1,2, figsize=(20,10)) sns.scatterplot(data=df_img, x='UMAP_1', y='UMAP_2', alpha=0.2, ax=ax[0], hue='tag' ) sns.scatterplot(data=df_txt, x='UMAP_1', y='UMAP_2', alpha=0.2, ax=ax[1], hue='tag' )