hpoghos commited on
Commit
81022ab
1 Parent(s): c3c0523
app.py CHANGED
@@ -1,7 +1,189 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General
2
+ import os
3
+ from os.path import join as opj
4
+ import argparse
5
+ import datetime
6
+ from pathlib import Path
7
+ import torch
8
  import gradio as gr
9
+ import tempfile
10
+ import yaml
11
+ from t2v_enhanced.model.video_ldm import VideoLDM
12
 
13
+ # Utilities
14
+ from t2v_enhanced.inference_utils import *
15
+ from t2v_enhanced.model_init import *
16
+ from t2v_enhanced.model_func import *
17
 
18
+
19
+ on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--public_access', action='store_true', default=True)
22
+ parser.add_argument('--where_to_log', type=str, default="gradio_output")
23
+ parser.add_argument('--device', type=str, default="cuda")
24
+ args = parser.parse_args()
25
+
26
+
27
+ Path(args.where_to_log).mkdir(parents=True, exist_ok=True)
28
+ result_fol = Path(args.where_to_log).absolute()
29
+ device = args.device
30
+
31
+
32
+ # --------------------------
33
+ # ----- Configurations -----
34
+ # --------------------------
35
+ ckpt_file_streaming_t2v = Path("t2v_enhanced/checkpoints/streaming_t2v.ckpt").absolute()
36
+ cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True}
37
+
38
+
39
+ # --------------------------
40
+ # ----- Initialization -----
41
+ # --------------------------
42
+ ms_model = init_modelscope(device)
43
+ # zs_model = init_zeroscope(device)
44
+ stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
45
+ msxl_model = init_v2v_model(cfg_v2v)
46
+
47
+ inference_generator = torch.Generator(device="cuda")
48
+
49
+
50
+ # -------------------------
51
+ # ----- Functionality -----
52
+ # -------------------------
53
+ def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance, where_to_log=result_fol):
54
+ now = datetime.datetime.now()
55
+ name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
56
+
57
+ if num_frames == [] or num_frames is None:
58
+ num_frames = 56
59
+ else:
60
+ num_frames = int(num_frames.split(" ")[0])
61
+
62
+ n_autoreg_gen = num_frames/8-8
63
+
64
+ inference_generator.manual_seed(seed)
65
+ short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
66
+ stream_long_gen(prompt, short_video, n_autoreg_gen, n_prompt, seed, t, image_guidance, name, stream_cli, stream_model)
67
+ video_path = opj(where_to_log, name+".mp4")
68
+ return video_path
69
+
70
+ def enhance(prompt, input_to_enhance):
71
+ encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model)
72
+ return encoded_video
73
+
74
+
75
+ # --------------------------
76
+ # ----- Gradio-Demo UI -----
77
+ # --------------------------
78
+ with gr.Blocks() as demo:
79
+ gr.HTML(
80
+ """
81
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
82
+ <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
83
+ <a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">StreamingT2V</a>
84
+ </h1>
85
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
86
+ Roberto Henschel<sup>1*</sup>, Levon Khachatryan<sup>1*</sup>, Daniil Hayrapetyan<sup>1*</sup>, Hayk Poghosyan<sup>1</sup>, Vahram Tadevosyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>, Humphrey Shi<sup>1,3</sup>
87
+ </h2>
88
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
89
+ <sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>SHI Labs @ Georgia Tech, Oregon & UIUC
90
+ </h2>
91
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
92
+ *Equal Contribution
93
+ </h2>
94
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
95
+ [<a href="https://arxiv.org/abs/2403.14773" style="color:blue;">arXiv</a>]
96
+ [<a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">GitHub</a>]
97
+ </h2>
98
+ <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
99
+ <b>StreamingT2V</b> is an advanced autoregressive technique that enables the creation of long videos featuring rich motion dynamics without any stagnation.
100
+ It ensures temporal consistency throughout the video, aligns closely with the descriptive text, and maintains high frame-level image quality.
101
+ Our demonstrations include successful examples of videos up to <b>1200 frames, spanning 2 minutes</b>, and can be extended for even longer durations.
102
+ Importantly, the effectiveness of StreamingT2V is not limited by the specific Text2Video model used, indicating that improvements in base models could yield even higher-quality videos.
103
+ </h2>
104
+ </div>
105
+ """)
106
+
107
+ if on_huggingspace:
108
+ gr.HTML("""
109
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
110
+ <br/>
111
+ <a href="https://huggingface.co/spaces/PAIR/StreamingT2V?duplicate=true">
112
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
113
+ </p>""")
114
+
115
+ with gr.Row():
116
+ with gr.Column():
117
+ with gr.Row():
118
+ with gr.Column():
119
+ with gr.Row():
120
+ num_frames = gr.Dropdown(["24", "32", "40", "48", "56", "80 - only on local", "240 - only on local", "600 - only on local", "1200 - only on local", "10000 - only on local"], label="Number of Video Frames: Default is 56", info="For >80 frames use local workstation!")
121
+ with gr.Row():
122
+ prompt_stage1 = gr.Textbox(label='Textual Prompt', placeholder="Ex: Dog running on the street.")
123
+ with gr.Row():
124
+ image_stage1 = gr.Image(label='Image Prompt (only required for I2V base models)', show_label=True, scale=1, show_download_button=True)
125
+ with gr.Column():
126
+ video_stage1 = gr.Video(label='Long Video Preview', show_label=True, interactive=False, scale=2, show_download_button=True)
127
+ with gr.Row():
128
+ run_button_stage1 = gr.Button("Long Video Preview Generation")
129
+
130
+ with gr.Row():
131
+ with gr.Column():
132
+ with gr.Accordion('Advanced options', open=False):
133
+ model_name_stage1 = gr.Dropdown(
134
+ choices=["T2V: ModelScope", "T2V: ZeroScope", "I2V: AnimateDiff"],
135
+ label="Base Model. Default is ModelScope",
136
+ info="Currently supports only ModelScope. We will add more options later!",
137
+ )
138
+ model_name_stage2 = gr.Dropdown(
139
+ choices=["ModelScope-XL", "Another", "Another"],
140
+ label="Enhancement Model. Default is ModelScope-XL",
141
+ info="Currently supports only ModelScope-XL. We will add more options later!",
142
+ )
143
+ n_prompt = gr.Textbox(label="Optional Negative Prompt", value='')
144
+ seed = gr.Slider(label='Seed', minimum=0, maximum=65536, value=33,step=1,)
145
+
146
+ t = gr.Slider(label="Timesteps", minimum=0, maximum=100, value=50, step=1,)
147
+ image_guidance = gr.Slider(label='Image guidance scale', minimum=1, maximum=10, value=9.0, step=1.0)
148
+
149
+ with gr.Column():
150
+ with gr.Row():
151
+ video_stage2 = gr.Video(label='Enhanced Long Video', show_label=True, interactive=False, height=473, show_download_button=True)
152
+ with gr.Row():
153
+ run_button_stage2 = gr.Button("Long Video Enhancement")
154
+ '''
155
+ '''
156
+ gr.HTML(
157
+ """
158
+ <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
159
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
160
+ <b>Version: v1.0</b>
161
+ </h3>
162
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
163
+ <b>Caution</b>:
164
+ We would like the raise the awareness of users of this demo of its potential issues and concerns.
165
+ Like previous large foundation models, StreamingT2V could be problematic in some cases, partially we use pretrained ModelScope, therefore StreamingT2V can Inherit Its Imperfections.
166
+ So far, we keep all features available for research testing both to show the great potential of the StreamingT2V framework and to collect important feedback to improve the model in the future.
167
+ We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
168
+ </h3>
169
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
170
+ <b>Biases and content acknowledgement</b>:
171
+ Beware that StreamingT2V may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
172
+ StreamingT2V in this demo is meant only for research purposes.
173
+ </h3>
174
+ </div>
175
+ """)
176
+
177
+ inputs_t2v = [prompt_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance]
178
+ run_button_stage1.click(fn=generate, inputs=inputs_t2v, outputs=video_stage1,)
179
+
180
+ inputs_v2v = [prompt_stage1, video_stage1]
181
+ run_button_stage2.click(fn=enhance, inputs=inputs_v2v, outputs=video_stage2,)
182
+
183
+
184
+ if on_huggingspace:
185
+ demo.queue(max_size=20)
186
+ demo.launch(debug=True)
187
+ else:
188
+ _, _, link = demo.queue(api_open=False).launch(share=args.public_access)
189
+ print(link)
requirements.txt CHANGED
@@ -22,16 +22,17 @@ scikit-image==0.20.0
22
  scikit-learn==1.2.2
