ychenhq commited on
Commit
0bab51a
·
verified ·
1 Parent(s): 04fbff5

Initial test

Browse files
Files changed (1) hide show
  1. app.py +322 -4
app.py CHANGED
@@ -1,7 +1,325 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
1
+ import os
2
+ import sys
3
  import gradio as gr
4
+ import math
5
+ import matplotlib.pyplot as plt
6
+
7
+ import requests
8
+ import fileinput
9
+ import firebase_admin
10
+ from firebase_admin import credentials
11
+ from firebase_admin import firestore
12
+ import gradio as gr
13
+ import json
14
+ import math
15
+ import requests
16
+
17
+
18
+ basedir = "/content/drive/MyDrive/FYP/Code/VideoCrafter"
19
+ # if you change the variables here, remember to change the "name" in .sh file
20
+ vidOut = "results/10videos"
21
+ uvqOut = "results/modified_prompts_eval"
22
+ evalOut = "evaluation_results"
23
+ num_of_vid = 3
24
+ vid_length = 2
25
+ uvq_threshold = 3.8
26
+ fps = 24
27
+
28
+
29
+ # Generate the scores in csv files
30
+ def genScore():
31
+ for i in range(1, num_of_vid+1):
32
+ fileindex = f"{i:04d}"
33
+ os.system(
34
+ f'python3 ./uvq/uvq_main.py --input_files="{fileindex},2,{basedir}/{vidOut}/{fileindex}.mp4" --output_dir {uvqOut} --model_dir ./uvq/models'
35
+ )
36
+
37
+
38
+ def getScore(filename):
39
+ # MOS_score defines the output of the uvq score
40
+ lines = str(filename).split('\n')
41
+ last_line = lines[-1]
42
+ MOS_score = last_line.split(',')[-1]
43
+ MOS_score = MOS_score[:-2]
44
+
45
+ return MOS_score
46
+
47
+ # MOS_score defines the Mean Opinion Score of prediction, if the video's MOS exceeds the threshold then we directly use this video
48
+
49
+
50
+ def chooseBestVideo():
51
+ MOS_score_high = 0
52
+ preferred_output = ""
53
+ chosen_idx = 0
54
+
55
+ for i in range(1, num_of_vid+1):
56
+ '''We loop thru this current processed video'''
57
+ filedir = f"{i:04d}"
58
+ filename = f"{i:04d}_uvq.csv"
59
+ with open(os.path.join(basedir, uvqOut, filedir, filename), 'r') as file:
60
+ MOS = file.read().strip()
61
+
62
+ MOS_score = getScore(MOS)
63
+ print("Video Index:", f"{i:04d}", "Score:", MOS_score)
64
+
65
+ # if the MOS_score is higher than the previous video, we choose this video as our preferred video output
66
+ if float(MOS_score) > float(MOS_score_high) or float(MOS_score) > uvq_threshold:
67
+ MOS_score_high = MOS_score
68
+ preferred_output = filename
69
+ chosen_idx = i
70
+
71
+ if float(MOS_score) > uvq_threshold:
72
+ break
73
+ return chosen_idx
74
+ # print(MOS_score_high)
75
+ # print(preferred_output)
76
+
77
+
78
+ def extract_scores_from_json(json_path):
79
+ with open(json_path) as file:
80
+ data = json.load(file)
81
+
82
+ for key, value in data.items():
83
+ if isinstance(value, list) and len(value) > 1 and isinstance(value[0], float):
84
+ motion_score = value[0]
85
+
86
+ return motion_score
87
+
88
+
89
+ def VBench_eval(vid_filename):
90
+ # vid_filename: video filename without .mp4
91
+ os.system(
92
+ f'python VBench/evaluate.py --dimension "motion_smoothness" --videos_path {os.path.join(basedir, vidOut, vid_filename)}.mp4 --custom_input --output_filename {vid_filename}'
93
+ )
94
+ eval_file_path = os.path.join(
95
+ basedir, evalOut, f"{vid_filename}_eval_results.json")
96
+ motion_score = extract_scores_from_json(eval_file_path)
97
+
98
+ return motion_score
99
+
100
+
101
+ def interpolation(chosen_idx, fps):
102
+ vid_filename = f"{chosen_idx:04d}.mp4"
103
+ os.chdir("/content/drive/MyDrive/FYP/Code/VideoCrafter/ECCV2022-RIFE")
104
+ os.system(
105
+ f'python3 inference_video.py --exp=2 --video={os.path.join(basedir, vidOut, vid_filename)} --fps {fps}'
106
+ )
107
+ os.chdir("/content/drive/MyDrive/FYP/Code/VideoCrafter")
108
+ out_name = f"{chosen_idx:04d}_4X_{fps}fps.mp4"
109
+ return out_name
110
+
111
+ # call the GPT API here
112
+
113
+
114
+ def call_gpt_api(prompt, isSentence=False):
115
+ api_key = "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"
116
+
117
+ response = requests.post(
118
+ 'https://api.openai.com/v1/chat/completions',
119
+ headers={
120
+ 'Content-Type': 'application/json',
121
+ 'Authorization': f'Bearer {api_key}'
122
+ },
123
+ json={
124
+ 'messages': [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': prompt}],
125
+ 'model': 'gpt-3.5-turbo',
126
+ # 'prompt': prompt,
127
+ 'temperature': 0.4,
128
+ 'max_tokens': 200
129
+ })
130
+ response_json = response.json()
131
+ choices = response_json['choices']
132
+ contents = [choice['message']['content'] for choice in choices]
133
+ contents = [
134
+ sentence for sublist in contents for sentence in sublist.split('\n')]
135
+ # Remove the leading number and dot from each sentence
136
+ sentences = [content.lstrip('1234567890.- ') for content in contents]
137
+ if len(sentences) > 2 and isSentence:
138
+ sentences = sentences[1:]
139
+ return sentences
140
+
141
+
142
+ # Initialize Firebase Admin SDK
143
+ cred = credentials.Certificate(
144
+ f"{basedir}/final-year-project-443dd-df6f48af0796.json")
145
+ firebase_admin.initialize_app(cred)
146
+ # Initialize Firestore client
147
+ db = firestore.client()
148
+
149
+
150
+ def retrieve_user_feedback():
151
+ # Retrieve user feedback from Firestore
152
+ feedback_collection = db.collection("user_feedbacks")
153
+ feedback_docs = feedback_collection.get()
154
+
155
+ feedback_text = []
156
+ experience = []
157
+ for doc in feedback_docs:
158
+ data = doc.to_dict()
159
+ feedback_text.append(data.get('feedback_text', None))
160
+ experience.append(data.get('experience', None))
161
+
162
+ return feedback_text, experience
163
+
164
+
165
+ feedback_text, experience = retrieve_user_feedback()
166
+ # print("Feedback Text:", feedback_text)
167
+ # print("Experience:", experience)
168
+
169
+
170
+ def store_user_feedback(feedback_text, experience):
171
+ # Get a reference to the Firestore collection
172
+ feedback_collection = db.collection("user_feedbacks")
173
+
174
+ # Create a new document with feedback_text and experience fields
175
+ feedback_collection.add({
176
+ 'feedback_text': feedback_text,
177
+ 'experience': experience
178
+ })
179
+ return
180
+
181
+
182
+ t2v_examples = [
183
+ ['A tiger walks in the forest, photorealistic, 4k, high definition'],
184
+ ['an elephant is walking under the sea, 4K, high definition'],
185
+ ['an astronaut riding a horse in outer space'],
186
+ ['a monkey is playing a piano'],
187
+ ['A fire is burning on a candle'],
188
+ ['a horse is drinking in the river'],
189
+ ['Robot dancing in times square'],
190
+ ]
191
+
192
+
193
+ def generate_output(input_text, output_video_1, fps, examples):
194
+ def generate_output_fn(input_text, output_video_1, fps, examples):
195
+ if input_text == "":
196
+ return input_text, output_video_1, examples
197
+ output = call_gpt_api(
198
+ prompt=f"Generate 2 similar prompts and add some reasonable words to the given prompt and not change the meaning, each within 30 words: {input_text}", isSentence=True)
199
+ output.append(input_text)
200
+ with open(f"{basedir}/prompts/test_prompts.txt", 'w') as file:
201
+ for i, sentence in enumerate(output):
202
+ if i < len(output) - 1:
203
+ file.write(sentence + '\n')
204
+ else:
205
+ file.write(sentence)
206
+ os.system(
207
+ f'sh {os.path.join(basedir, "scripts", "run_text2video.sh")}')
208
+ # Connect the video output and return the video corresponding link
209
+ genScore()
210
+ chosen_idx = chooseBestVideo()
211
+ chosen_vid_path = interpolation(chosen_idx, fps)
212
+ chosen_vid_path = f"{basedir}/{vidOut}/{chosen_vid_path}"
213
+ # chosen_vid_path = "/content/drive/MyDrive/FYP/Code/VideoCrafter/results/cat/0002_4X_16fps.mp4"
214
+ output_video_1 = gr.Video(
215
+ value=chosen_vid_path, show_download_button=True)
216
+
217
+ examples_list = call_gpt_api(
218
+ prompt=f"Generate 5 similar prompts that makes a storyline coming after the given input, each within 10 words: {input_text}")
219
+ examples = []
220
+ for prompt in examples_list:
221
+ examples.append([prompt])
222
+ input_text = ""
223
+
224
+ return input_text, output_video_1, examples
225
+
226
+ return generate_output_fn(input_text, output_video_1, fps, examples)
227
+
228
+
229
+ def t2v_demo(result_dir='./tmp/'):
230
+ with gr.Blocks() as videocrafter_iface:
231
+ gr.Markdown("<div align='center'> <h2> VideoCraftXtend: AI-Enhanced Text-to-Video Generation with Extended Length and Enhanced Motion Smoothness </span> </h2> </div>")
232
+
233
+ # Initialize values for video length and fps
234
+ video_len_value = 5.0
235
+
236
+ def update_fps(video_len, fps):
237
+ fps_value = 80 / video_len
238
+ return f"<div justify-content: 'center'; text-align='center'> <h6> FPS (frames per second) : {int(fps_value)} </span> </h6> </div>"
239
+
240
+ def load_example(example_id):
241
+ return example_id[0]
242
+
243
+ def update_feedback(value, text):
244
+ labels = ['Positive', 'Neutral', 'Negative']
245
+ colors = ['#66c2a5', '#fc8d62', '#8da0cb']
246
+ if value != '':
247
+ store_user_feedback(value, text)
248
+ user_satisfaction.append(value)
249
+ value = ''
250
+ if text != '':
251
+ user_feedback.append(text)
252
+ text = ''
253
+ user_feedback, user_satisfaction = retrieve_user_feedback()
254
+ sizes = [user_satisfaction.count('Positive'), user_satisfaction.count(
255
+ 'Neutral'), user_satisfaction.count('Negative')]
256
+ plt.pie(sizes, labels=labels, autopct='%1.1f%%',
257
+ startangle=140, colors=colors)
258
+ plt.axis('equal')
259
+ return plt
260
+
261
+ with gr.Tab(label="Text2Video"):
262
+ with gr.Column():
263
+ with gr.Row():
264
+ with gr.Column():
265
+ input_text = gr.Text(
266
+ placeholder=t2v_examples[2], label='Please input your prompt here.')
267
+ with gr.Row():
268
+ examples = gr.Dataset(samples=t2v_examples, components=[
269
+ input_text], label='Sample prompts that can be used to form a storyline.')
270
+ with gr.Column():
271
+ gr.Markdown(
272
+ "<div align='center'> <h4> Modify video length and the corresponding fps will be shown on the right. </span> </h4> </div>")
273
+ with gr.Row():
274
+ video_len = gr.Slider(minimum=4.0, maximum=10.0, step=1, label='Video Length',
275
+ value=video_len_value, elem_id="video_len", interactive=True)
276
+ fps = gr.Markdown(
277
+ elem_id="fps", value=f"<div> <h6> FPS (frames per second) : 16</span> </h6> </div>")
278
+ send_btn = gr.Button("Send")
279
+ with gr.Column():
280
+ with gr.Tab(label='Result'):
281
+ with gr.Row():
282
+ output_video_1 = gr.Video(
283
+ value="/content/drive/MyDrive/FYP/Code/VideoCrafter/results/10videos/0009.mp4", show_download_button=True)
284
+
285
+ video_len.change(update_fps, inputs=[video_len, fps], outputs=fps)
286
+ # fps.change(update_video_len_slider, inputs = fps, outputs = video_len)
287
+
288
+ examples.click(load_example, inputs=[
289
+ examples], outputs=[input_text])
290
+ send_btn.click(
291
+ fn=generate_output,
292
+ inputs=[input_text, output_video_1, fps, examples],
293
+ outputs=[input_text, output_video_1, examples],
294
+ )
295
+
296
+ with gr.Tab(label="Feedback"):
297
+ with gr.Column():
298
+ with gr.Column():
299
+ with gr.Row():
300
+ feedback_value = gr.Radio(
301
+ ['Positive', 'Neutral', 'Negative'], label="How is your experience?")
302
+ feedback_text = gr.Textbox(
303
+ placeholder="Enter feedback here", label="Feedback Text")
304
+ with gr.Row():
305
+ cancel_btn = gr.Button("Clear")
306
+ submit_btn = gr.Button("Submit")
307
+ with gr.Row():
308
+ pie_chart = gr.Plot(value=update_feedback(
309
+ '', ''), label="Feedback Pie Chart")
310
+ with gr.Column():
311
+ gr.Markdown(
312
+ "<div align='center'> <h4> Feedbacks from users: </span> </h4> </div>")
313
+ feedback_text_display = [gr.Markdown(
314
+ feedback, label="User Feedback") for feedback in retrieve_user_feedback()[0]]
315
+ submit_btn.click(fn=update_feedback, inputs=[
316
+ feedback_value, feedback_text], outputs=[pie_chart])
317
+
318
+ return videocrafter_iface
319
 
 
 
320
 
321
+ if __name__ == "__main__":
322
+ result_dir = os.path.join('./', 'results')
323
+ t2v_iface = t2v_demo(result_dir)
324
+ t2v_iface.queue(max_size=10)
325
+ t2v_iface.launch(debug=True)