zz912 commited on
Commit
ca1b139
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .specstory/
3
+ .venv/
4
+ .idea/
5
+ .vscode/
6
+ .pytest_cache/
7
+ .ruff_cache/
8
+ models/
9
+
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Pixel GArt
3
+ emoji: 🌖
4
+ colorFrom: red
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.35.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: AI tool for turning sketches into pixel art.
12
+ ---
13
+
14
+ ## Introduction
15
+
16
+ This project is a pixel art generator that leverages Stable Diffusion 1.5 combined with the PixelArtRedmond LoRA to transform hand-drawn sketches into stunning retro-style pixel masterpieces. Designed to unleash creativity, it offers users an easy way to create high-quality pixel art inspired by classic video games.
17
+
18
+ ## 📦 Installation
19
+
20
+ ### 1. Clone the project
21
+
22
+ ```
23
+ git clone git@github.com:sosiki1997/Pixel_GArt.git
24
+
25
+ cd Pixel_GArt
26
+ ```
27
+
28
+ ### 2. Create and activate conda environment
29
+
30
+ ```
31
+ conda create -n pixel_venv python=3.12
32
+
33
+ conda activate pixel_venv
34
+ ```
35
+
36
+ ### 3. Install dependencies
37
+
38
+ ```
39
+ pip install gradio==3.44.4
40
+
41
+ pip install -r requirements.txt
42
+
43
+ pip install torch torchvision opencv-python pillow
44
+
45
+ pip install git+https://github.com/facebookresearch/segment-anything.git
46
+ ```
47
+
48
+ ### 4. Run the server
49
+
50
+ ```
51
+ python -m app.main
52
+ ```
53
+
54
+ ### 5. Open in browser
55
+
56
+ ```
57
+ http://127.0.0.1:7860
58
+ ```
59
+
60
+ <details>
61
+ <summary>📖 中文说明(点击展开)</summary>
62
+
63
+
64
+ 本项目是一个像素画生成器,结合了 Stable Diffusion 1.5 和 PixelArtRedmond LoRA,能将手绘草图转化为令人惊艳的复古风格像素艺术。旨在释放创意,让用户轻松创作出高品质的像素艺术作品,灵感源自经典电子游戏。
65
+
66
+ </details>
67
+
68
+ <details>
69
+ <summary>📖 日本語の説明(クリックで展開)</summary>
70
+
71
+
72
+ 本プロジェクトは、Stable Diffusion 1.5 と PixelArtRedmond LoRA を組み合わせて、手描きのスケッチを圧巻のレトロ風ピクセルアートに変換するジェネレーターです。クラシックゲームにインスパイアされた高品質なピクセルアートを手軽に制作できることを目的としています。
73
+
74
+ </details>
75
+
76
+ ---
77
+
78
+ ## 🖼️ 出力例(生成画像)
79
+
80
+ <p align="center">
81
+ <img src="./readme_img/output_1_snapshot.png" width="600"/>
82
+ <img src="./readme_img/output_2_snapshot.png" width="600"/>
83
+ <img src="./readme_img/output_3_snapshot.png" width="600"/>
84
+ <img src="./readme_img/output_4_snapshot.png" width="600"/>
85
+ </p>
__pycache__/app.cpython-312.pyc ADDED
Binary file (560 Bytes). View file
 
app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - Hugging Face 和本地都能用的入口
2
+
3
+ import sys
4
+ import os
5
+
6
+ # 把 app 目录加入路径(方便导入 main)
7
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "app")))
8
+
9
+ from main import launch_app
10
+
11
+ # ✅ 无论是否 __main__ 都调用 launch_app(HF Spaces 也能运行)
12
+ launch_app()
app/__pycache__/main.cpython-312.pyc ADDED
Binary file (512 Bytes). View file
 
app/__pycache__/main.cpython-39.pyc ADDED
Binary file (2.05 kB). View file
 
app/interface/__pycache__/gradio_ui.cpython-312.pyc ADDED
Binary file (6.14 kB). View file
 
