yerang commited on
Commit
7fb1673
·
verified ·
1 Parent(s): 73f874e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -1
app.py CHANGED
@@ -20,13 +20,81 @@ import cv2
20
  from elevenlabs_utils import ElevenLabsPipeline
21
  from setup_environment import initialize_environment
22
  from src.utils.video import extract_audio
23
- from flux_dev import create_flux_tab
 
24
 
25
  # import gdown
26
  # folder_url = f"https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib"
27
  # gdown.download_folder(url=folder_url, output="pretrained_weights", quiet=False)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  initialize_environment()
31
 
32
  import sys
@@ -86,6 +154,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
86
  with gr.Row():
87
  driving_video_path.render()
88
 
 
 
 
89
 
90
 
91
 
 
20
  from elevenlabs_utils import ElevenLabsPipeline
21
  from setup_environment import initialize_environment
22
  from src.utils.video import extract_audio
23
+ #from flux_dev import create_flux_tab
24
+ from diffusers import FluxPipeline
25
 
26
  # import gdown
27
  # folder_url = f"https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib"
28
  # gdown.download_folder(url=folder_url, output="pretrained_weights", quiet=False)
29
 
30
 
31
+
32
+ #========================= # FLUX 모델 로드 설정
33
+ flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
34
+ flux_pipe.enable_sequential_cpu_offload()
35
+ flux_pipe.vae.enable_slicing()
36
+ flux_pipe.vae.enable_tiling()
37
+ flux_pipe.to(torch.float16)
38
+
39
+ def generate_image(prompt, guidance_scale, width, height):
40
+ # 이미지를 생성하는 함수
41
+ output_image = flux_pipe(
42
+ prompt=prompt,
43
+ guidance_scale=guidance_scale,
44
+ height=height,
45
+ width=width,
46
+ num_inference_steps=4,
47
+ max_sequence_length=256,
48
+ ).images[0]
49
+
50
+ # 결과 폴더 생성
51
+ result_folder = "result_flux"
52
+ os.makedirs(result_folder, exist_ok=True)
53
+
54
+ # 파일 이름 생성
55
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
56
+ #filename = f"{prompt.replace(' ', '_')}_{timestamp}.png"
57
+ filename = f"{'_'.join(prompt.split()[:3])}_{timestamp}.png"
58
+ output_path = os.path.join(result_folder, filename)
59
+
60
+ # 이미지를 저장
61
+ output_image.save(output_path)
62
+
63
+ return output_image, output_path # 두 개의 출력 반환
64
+
65
+ def flux_tab(): #image_input): # image_input을 인자로 받습니다.
66
+ with gr.Tab("FLUX 이미지 생성"):
67
+ with gr.Row():
68
+ with gr.Column():
69
+ # 사용자 입력 설정
70
+ prompt = gr.Textbox(label="Prompt", value="A cat holding a sign that says hello world")
71
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, value=3.5, step=0.1)
72
+ width = gr.Slider(label="Width", minimum=256, maximum=2048, value=512, step=64)
73
+ height = gr.Slider(label="Height", minimum=256, maximum=2048, value=512, step=64)
74
+
75
+ with gr.Column():
76
+ # 출력 이미지와 다운로드 버튼
77
+ output_image = gr.Image(type="pil", label="Output")
78
+ download_button = gr.File(label="Download")
79
+ generate_button = gr.Button("이미지 생성")
80
+ use_in_text2lipsync_button = gr.Button("이 이미지를 Text2Lipsync에서 사용하기") # 새로운 버튼 추가
81
+
82
+ # 클릭 이벤트를 정의
83
+ generate_button.click(
84
+ fn=generate_image,
85
+ inputs=[prompt, guidance_scale, width, height],
86
+ outputs=[output_image, download_button]
87
+ )
88
+
89
+ # 새로운 버튼 클릭 이벤트 정의
90
+ use_in_text2lipsync_button.click(
91
+ fn=lambda img: img, # 간단한 람다 함수를 사용하여 이미지를 그대로 전달
92
+ inputs=[output_image], # 생성된 이미지를 입력으로 사용
93
+ outputs=[image_input] # Text to LipSync 탭의 image_input을 업데이트
94
+ )
95
+
96
+ #========================= # FLUX 모델 로드 설정
97
+
98
  initialize_environment()
99
 
100
  import sys
 
154
  with gr.Row():
155
  driving_video_path.render()
156
 
157
+ with gr.Row():
158
+ flux_tab() # image_input을 flux_tab에 전달합니다.
159
+
160
 
161
 
162