File size: 2,064 Bytes
7ee7b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 10 21:13:04 2023

@author: zhihuang
"""

import pickle
import os
import pandas as pd
import numpy as np
import umap
import seaborn as sns
import matplotlib.pyplot as plt
opj=os.path.join

if __name__ == '__main__':
    dd = '/home/zhihuang/Desktop/webplip/data'
    with open(opj(dd, 'twitter.asset'),'rb') as f:
        data = pickle.load(f)
    
    n_neighbors = 15
    random_state = 0
    
    reducer = umap.UMAP(n_components=2,
                          n_neighbors=n_neighbors,
                          min_dist=0.1,
                          metric='euclidean',
                          random_state=random_state)
    img_2d = reducer.fit(data['image_embedding'])
    img_2d = reducer.transform(data['image_embedding'])
    df_img = pd.DataFrame(np.c_[img_2d, data['meta'].values], columns = ['UMAP_1','UMAP_2'] + list(data['meta'].columns))
    df_img.to_csv(opj(dd, 'img_2d_embedding.csv'))
        
    
    # reducer = umap.UMAP(n_components=2,
    #                       n_neighbors=n_neighbors,
    #                       min_dist=0.1,
    #                       metric='euclidean',
    #                       random_state=random_state)
    txt_2d = reducer.fit_transform(data['text_embedding'])
    df_txt = pd.DataFrame(np.c_[txt_2d, data['meta'].values], columns = ['UMAP_1','UMAP_2'] + list(data['meta'].columns))
    df_txt.to_csv(opj(dd, 'txt_2d_embedding.csv'))
        
        
        
    
    fig, ax = plt.subplots(1,2, figsize=(20,10))
    sns.scatterplot(data=df_img,
                    x='UMAP_1',
                    y='UMAP_2',
                    alpha=0.2,
                    ax=ax[0],
                    hue='tag'
                    )
        
    sns.scatterplot(data=df_txt,
                    x='UMAP_1',
                    y='UMAP_2',
                    alpha=0.2,
                    ax=ax[1],
                    hue='tag'
                    )