app/interface/gradio_ui.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ from PIL import Image
5
+ import numpy as np
6
+ import io
7
+
8
+ # 添加父目录到路径,以便导入 utils 模块
9
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+ from utils.image_processing import ImageProcessor
11
+
12
+ # 导入我们的处理函数
13
+ from utils.process import process_image
14
+
15
+ def create_gradio_interface(generator):
16
+ """创建Gradio界面"""
17
+
18
+ def handle_image_opencv(input_image, pixel_size=20):
19
+ """使用OpenCV处理上传的图像"""
20
+ if input_image is None:
21
+ return None, "请上传或绘制图像"
22
+
23
+ try:
24
+ # 调用OpenCV处理函数
25
+ result_image = process_image(input_image, pixel_size=pixel_size)
26
+ return result_image, "処理が成功しました"
27
+ except Exception as e:
28
+ import traceback
29
+ traceback.print_exc()
30
+ return None, f"处理失败: {str(e)}"
31
+
32
+ def handle_image_diffusion(input_image, prompt, guidance_scale=7.5):
33
+ """使用Stable Diffusion处理上传的图像"""
34
+ if input_image is None:
35
+ return None, "请上传图像"
36
+
37
+ try:
38
+ # 将PIL图像转换为字节
39
+ img_byte_arr = io.BytesIO()
40
+ input_image.save(img_byte_arr, format='PNG')
41
+ img_byte_arr = img_byte_arr.getvalue()
42
+
43
+ # 调用generator的generate函数
44
+ # result_image = generator.generate(img_byte_arr, prompt, guidance_scale=guidance_scale)
45
+ result_image = generator.generate(prompt=prompt, guidance_scale=guidance_scale)
46
+
47
+ return result_image, "処理が成功しました。"
48
+ except Exception as e:
49
+ import traceback
50
+ traceback.print_exc()
51
+ return None, f"生成失败: {str(e)}"
52
+
53
+ # 创建Gradio界面
54
+ with gr.Blocks(title="ドット絵ピクセルアート") as demo:
55
+ gr.Markdown("# ドット絵ピクセルアート")
56
+
57
+ with gr.Tabs():
58
+ with gr.TabItem("OpenCV ピクセル化"):
59
+ with gr.Row():
60
+ with gr.Column():
61
+ # 输入区域 - 设置固定大小
62
+ input_image_opencv = gr.Image(
63
+ label="スケッチをアップロードまたは描画",
64
+ type="pil",
65
+ height=512, # 设置固定高度
66
+ width=512, # 设置固定宽度
67
+ container=True, # 使用容器包裹
68
+ show_download_button=False, # 不显示下载按钮
69
+ show_label=True, # 显示标签
70
+ )
71
+
72
+ pixel_size = gr.Slider(minimum=5, maximum=50, value=20, step=1,
73
+ label="ピクセルサイズ")
74
+
75
+ process_btn_opencv = gr.Button("ドット絵を生成 (OpenCV)")
76
+
77
+ with gr.Column():
78
+ # 输出区域 - 也设置固定大小
79
+ output_image_opencv = gr.Image(
80
+ label="生成結果",
81
+ height=512, # 设置固定高度
82
+ width=512, # 设置固定宽度
83
+ container=True, # 使用容器包裹
84
+ show_download_button=True, # 显示下载按钮
85
+ )
86
+ output_message_opencv = gr.Textbox(label="ステータス")
87
+
88
+ with gr.TabItem("Stable Diffusionで生成"):
89
+ with gr.Row():
90
+ with gr.Column():
91
+ # 输入区域
92
+ input_image_diffusion = gr.Image(
93
+ label="参考画像をアップロード",
94
+ type="pil",
95
+ height=512,
96
+ width=512,
97
+ container=True,
98
+ show_download_button=False,
99
+ )
100
+
101
+ prompt = gr.Textbox(
102
+ label="プロンプト",
103
+ placeholder="希望するピクセルアートのスタイルを説明してください...",
104
+ value="Pixel Art, PixArFK"
105
+ )
106
+
107
+ guidance_scale = gr.Slider(
108
+ minimum=1.0,
109
+ maximum=15.0,
110
+ value=7.5,
111
+ step=0.5,
112
+ label="ガイダンスの強さ(数値が高いほど指定した内容に沿いやすくなります)"
113
+ )
114
+
115
+ process_btn_diffusion = gr.Button("ドット絵を作成 (Stable Diffusion)")
116
+
117
+ with gr.Column():
118
+ # 输出区域
119
+ output_image_diffusion = gr.Image(
120
+ label="生成結果",
121
+ height=512,
122
+ width=512,
123
+ container=True,
124
+ show_download_button=True,
125
+ )
126
+ output_message_diffusion = gr.Textbox(label="ステータス")
127
+
128
+ # 连接OpenCV处理按钮和函数
129
+ process_btn_opencv.click(
130
+ fn=handle_image_opencv,
131
+ inputs=[input_image_opencv, pixel_size],
132
+ outputs=[output_image_opencv, output_message_opencv]
133
+ )
134
+
135
+ # 连接Stable Diffusion处理按钮和函数
136
+ process_btn_diffusion.click(
137
+ fn=handle_image_diffusion,
138
+ inputs=[input_image_diffusion, prompt, guidance_scale],
139
+ outputs=[output_image_diffusion, output_message_diffusion]
140
+ )
141
+
142
+ return demo
143
+
144
+ # 启动应用
145
+ if __name__ == "__main__":
146
+ demo.launch()
app/main.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.generator import PixelArtGenerator
2
+ from interface.gradio_ui import create_gradio_interface
3
+
4
+ # 初始化生成器和 Gradio 界面
5
+ generator = PixelArtGenerator()
6
+ interface = create_gradio_interface(generator)
7
+
8
+
9
+ def launch_app():
10
+ interface.launch()
11
+
app/utils/__pycache__/image.cpython-312.pyc ADDED
Binary file (1.26 kB). View file
 