23
  scipy==1.9.1
24
  seaborn==0.12.2
25
- -e .
26
  torch==2.0.0
27
  torchdata==0.6.0
28
  torchvision==0.15.1
 
29
  tqdm==4.65.0
30
  xformers==0.0.19
31
  open-clip-torch==2.24.0
32
- jsonargparse==4.20.1
33
  fairscale==0.4.13
34
  rotary-embedding-torch==0.5.3
35
  easydict==1.13
36
  torchsde==0.2.6
37
- imageio[ffmpeg]==2.25.0
 
 
22
  scikit-learn==1.2.2
23
  scipy==1.9.1
24
  seaborn==0.12.2
 
25
  torch==2.0.0
26
  torchdata==0.6.0
27
  torchvision==0.15.1
28
+ modelscope==1.13.3
29
  tqdm==4.65.0
30
  xformers==0.0.19
31
  open-clip-torch==2.24.0
32
+ jsonargparse[signatures]==4.27.7
33
  fairscale==0.4.13
34
  rotary-embedding-torch==0.5.3
35
  easydict==1.13
36
  torchsde==0.2.6
37
+ imageio[ffmpeg]==2.25.0
38
+ kornia==0.7.2
t2v_enhanced/gradio_demo.py DELETED
@@ -1,189 +0,0 @@
1
- # General
2
- import os
3
- from os.path import join as opj
4
- import argparse
5
- import datetime
6
- from pathlib import Path
7
- import torch
8
- import gradio as gr
9
- import tempfile
10
- import yaml
11
- from t2v_enhanced.model.video_ldm import VideoLDM
12
-
13
- # Utilities
14
- from inference_utils import *
15
- from model_init import *
16
- from model_func import *
17
-
18
-
19
- on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument('--public_access', action='store_true', default=True)
22
- parser.add_argument('--where_to_log', type=str, default="gradio_output")
23
- parser.add_argument('--device', type=str, default="cuda")
24
- args = parser.parse_args()
25
-
26
-
27
- Path(args.where_to_log).mkdir(parents=True, exist_ok=True)
28
- result_fol = Path(args.where_to_log).absolute()
29
- device = args.device
30
-
31
-
32
- # --------------------------
33
- # ----- Configurations -----
34
- # --------------------------
35
- ckpt_file_streaming_t2v = Path("checkpoints/streaming_t2v.ckpt").absolute()
36
- cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True}
37
-
38
-
39
- # --------------------------
40
- # ----- Initialization -----
41
- # --------------------------
42
- ms_model = init_modelscope(device)
43
- # zs_model = init_zeroscope(device)
44
- stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
45
- msxl_model = init_v2v_model(cfg_v2v)
46
-
47
- inference_generator = torch.Generator(device="cuda")
48
-
49
-
50
- # -------------------------
51
- # ----- Functionality -----
52
- # -------------------------
53
- def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance, where_to_log=result_fol):
54
- now = datetime.datetime.now()
55
- name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
56
-
57
- if num_frames == [] or num_frames is None:
58
- num_frames = 56
59
- else:
60
- num_frames = int(num_frames.split(" ")[0])
61
-
62
- n_autoreg_gen = num_frames/8-8
63
-
64
- inference_generator.manual_seed(seed)
65
- short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
66
- stream_long_gen(prompt, short_video, n_autoreg_gen, n_prompt, seed, t, image_guidance, name, stream_cli, stream_model)
67
- video_path = opj(where_to_log, name+".mp4")
68
- return video_path
69
-
70
- def enhance(prompt, input_to_enhance):
71
- encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model)
72
- return encoded_video
73
-
74
-
75
- # --------------------------
76
- # ----- Gradio-Demo UI -----
77
- # --------------------------
78
- with gr.Blocks() as demo:
79
- gr.HTML(
80
- """
81
- <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
82
- <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
83
- <a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">StreamingT2V</a>
84
- </h1>
85
- <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
86
- Roberto Henschel<sup>1*</sup>, Levon Khachatryan<sup>1*</sup>, Daniil Hayrapetyan<sup>1*</sup>, Hayk Poghosyan<sup>1</sup>, Vahram Tadevosyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>, Humphrey Shi<sup>1,3</sup>
87
- </h2>
88
- <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
89
- <sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>SHI Labs @ Georgia Tech, Oregon & UIUC
90
- </h2>
91
- <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
92
- *Equal Contribution
93
- </h2>
94
- <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
95
- [<a href="https://arxiv.org/abs/2403.14773" style="color:blue;">arXiv</a>]
96
- [<a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">GitHub</a>]
97
- </h2>
98
- <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
99
- <b>StreamingT2V</b> is an advanced autoregressive technique that enables the creation of long videos featuring rich motion dynamics without any stagnation.
100
- It ensures temporal consistency throughout the video, aligns closely with the descriptive text, and maintains high frame-level image quality.
101
- Our demonstrations include successful examples of videos up to <b>1200 frames, spanning 2 minutes</b>, and can be extended for even longer durations.
102
- Importantly, the effectiveness of StreamingT2V is not limited by the specific Text2Video model used, indicating that improvements in base models could yield even higher-quality videos.
103
- </h2>
104
- </div>
105
- """)
106
-
107
- if on_huggingspace:
108
- gr.HTML("""
109
- <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
110
- <br/>
111
- <a href="https://huggingface.co/spaces/PAIR/StreamingT2V?duplicate=true">
112
- <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
113
- </p>""")
114
-
115
- with gr.Row():
116
- with gr.Column():
117
- with gr.Row():
118
- with gr.Column():
119
- with gr.Row():
120
- num_frames = gr.Dropdown(["24", "32", "40", "48", "56", "80 - only on local", "240 - only on local", "600 - only on local", "1200 - only on local", "10000 - only on local"], label="Number of Video Frames: Default is 56", info="For >80 frames use local workstation!")
121
- with gr.Row():
122
- prompt_stage1 = gr.Textbox(label='Textual Prompt', placeholder="Ex: Dog running on the street.")
123
- with gr.Row():
124
- image_stage1 = gr.Image(label='Image Prompt (only required for I2V base models)', show_label=True, scale=1, show_download_button=True)
125
- with gr.Column():
126
- video_stage1 = gr.Video(label='Long Video Preview', show_label=True, interactive=False, scale=2, show_download_button=True)
127
- with gr.Row():
128
- run_button_stage1 = gr.Button("Long Video Preview Generation")
129
-
130
- with gr.Row():
131
- with gr.Column():
132
- with gr.Accordion('Advanced options', open=False):
133
- model_name_stage1 = gr.Dropdown(
134
- choices=["T2V: ModelScope", "T2V: ZeroScope", "I2V: AnimateDiff"],
135
- label="Base Model. Default is ModelScope",
136
- info="Currently supports only ModelScope. We will add more options later!",
137
- )
138
- model_name_stage2 = gr.Dropdown(
139
- choices=["ModelScope-XL", "Another", "Another"],
140
- label="Enhancement Model. Default is ModelScope-XL",
141
- info="Currently supports only ModelScope-XL. We will add more options later!",
142
- )
143
- n_prompt = gr.Textbox(label="Optional Negative Prompt", value='')
144
- seed = gr.Slider(label='Seed', minimum=0, maximum=65536, value=33,step=1,)
145
-
146
- t = gr.Slider(label="Timesteps", minimum=0, maximum=100, value=50, step=1,)
147
- image_guidance = gr.Slider(label='Image guidance scale', minimum=1, maximum=10, value=9.0, step=1.0)
148
-
149
- with gr.Column():
150
- with gr.Row():
151
- video_stage2 = gr.Video(label='Enhanced Long Video', show_label=True, interactive=False, height=473, show_download_button=True)
152
- with gr.Row():
153
- run_button_stage2 = gr.Button("Long Video Enhancement")
154
- '''
155
- '''
156
- gr.HTML(
157
- """
158
- <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
159
- <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
160
- <b>Version: v1.0</b>
161
- </h3>
162
- <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
163
- <b>Caution</b>:
164
- We would like the raise the awareness of users of this demo of its potential issues and concerns.
165
- Like previous large foundation models, StreamingT2V could be problematic in some cases, partially we use pretrained ModelScope, therefore StreamingT2V can Inherit Its Imperfections.
166
- So far, we keep all features available for research testing both to show the great potential of the StreamingT2V framework and to collect important feedback to improve the model in the future.
167
- We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
168
- </h3>
169
- <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
170
- <b>Biases and content acknowledgement</b>:
171
- Beware that StreamingT2V may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
172
- StreamingT2V in this demo is meant only for research purposes.
173
- </h3>
174
- </div>
175
- """)
176
-
177
- inputs_t2v = [prompt_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance]
178
- run_button_stage1.click(fn=generate, inputs=inputs_t2v, outputs=video_stage1,)
179
-
180
- inputs_v2v = [prompt_stage1, video_stage1]
181
- run_button_stage2.click(fn=enhance, inputs=inputs_v2v, outputs=video_stage2,)
182
-
183
-
184
- if on_huggingspace:
185
- demo.queue(max_size=20)
186
- demo.launch(debug=True)
187
- else:
188
- _, _, link = demo.queue(api_open=False).launch(share=args.public_access)
189
- print(link)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
t2v_enhanced/inference.py CHANGED
@@ -11,9 +11,9 @@ import yaml
11
  from t2v_enhanced.model.video_ldm import VideoLDM
