Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,75 +1,66 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            import sys
         | 
| 3 | 
             
            import gradio as gr
         | 
|  | |
| 4 | 
             
            import tempfile
         | 
| 5 | 
             
            from huggingface_hub import snapshot_download
         | 
|  | |
| 6 | 
             
            import spaces
         | 
| 7 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 |  | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
            LORA_PATH = os.path.join(LORA_DIR, "pusa_v1.pt")
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            @spaces.GPU
         | 
| 16 | 
            -
            def generate_video(prompt, lora_upload):
         | 
| 17 | 
            -
                # Download Wan2.1 model only if missing
         | 
| 18 | 
            -
                if not os.path.exists(WAN_MODEL_DIR):
         | 
| 19 | 
             
                    snapshot_download(
         | 
| 20 | 
            -
                        repo_id= | 
| 21 | 
            -
                         | 
| 22 | 
            -
                         | 
|  | |
| 23 | 
             
                        local_dir_use_symlinks=False,
         | 
| 24 | 
            -
                        resume_download=True,
         | 
| 25 | 
             
                    )
         | 
|  | |
| 26 |  | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
                 | 
| 31 | 
            -
                    if not os.path.exists(LORA_PATH):
         | 
| 32 | 
            -
                        os.makedirs(LORA_DIR, exist_ok=True)
         | 
| 33 | 
            -
                        snapshot_download(
         | 
| 34 | 
            -
                            repo_id="RaphaelLiu/PusaV1",
         | 
| 35 | 
            -
                            allow_patterns=["PusaV1/pusa_v1.pt.part*"],
         | 
| 36 | 
            -
                            local_dir=LORA_DIR,
         | 
| 37 | 
            -
                            local_dir_use_symlinks=False,
         | 
| 38 | 
            -
                        )
         | 
| 39 | 
            -
                        # Stitch parts
         | 
| 40 | 
            -
                        part_files = sorted(
         | 
| 41 | 
            -
                            f for f in os.listdir(LORA_DIR) if f.startswith("pusa_v1.pt.part")
         | 
| 42 | 
            -
                        )
         | 
| 43 | 
            -
                        with open(LORA_PATH, "wb") as wfd:
         | 
| 44 | 
            -
                            for part in part_files:
         | 
| 45 | 
            -
                                with open(os.path.join(LORA_DIR, part), "rb") as fd:
         | 
| 46 | 
            -
                                    wfd.write(fd.read())
         | 
| 47 |  | 
| 48 | 
            -
             | 
|  | |
|  | |
| 49 |  | 
| 50 | 
            -
                #  | 
| 51 | 
            -
                 | 
| 52 | 
            -
                 | 
| 53 | 
            -
                pipe.set_lora_adapters(lora_path)
         | 
| 54 |  | 
| 55 | 
            -
                #  | 
| 56 | 
            -
                result | 
| 57 |  | 
| 58 | 
            -
                # Save video | 
| 59 | 
             
                tmp_dir = tempfile.mkdtemp()
         | 
| 60 | 
            -
                 | 
| 61 | 
            -
                save_video(result.frames,  | 
| 62 | 
            -
             | 
| 63 | 
            -
                return video_path
         | 
| 64 |  | 
|  | |
| 65 |  | 
|  | |
| 66 | 
             
            with gr.Blocks() as demo:
         | 
| 67 | 
            -
                gr.Markdown(" | 
| 68 | 
            -
             | 
| 69 | 
            -
                 | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 72 |  | 
| 73 | 
            -
                generate_btn.click(fn=generate_video, inputs= | 
| 74 |  | 
| 75 | 
             
            demo.launch()
         | 
|  | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
             
            import tempfile
         | 
| 4 | 
             
            from huggingface_hub import snapshot_download
         | 
| 5 | 
            +
            from diffsynth import ModelManager, WanVideoPusaPipeline, save_video
         | 
| 6 | 
             
            import spaces
         | 
| 7 |  | 
| 8 | 
            +
            # Constants
         | 
| 9 | 
            +
            WAN_SUBFOLDER = "Wan2.1-T2V-14B"
         | 
| 10 | 
            +
            MODEL_REPO_ID = "RaphaelLiu/PusaV1"
         | 
| 11 | 
            +
            MODEL_ZOO_DIR = "./model_zoo"
         | 
| 12 | 
            +
            WAN_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, WAN_SUBFOLDER)
         | 
| 13 | 
            +
            LORA_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
         | 
| 14 |  | 
| 15 | 
            +
            # Ensure model is downloaded
         | 
| 16 | 
            +
            def ensure_model_downloaded():
         | 
| 17 | 
            +
                if not os.path.exists(WAN_MODEL_PATH):
         | 
| 18 | 
            +
                    print("Downloading Wan2.1-T2V-14B from HuggingFace Hub...")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 19 | 
             
                    snapshot_download(
         | 
| 20 | 
            +
                        repo_id=MODEL_REPO_ID,
         | 
| 21 | 
            +
                        local_dir=MODEL_ZOO_DIR,
         | 
| 22 | 
            +
                        repo_type="model",
         | 
| 23 | 
            +
                        allow_patterns=[f"{WAN_SUBFOLDER}/**"],
         | 
| 24 | 
             
                        local_dir_use_symlinks=False,
         | 
|  | |
| 25 | 
             
                    )
         | 
| 26 | 
            +
                    print("Model downloaded.")
         | 
| 27 |  | 
| 28 | 
            +
            # Video generation logic
         | 
| 29 | 
            +
            @spaces.GPU
         | 
| 30 | 
            +
            def generate_video(prompt: str):
         | 
| 31 | 
            +
                ensure_model_downloaded()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 32 |  | 
| 33 | 
            +
                # Load model
         | 
| 34 | 
            +
                manager = ModelManager(pretrained_model_dir=WAN_MODEL_PATH)
         | 
| 35 | 
            +
                model = manager.load_model()
         | 
| 36 |  | 
| 37 | 
            +
                # Set up pipeline
         | 
| 38 | 
            +
                pipeline = WanVideoPusaPipeline(model=model)
         | 
| 39 | 
            +
                pipeline.set_lora_adapters(LORA_PATH)
         | 
|  | |
| 40 |  | 
| 41 | 
            +
                # Generate video
         | 
| 42 | 
            +
                result = pipeline(prompt)
         | 
| 43 |  | 
| 44 | 
            +
                # Save video
         | 
| 45 | 
             
                tmp_dir = tempfile.mkdtemp()
         | 
| 46 | 
            +
                output_path = os.path.join(tmp_dir, "video.mp4")
         | 
| 47 | 
            +
                save_video(result.frames, output_path, fps=8)
         | 
|  | |
|  | |
| 48 |  | 
| 49 | 
            +
                return output_path
         | 
| 50 |  | 
| 51 | 
            +
            # Gradio UI
         | 
| 52 | 
             
            with gr.Blocks() as demo:
         | 
| 53 | 
            +
                gr.Markdown("## 🎥 Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator")
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                prompt_input = gr.Textbox(
         | 
| 56 | 
            +
                    lines=4,
         | 
| 57 | 
            +
                    label="Prompt",
         | 
| 58 | 
            +
                    placeholder="Describe your video (e.g. A coral reef full of colorful fish...)"
         | 
| 59 | 
            +
                )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                generate_btn = gr.Button("Generate Video")
         | 
| 62 | 
            +
                video_output = gr.Video(label="Output")
         | 
| 63 |  | 
| 64 | 
            +
                generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
         | 
| 65 |  | 
| 66 | 
             
            demo.launch()
         | 
