File size: 2,216 Bytes
bfed184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import random 
import numpy as np
import torch 
import matplotlib.pyplot as plt
from PIL import Image 

def merge_lists_by_index(list1, list2):
    # Check if both lists have the same number of elements
    if len(list1) != len(list2):
        raise ValueError("Both lists should have the same number of elements.")

    # Merge the lists by concatenating strings at the same index
    merged_list = [string1 + '. ' + string2 for string1, string2 in zip(list1, list2)]
    return merged_list

def plot_x_y(x, y, x_label, y_label, save_path, **kwargs):
    
    plt.plot( x , y, **kwargs)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend()
    plt.savefig(save_path)


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# count # of param for a list of module
def count_param(module_list):
    return sum(x.numel() for module in module_list for x in module.parameters()) / 10**6
    
# display the peak memory of cuda
def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def anal_tensor(tensor, name):
    sent = f" name: {name} mean: {tensor.mean().item()}  std: {tensor.std().item()}  min: {tensor.min().item()}  max: {tensor.max().item()}"
    print(sent)


def split(lst, split_nbr):
    div = len(lst) // split_nbr
    rest = len(lst) % split_nbr
    results = []
    start, end = 0, div
    while start < len(lst):
        if rest >= 1:
            end += 1
            rest -= 1
        results.append(lst[start:end])
        start, end = end, end+div
    return results

def chunk(iterable, chunk_size):
    ret = []
    for record in iterable:
        ret.append(record)
        if len(ret) == chunk_size:
            yield ret
            ret = []
    if ret:
        yield ret

def image_concat_h(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

def image_concat_v(im1, im2):
    dst = Image.new('RGB', (im1.width, im1.height + im2.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (0, im1.height))
    return dst