dylanebert HF staff commited on
Commit
0353592
1 Parent(s): ab3a2b5

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv/
2
+ output*
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Multi View Diffusion
3
- emoji: 📚
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.24.0
8
  app_file: app.py
 
1
  ---
2
  title: Multi View Diffusion
3
+ emoji: 🧊
4
+ colorFrom: indigo
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.24.0
8
  app_file: app.py
__pycache__/app.cpython-310.pyc ADDED
Binary file (2.36 kB). View file
 
__pycache__/text_pipeline.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ from diffusers import DiffusionPipeline
5
+ from PIL import Image
6
+
7
+ # Text-to-Multi-View Diffusion pipeline
8
+ text_pipeline = DiffusionPipeline.from_pretrained(
9
+ "ashawkey/mvdream-sd2.1-diffusers",
10
+ custom_pipeline="dylanebert/multi_view_diffusion",
11
+ torch_dtype=torch.float16,
12
+ trust_remote_code=True,
13
+ ).to("cuda")
14
+
15
+ # Image-to-Multi-View Diffusion pipeline
16
+ image_pipeline = DiffusionPipeline.from_pretrained(
17
+ "ashawkey/imagedream-ipmv-diffusers",
18
+ custom_pipeline="dylanebert/multi_view_diffusion",
19
+ torch_dtype=torch.float16,
20
+ trust_remote_code=True,
21
+ ).to("cuda")
22
+
23
+
24
+ def create_image_grid(images):
25
+ images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
26
+
27
+ width, height = images[0].size
28
+ grid_img = Image.new("RGB", (2 * width, 2 * height))
29
+
30
+ grid_img.paste(images[0], (0, 0))
31
+ grid_img.paste(images[1], (width, 0))
32
+ grid_img.paste(images[2], (0, height))
33
+ grid_img.paste(images[3], (width, height))
34
+
35
+ return grid_img
36
+
37
+
38
+ def text_to_mv(prompt):
39
+ images = text_pipeline(
40
+ prompt, guidance_scale=5, num_inference_steps=30, elevation=0
41
+ )
42
+ return create_image_grid(images)
43
+
44
+
45
+ def image_to_mv(image, prompt):
46
+ image = image.astype("float32") / 255.0
47
+ images = image_pipeline(
48
+ prompt, image, guidance_scale=5, num_inference_steps=30, elevation=0
49
+ )
50
+ return create_image_grid(images)
51
+
52
+
53
+ with gr.Blocks() as demo:
54
+ with gr.Row():
55
+ with gr.Column():
56
+ with gr.Tab("Text Input"):
57
+ text_input = gr.Textbox(
58
+ lines=2,
59
+ show_label=False,
60
+ placeholder="Enter a prompt here (e.g. 'a cat statue')",
61
+ )
62
+ text_btn = gr.Button("Generate Multi-View Images")
63
+ with gr.Tab("Image Input"):
64
+ image_input = gr.Image(
65
+ label="Image Input",
66
+ type="numpy",
67
+ )
68
+ optional_text_input = gr.Textbox(
69
+ lines=2,
70
+ show_label=False,
71
+ placeholder="Enter an optional prompt here",
72
+ )
73
+ image_btn = gr.Button("Generate Multi-View Images")
74
+ with gr.Column():
75
+ output = gr.Image(label="Generated Images")
76
+
77
+ text_btn.click(fn=text_to_mv, inputs=text_input, outputs=output)
78
+ image_btn.click(
79
+ fn=image_to_mv, inputs=[image_input, optional_text_input], outputs=output
80
+ )
81
+
82
+
83
+ if __name__ == "__main__":
84
+ demo.queue().launch()
data/cat_statue.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.28.0
2
+ aiofiles==23.2.1
3
+ altair==5.2.0
4
+ annotated-types==0.6.0
5
+ anyio==4.3.0
6
+ attrs==23.2.0
7
+ certifi==2024.2.2
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ colorama==0.4.6
11
+ contourpy==1.2.0
12
+ cycler==0.12.1
13
+ diffusers==0.27.2
14
+ einops==0.7.0
15
+ exceptiongroup==1.2.0
16
+ executing==2.0.1
17
+ fastapi==0.110.0
18
+ ffmpy==0.3.2
19
+ filelock==3.13.3
20
+ fonttools==4.50.0
21
+ fsspec==2024.3.1
22
+ gradio==4.24.0
23
+ gradio_client==0.14.0
24
+ h11==0.14.0
25
+ httpcore==1.0.5
26
+ httpx==0.27.0
27
+ huggingface-hub==0.22.2
28
+ idna==3.6
29
+ importlib_metadata==7.1.0
30
+ importlib_resources==6.4.0
31
+ Jinja2==3.1.3
32
+ jsonschema==4.21.1
33
+ jsonschema-specifications==2023.12.1
34
+ kiui==0.2.7
35
+ kiwisolver==1.4.5
36
+ lazy_loader==0.3
37
+ markdown-it-py==3.0.0
38
+ MarkupSafe==2.1.5
39
+ matplotlib==3.8.3
40
+ mdurl==0.1.2
41
+ mpmath==1.3.0
42
+ networkx==3.2.1
43
+ numpy==1.26.4
44
+ nvidia-cublas-cu12==12.1.3.1
45
+ nvidia-cuda-cupti-cu12==12.1.105
46
+ nvidia-cuda-nvrtc-cu12==12.1.105
47
+ nvidia-cuda-runtime-cu12==12.1.105
48
+ nvidia-cudnn-cu12==8.9.2.26
49
+ nvidia-cufft-cu12==11.0.2.54
50
+ nvidia-curand-cu12==10.3.2.106
51
+ nvidia-cusolver-cu12==11.4.5.107
52
+ nvidia-cusparse-cu12==12.1.0.106
53
+ nvidia-nccl-cu12==2.19.3
54
+ nvidia-nvjitlink-cu12==12.4.99
55
+ nvidia-nvtx-cu12==12.1.105
56
+ objprint==0.2.3
57
+ opencv-python==4.9.0.80
58
+ orjson==3.10.0
59
+ packaging==24.0
60
+ pandas==2.2.1
61
+ pillow==10.2.0
62
+ psutil==5.9.8
63
+ pydantic==2.6.4
64
+ pydantic_core==2.16.3
65
+ pydub==0.25.1
66
+ Pygments==2.17.2
67
+ pyparsing==3.1.2
68
+ python-dateutil==2.9.0.post0
69
+ python-multipart==0.0.9
70
+ pytz==2024.1
71
+ PyYAML==6.0.1
72
+ referencing==0.34.0
73
+ regex==2023.12.25
74
+ requests==2.31.0
75
+ rich==13.7.1
76
+ rpds-py==0.18.0
77
+ ruff==0.3.4
78
+ safetensors==0.4.2
79
+ scipy==1.12.0
80
+ semantic-version==2.10.0
81
+ shellingham==1.5.4
82
+ six==1.16.0
83
+ sniffio==1.3.1
84
+ starlette==0.36.3
85
+ sympy==1.12
86
+ tokenizers==0.15.2
87
+ tomlkit==0.12.0
88
+ toolz==0.12.1
89
+ torch==2.2.2
90
+ torchaudio==2.2.2
91
+ torchvision==0.17.2
92
+ tqdm==4.66.2
93
+ transformers==4.39.2
94
+ triton==2.2.0
95
+ typer==0.11.1
96
+ typing_extensions==4.10.0
97
+ tzdata==2024.1
98
+ urllib3==2.2.1
99
+ uvicorn==0.29.0
100
+ varname==0.13.0
101
+ websockets==11.0.3
102
+ xformers==0.0.25.post1
103
+ zipp==3.18.1