File size: 2,367 Bytes
f908e9d
 
 
 
 
 
 
 
 
 
1e0ae95
f908e9d
 
 
 
82fdcc2
f908e9d
 
 
 
82fdcc2
f908e9d
1e0ae95
f908e9d
 
 
 
 
8b7bf6e
f908e9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import copy
import numpy as np

from typing import OrderedDict
from scipy.ndimage import gaussian_filter1d

from transformers import PreTrainedModel
from in2in.utils.configs import get_config
from in2in.models.in2in import in2IN
from in2in.utils.preprocess import MotionNormalizer

from .config import in2INConfig

class in2INModel(PreTrainedModel):

    config_class = in2INConfig

    def __init__(self, config):
        super().__init__(config)
        self.mode = config.MODE
        self.model = in2IN(config, mode=config.MODE)
        self.normalizer = MotionNormalizer()

    def forward(self, prompt_interaction, prompt_individual1, prompt_individual2):
        self.model.eval()
        batch = OrderedDict({})

        batch["motion_lens"] = torch.zeros(1,1).long()
        batch["prompt_interaction"] = prompt_interaction

        if self.mode != "individual":
            batch["prompt_individual1"] = prompt_individual1
            batch["prompt_individual2"] = prompt_individual2

        window_size = 210
        motion_output = self.generate_loop(batch, window_size)
        return motion_output

    def generate_loop(self, batch, window_size):
        prompt_interaction = batch["prompt_interaction"]

        if self.mode != "individual":
            prompt_individual1 = batch["prompt_individual1"]
            prompt_individual2 = batch["prompt_individual2"]

        batch = copy.deepcopy(batch)
        batch["motion_lens"][:] = window_size

        batch["text"] = [prompt_interaction]
        if self.mode != "individual":
            batch["text_individual1"] = [prompt_individual1]
            batch["text_individual2"] = [prompt_individual2]

        batch = self.model.forward_test(batch)
        motion_output_both = batch["output"][0].reshape(batch["output"][0].shape[0], 2, -1)
        motion_output_both = self.normalizer.backward(motion_output_both.cpu().detach().numpy())
        
        sequences = [[], []]
        for j in range(2):
            motion_output = motion_output_both[:,j]
            joints3d = motion_output[:,:22*3].reshape(-1,22,3)
            joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest')
            sequences[j].append(joints3d)

        sequences[0] = np.concatenate(sequences[0], axis=0)
        sequences[1] = np.concatenate(sequences[1], axis=0)
        return sequences