File size: 5,115 Bytes
bd1f2b1
 
 
 
 
 
 
 
7b4fc5d
 
 
e29fad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd1f2b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b4fc5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd1f2b1
 
 
7b4fc5d
 
 
bd1f2b1
 
 
7b4fc5d
 
 
 
bd1f2b1
 
 
 
7b4fc5d
 
 
 
bd1f2b1
 
 
 
7b4fc5d
bd1f2b1
7b4fc5d
 
 
 
 
 
 
 
 
 
 
 
 
 
bd1f2b1
7b4fc5d
 
 
bd1f2b1
 
 
efb1c49
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
ο»Ώimport argparse
import yaml
import torch
import os
import sys
from pathlib import Path
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def get_device(config_device):
    """Auto-detect available device"""
    if config_device == "auto":
        if torch.cuda.is_available():
            device = "cuda"
            logger.info("CUDA available, using GPU")
        else:
            device = "cpu"
            logger.info("CUDA not available, using CPU")
    else:
        device = config_device
        logger.info(f"Using configured device: {device}")
    
    return device

def parse_args():
    parser = argparse.ArgumentParser(description="OmniAvatar-14B Inference")
    parser.add_argument("--config", type=str, required=True, help="Path to config file")
    parser.add_argument("--input_file", type=str, required=True, help="Path to input samples file")
    parser.add_argument("--guidance_scale", type=float, default=5.0, help="Guidance scale")
    parser.add_argument("--audio_scale", type=float, default=3.0, help="Audio guidance scale")
    parser.add_argument("--num_steps", type=int, default=30, help="Number of inference steps")
    parser.add_argument("--sp_size", type=int, default=1, help="Multi-GPU size")
    parser.add_argument("--tea_cache_l1_thresh", type=float, default=None, help="TeaCache threshold")
    return parser.parse_args()

def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

def process_input_file(input_file):
    """Parse input file with format: prompt@@image_path@@audio_path"""
    samples = []
    with open(input_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line:
                parts = line.split('@@')
                if len(parts) >= 3:
                    prompt = parts[0]
                    image_path = parts[1] if parts[1] else None
                    audio_path = parts[2]
                    samples.append({
                        'prompt': prompt,
                        'image_path': image_path,
                        'audio_path': audio_path
                    })
    return samples

def create_placeholder_video(output_path, duration=5.0, fps=24):
    """Create a simple placeholder video"""
    import numpy as np
    import cv2
    
    logger.info(f"Creating placeholder video: {output_path}")
    
    # Video properties
    width, height = 480, 480
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    
    # Create video writer
    out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
    
    # Generate frames
    total_frames = int(duration * fps)
    for frame_idx in range(total_frames):
        # Create a simple animated frame
        frame = np.zeros((height, width, 3), dtype=np.uint8)
        
        # Add some animation - moving circle
        center_x = int(width/2 + 100 * np.sin(2 * np.pi * frame_idx / 60))
        center_y = int(height/2 + 50 * np.cos(2 * np.pi * frame_idx / 60))
        
        # Draw circle
        cv2.circle(frame, (center_x, center_y), 30, (0, 255, 0), -1)
        
        # Add text
        text = f"Avatar Placeholder Frame {frame_idx + 1}"
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(frame, text, (10, 30), font, 0.5, (255, 255, 255), 1)
        
        out.write(frame)
    
    out.release()
    logger.info(f"βœ… Placeholder video created: {output_path}")

def main():
    args = parse_args()
    
    logger.info("πŸš€ Starting OmniAvatar-14B Inference")
    logger.info(f"Arguments: {args}")
    
    # Load configuration
    config = load_config(args.config)
    
    # Auto-detect device
    device = get_device(config["hardware"]["device"])
    config["hardware"]["device"] = device
    
    # Process input samples
    samples = process_input_file(args.input_file)
    logger.info(f"Processing {len(samples)} samples")
    
    if not samples:
        logger.error("No valid samples found in input file")
        return
    
    # Create output directory
    output_dir = Path(config['output']['output_dir'])
    output_dir.mkdir(exist_ok=True)
    
    # Process each sample
    for i, sample in enumerate(samples):
        logger.info(f"Processing sample {i+1}/{len(samples)}: {sample['prompt'][:50]}...")
        
        # For now, create a placeholder video
        output_filename = f"avatar_output_{i:03d}.mp4"
        output_path = output_dir / output_filename
        
        try:
            # Create placeholder video (in the future, this would be actual avatar generation)
            create_placeholder_video(output_path, duration=5.0, fps=24)
            
            logger.info(f"βœ… Sample {i+1} completed: {output_path}")
            
        except Exception as e:
            logger.error(f"❌ Error processing sample {i+1}: {e}")
    
    logger.info("πŸŽ‰ Inference completed!")
    logger.info("πŸ“ Note: Currently generating placeholder videos.")
    logger.info("πŸ”œ Future updates will include actual OmniAvatar model inference.")

if __name__ == "__main__":
    main()