12
 
13
  # Utilities
14
- from inference_utils import *
15
- from model_init import *
16
- from model_func import *
17
 
18
 
19
  if __name__ == "__main__":
 
11
  from t2v_enhanced.model.video_ldm import VideoLDM
12
 
13
  # Utilities
14
+ from t2v_enhanced.inference_utils import *
15
+ from t2v_enhanced.model_init import *
16
+ from t2v_enhanced.model_func import *
17
 
18
 
19
  if __name__ == "__main__":
t2v_enhanced/model/video_ldm.py CHANGED
@@ -8,7 +8,7 @@ from diffusers.utils.import_utils import is_xformers_available
8
  from einops import rearrange, repeat
9
 
10
  from transformers import CLIPTextModel, CLIPTokenizer
11
- from utils.video_utils import ResultProcessor, save_videos_grid, video_naming
12
 
13
  from t2v_enhanced.model import pl_module_params_controlnet
14
 
 
8
  from einops import rearrange, repeat
9
 
10
  from transformers import CLIPTextModel, CLIPTokenizer
11
+ from t2v_enhanced.utils.video_utils import ResultProcessor, save_videos_grid, video_naming
12
 
13
  from t2v_enhanced.model import pl_module_params_controlnet
14
 
t2v_enhanced/model_func.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from einops import rearrange, repeat
7
 