app/utils/__pycache__/image_processing.cpython-312.pyc ADDED
Binary file (8.42 kB). View file
 
app/utils/__pycache__/process.cpython-312.pyc ADDED
Binary file (2.49 kB). View file
 
app/utils/extract_subject.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import json
4
+ from image_processing import ImageProcessor
5
+
6
+ if __name__ == "__main__":
7
+ if len(sys.argv) < 2:
8
+ print(json.dumps({"error": "No image path provided"}))
9
+ sys.exit(1)
10
+
11
+ image_path = sys.argv[1]
12
+ processor = ImageProcessor()
13
+
14
+ try:
15
+ subject_path, mask = processor.extract_subject(image_path)
16
+ # 保存掩码为图像
17
+ mask_path = image_path.replace('.', '_mask.')
18
+ import cv2
19
+ import numpy as np
20
+ cv2.imwrite(mask_path, (mask * 255).astype(np.uint8))
21
+
22
+ print(json.dumps({
23
+ "subjectPath": subject_path,
24
+ "maskPath": mask_path
25
+ }))
26
+ except Exception as e:
27
+ print(json.dumps({"error": str(e)}))
28
+ sys.exit(1)
app/utils/image.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ def process_sketch(
6
+ image: Image.Image,
7
+ low_threshold: int = 100,
8
+ high_threshold: int = 200,
9
+ bg_color: int = 255
10
+ ) -> Image.Image:
11
+ """处理草图,提取 Canny 边缘"""
12
+ # 转换为 numpy 数组
13
+ img_array = np.array(image)
14
+
15
+ # 转换为灰度图
16
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
17
+
18
+ # 应用高斯模糊减少噪声
19
+ blurred = cv2.GaussianBlur(gray, (3, 3), 0)
20
+
21
+ # 应用 Canny 边缘检测
22
+ edges = cv2.Canny(
23
+ blurred,
24
+ threshold1=low_threshold,
25
+ threshold2=high_threshold
26
+ )
27
+
28
+ # 创建白色背景
29
+ result = np.full_like(edges, bg_color, dtype=np.uint8)
30
+
31
+ # 将边缘设为黑色
32
+ result[edges > 0] = 0
33
+
34
+ # 转回 PIL Image
35
+ return Image.fromarray(result)
app/utils/image_processing.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision import transforms
6
+ from segment_anything import SamPredictor, sam_model_registry
7
+ import os
8
+
9
+ class ImageProcessor:
10
+ def __init__(self):
11
+ # 加载SAM模型用于图像分割
12
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ # 更新为你下载的模型实际路径
14
+ model_path = "./models/sam_vit_h_4b8939.pth" # 修改为你的实际路径
15
+
16
+ # 检查模型文件是否存在
17
+ if not os.path.exists(model_path):
18
+ raise FileNotFoundError(f"SAM模型文件未找到: {model_path}")
19
+
20
+ print(f"加载SAM模型: {model_path}")
21
+ self.sam = sam_model_registry["vit_h"](checkpoint=model_path)
22
+ self.sam.to(self.device)
23
+ self.predictor = SamPredictor(self.sam)
24
+
25
+ def extract_subject(self, image_path):
26
+ """提取图像中的主体对象,使用多点提示和更强的分割策略"""
27
+ # 读取图像
28
+ image = cv2.imread(image_path)
29
+ if image is None:
30
+ raise ValueError(f"无法读取图像: {image_path}")
31
+
32
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
33
+ h, w = image.shape[:2]
34
+
35
+ # 设置SAM预测器
36
+ self.predictor.set_image(image_rgb)
37
+
38
+ # 使用多个点作为提示,覆盖图像的不同区域
39
+ points = np.array([
40
+ [w//2, h//2], # 中心
41
+ [w//4, h//4], # 左上
42
+ [3*w//4, h//4], # 右上
43
+ [w//4, 3*h//4], # 左下
44
+ [3*w//4, 3*h//4], # 右下
45
+ ])
46
+
47
+ print(f"使用多点提示: {points}")
48
+
49
+ # 所有点都标记为前景
50
+ labels = np.ones(len(points))
51
+
52
+ # 获取掩码
53
+ masks, scores, _ = self.predictor.predict(
54
+ point_coords=points,
55
+ point_labels=labels,
56
+ multimask_output=True # 生成多个掩码
57
+ )
58
+
59
+ # 选择得分最高的掩码
60
+ best_mask_idx = np.argmax(scores)
61
+ mask = masks[best_mask_idx]
62
+
63
+ print(f"选择得分最高的掩码: {scores[best_mask_idx]}")
64
+
65
+ # 检查掩码覆盖面积
66
+ mask_area = np.sum(mask)
67
+ image_area = h * w
68
+ coverage = mask_area / image_area
69
+ print(f"掩码覆盖率: {coverage:.2%}")
70
+
71
+ # 如果掩码覆盖率太小,尝试使用更简单的方法
72
+ if coverage < 0.05: # 如果覆盖率小于5%
73
+ print("掩码覆盖率太小,尝试使用颜色阈值分割")
74
+
75
+ # 使用颜色阈值分割
76
+ # 转换为HSV颜色空间
77
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
78
+
79
+ # 定义橙色范围 (松鼠的主要颜色)
80
+ lower_orange = np.array([10, 100, 100])
81
+ upper_orange = np.array([25, 255, 255])
82
+
83
+ # 创建掩码
84
+ color_mask = cv2.inRange(hsv, lower_orange, upper_orange)
85
+
86
+ # 应用形态学操作清理掩码
87
+ kernel = np.ones((5,5), np.uint8)
88
+ color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, kernel)
89
+ color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_CLOSE, kernel)
90
+
91
+ # 转换为布尔掩码
92
+ mask = color_mask > 0
93
+
94
+ # 再次检查掩码覆盖率
95
+ mask_area = np.sum(mask)
96
+ coverage = mask_area / image_area
97
+ print(f"颜色阈值分割后的掩码覆盖率: {coverage:.2%}")
98
+
99
+ # 如果仍然太小,使用简单的矩形区域
100
+ if coverage < 0.05:
101
+ print("颜色阈值分割仍然不理想,使用中心区域")
102
+ mask = np.zeros((h, w), dtype=bool)
103
+ # 使用图像中心的60%区域
104
+ h_start, h_end = int(h*0.2), int(h*0.8)
105
+ w_start, w_end = int(w*0.2), int(w*0.8)
106
+ mask[h_start:h_end, w_start:w_end] = True
107
+
108
+ # 保存掩码为图像
109
+ mask_path = os.path.splitext(image_path)[0] + '_mask.png'
110
+ cv2.imwrite(mask_path, (mask * 255).astype(np.uint8))
111
+
112
+ return image_path, mask
113
+
114
+ def pixelate_subject(self, image_path, mask, pixel_size=20):
115
+ """对主体进行像素化,并确保边缘也完全像素化"""
116
+ print("开始像素化处理...")
117
+ # 读取原始图像
118
+ image = cv2.imread(image_path)
119
+ if image is None:
120
+ raise ValueError(f"无法读取图像: {image_path}")
121
+
122
+ print("图像读取成功,开始处理...")
123
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
124
+ h, w = image.shape[:2]
125
+
126
+ # 创建一个透明背景的RGBA图像
127
+ result = np.zeros((h, w, 4), dtype=np.uint8)
128
+
129
+ # 获取主体区域的边界框
130
+ y_indices, x_indices = np.where(mask)
131
+ if len(y_indices) == 0 or len(x_indices) == 0:
132
+ print("未检测到主体,返回原图")
133
+ return image_path # 如果没有检测到主体,返回原图
134
+
135
+ y_min, y_max = np.min(y_indices), np.max(y_indices)
136
+ x_min, x_max = np.min(x_indices), np.max(x_indices)
137
+ print(f"主体边界框: ({x_min}, {y_min}) - ({x_max}, {y_max})")
138
+
139
+ # 提取主体区域
140
+ subject = image_rgb[y_min:y_max, x_min:x_max]
141
+ subject_mask = mask[y_min:y_max, x_min:x_max]
142
+
143
+ # 创建一个完整的RGBA图像,包含主体和透明背景
144
+ subject_rgba = np.zeros((subject.shape[0], subject.shape[1], 4), dtype=np.uint8)
145
+ subject_rgba[:,:,:3] = subject
146
+ subject_rgba[:,:,3] = (subject_mask * 255).astype(np.uint8)
147
+
148
+ print("对整个主体(包括边缘)进行像素化...")
149
+ # 对整个RGBA图像进行像素化处理
150
+ h_sub, w_sub = subject_rgba.shape[:2]
151
+ print(f"主体尺寸: {w_sub}x{h_sub}, 像素大小: {pixel_size}")
152
+
153
+ # 确保目标尺寸至少为1x1
154
+ target_w = max(1, w_sub // pixel_size)
155
+ target_h = max(1, h_sub // pixel_size)
156
+ print(f"目标尺寸: {target_w}x{target_h}")
157
+
158
+ # 先缩小
159
+ temp = cv2.resize(subject_rgba, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
160
+ # 再放大,使用最近邻插值以保持像素化效果
161
+ pixelated_rgba = cv2.resize(temp, (w_sub, h_sub), interpolation=cv2.INTER_NEAREST)
162
+
163
+ print("将处理后的主体放回原位置...")
164
+ # 将处理后的主体放回原位置
165
+ result[y_min:y_max, x_min:x_max] = pixelated_rgba
166
+
167
+ # 保存结果为PNG(支持透明度)
168
+ result_path = os.path.splitext(image_path)[0] + '_pixelated.png'
169
+
170
+ print(f"保存结果到: {result_path}")
171
+ # 使用PIL保存RGBA图像
172
+ pil_image = Image.fromarray(result)
173
+ pil_image.save(result_path, format='PNG')
174
+
175
+ print("像素化处理完成!")
176
+ return result_path
app/utils/pixelate_subject.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import json
4
+ import cv2
5
+ import numpy as np
6
+ from image_processing import ImageProcessor
7
+
8
+ if __name__ == "__main__":
9
+ if len(sys.argv) < 3:
10
+ print(json.dumps({"error": "Missing arguments"}))
11
+ sys.exit(1)
12
+
13
+ subject_path = sys.argv[1]
14
+ mask_path = sys.argv[2]
15
+ pixel_size = int(sys.argv[3]) if len(sys.argv) > 3 else 20
16
+
17
+ # 读取掩码
18
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
19
+ mask = mask > 0 # 二值化
20
+
21
+ processor = ImageProcessor()
22
+
23
+ try:
24
+ result_path = processor.pixelate_subject(subject_path, mask, pixel_size)
25
+ print(json.dumps({
26
+ "resultPath": result_path
27
+ }))
28
+ except Exception as e:
29
+ print(json.dumps({"error": str(e)}))
30
+ sys.exit(1)
app/utils/process.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ # 确保导入路径正确
3
+ from utils.image_processing import ImageProcessor
4
+ import time
5
+ import gradio as gr
6
+ import traceback
7
+
8
+ # 初始化图像处理器
9
+ try:
10
+ print("初始化 ImageProcessor...")
11
+ image_processor = ImageProcessor()
12
+ print("ImageProcessor 初始化成功")
13
+ except Exception as e:
14
+ print(f"初始化 ImageProcessor 失败: {str(e)}")
15
+ traceback.print_exc()
16
+
17
+ def process_image(input_image, *args, **kwargs):
18
+ progress = gr.Progress()
19
+
20
+ print("开始处理输入图片...")
21
+ progress(0, desc="开始处理...")
22
+
23
+ # 保存输入图像到临时文件
24
+ temp_input_path = "temp_input.png"
25
+ input_image.save(temp_input_path)
26
+
27
+ try:
28
+ print("使用SAM提取主体...")
29
+ progress(0.3, desc="提取主体...")
30
+ # 使用SAM提取主体
31
+ _, mask = image_processor.extract_subject(temp_input_path)
32
+
33
+ print("对主体进行像素化...")
34
+ progress(0.6, desc="像素化处理...")
35
+ # 对主体进行像素化
36
+ pixel_size = kwargs.get('pixel_size', 20)
37
+ result_path = image_processor.pixelate_subject(temp_input_path, mask, pixel_size)
38
+
39
+ # 读取结果图像
40
+ progress(0.9, desc="生成最终图像...")
41
+ result_image = Image.open(result_path)
42
+
43
+ print("完成主体提取和像素化!")
44
+ progress(1.0, desc="处理完成!")
45
+ return result_image
46
+ except Exception as e:
47
+ traceback.print_exc()
48
+ print(f"处理失败: {str(e)}")
49
+ progress(1.0, desc="处理失败")
50
+
51
+ # 如果失败,继续使用原来的处理流程
52
+ print("回退到原始处理流程...")
53
+
54
+ print("提取边缘...")
55
+ # 原来的边缘提取代码
56
+
57
+ print("开始生成图片...")
58
+ # 原来的图片生成代码
59
+
60
+ # ... 其余原始处理代码 ...
61
+ return None
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 核心 AI 与图像处理
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ transformers
6
+ diffusers
7
+ accelerate
8
+ safetensors
9
+ scipy
10
+ tqdm
11
+ einops
12
+ opencv-python
13
+ Pillow
14
+ numpy
15
+ kornia
16
+
17
+ # Gradio UI
18
+ gradio
19
+
20
+ # CLIP interrogator
21
+ clip-interrogator
22
+
23
+ # FastAPI 和 Web API(可选)
24
+ fastapi
25
+ uvicorn
26
+ python-multipart
27
+
28
+ # 通用工具
29
+ requests
30
+ pydantic
31
+ packaging
32
+
33
+ openai
34
+ huggingface-hub
35
+
36
+ git+https://github.com/facebookresearch/segment-anything.git