File size: 2,183 Bytes
492f6af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Utility functions"""
import importlib
import random
import re
import torch
import numpy as np
from PIL import Image

    
def normalize(image,rescale=True):
    
    if rescale:
        image = image.float() / 255.0  # Convert to float and rescale to [0, 1]
    normalize_image = 2*image-1 # normalize to [-1, 1]

    return normalize_image



def process_caption(caption):
    """Process a caption to ensure proper formatting and remove duplicates.
    
    Args:
        caption: A string containing the caption text
        
    Returns:
        processed_caption: A string with processed caption
    """
    if not caption.endswith('.'):
        last_period_index = caption.rfind('.')
        if last_period_index != -1:
            caption = caption[:last_period_index + 1]
    
    sentences = re.split(r'(?<=[.!?])\s+', caption)
    
    unique_sentences = []
    for sentence in sentences:
        if sentence and sentence not in unique_sentences:
            unique_sentences.append(sentence)
    
    processed_caption = ' '.join(unique_sentences)
    
    return processed_caption


def initiate_time_steps(step, total_timestep, batch_size, config):
    """A helper function to initiate time steps for the diffusion model.

    Args:
        step: An integer of the constant step
        total_timestep: An integer of the total timesteps of the diffusion model
        batch_size: An integer of the batch size
        config: A config object

    Returns:
        timesteps: A tensor of shape [batch_size,] of the time steps
    """
    if config.rand_timestep_equal_int:
        # the same timestep for each image in the batch
        interval_val = total_timestep // batch_size
        start_point = random.randint(0, interval_val - 1)
        timesteps = torch.tensor(
            list(range(start_point, total_timestep, interval_val))
        ).long()
        return timesteps
    elif config.random_timestep_per_iteration:
        # random timestep for each image in the batch
        return torch.randint(0, total_timestep, (batch_size,)).long()          #default
    else:
        # why we need to do this?
        return torch.tensor([step] * batch_size).long()