8
  # Utilities
9
- from inference_utils import *
10
 
11
  from modelscope.outputs import OutputKeys
12
  import imageio
 
6
  from einops import rearrange, repeat
7
 
8
  # Utilities
9
+ from t2v_enhanced.inference_utils import *
10
 
11
  from modelscope.outputs import OutputKeys
12
  import imageio
t2v_enhanced/model_init.py CHANGED
@@ -13,8 +13,8 @@ from diffusers import StableVideoDiffusionPipeline, AutoPipelineForText2Image
13
  import tempfile
14
  import yaml
15
  from t2v_enhanced.model.video_ldm import VideoLDM
16
- from model.callbacks import SaveConfigCallback
17
- from inference_utils import legacy_transformation, remove_value, CustomCLI
18
 
19
  # For Stage-3
20
  from modelscope.pipelines import pipeline
@@ -67,7 +67,7 @@ def init_svd(device="cuda"):
67
 
68
  # Initialize StreamingT2V model.
69
  def init_streamingt2v_model(ckpt_file, result_fol):
70
- config_file = "configs/text_to_video/config.yaml"
71
  sys.argv = sys.argv[:1]
72
  with tempfile.TemporaryDirectory() as tmpdirname:
73
  storage_fol = Path(tmpdirname)
@@ -86,7 +86,7 @@ def init_streamingt2v_model(ckpt_file, result_fol):
86
  sys.argv.append("--result_fol")
87
  sys.argv.append(result_fol.as_posix())
88
  sys.argv.append("--config")
89
- sys.argv.append("configs/inference/inference_long_video.yaml")
90
  sys.argv.append("--data.prompt_cfg.type=prompt")
91
  sys.argv.append(f"--data.prompt_cfg.content='test prompt for initialization'")
92
  sys.argv.append("--trainer.devices=1")
 
13
  import tempfile
14
  import yaml
15
  from t2v_enhanced.model.video_ldm import VideoLDM
16
+ from t2v_enhanced.model.callbacks import SaveConfigCallback
17
+ from t2v_enhanced.inference_utils import legacy_transformation, remove_value, CustomCLI
18
 
19
  # For Stage-3
20
  from modelscope.pipelines import pipeline
 
67
 
68
  # Initialize StreamingT2V model.
69
  def init_streamingt2v_model(ckpt_file, result_fol):
70
+ config_file = "t2v_enhanced/configs/text_to_video/config.yaml"
71
  sys.argv = sys.argv[:1]
72
  with tempfile.TemporaryDirectory() as tmpdirname:
73
  storage_fol = Path(tmpdirname)
 
86
  sys.argv.append("--result_fol")
87
  sys.argv.append(result_fol.as_posix())
88
  sys.argv.append("--config")
89
+ sys.argv.append("t2v_enhanced/configs/inference/inference_long_video.yaml")
90
  sys.argv.append("--data.prompt_cfg.type=prompt")
91
  sys.argv.append(f"--data.prompt_cfg.content='test prompt for initialization'")
92
  sys.argv.append("--trainer.devices=1")