File size: 4,106 Bytes
836cc89
d3dd837
 
836cc89
 
 
d3dd837
 
eb3165e
 
836cc89
 
 
 
 
d3dd837
836cc89
 
 
 
3755f04
836cc89
 
3755f04
836cc89
 
 
 
 
 
 
 
 
 
d3dd837
836cc89
 
 
 
 
 
 
 
 
3755f04
836cc89
 
d3dd837
 
836cc89
eb3165e
836cc89
 
d3dd837
836cc89
 
 
d3dd837
836cc89
 
 
d3dd837
836cc89
 
d3dd837
836cc89
d3dd837
836cc89
 
 
 
 
 
d3dd837
836cc89
d3dd837
836cc89
 
 
 
 
 
 
 
 
d3dd837
836cc89
 
 
 
 
 
 
 
 
 
d3dd837
836cc89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb3165e
 
 
836cc89
 
 
 
 
d3dd837
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import { AutoProcessor, RawImage, AutoModel } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers';

const status = document.getElementById('status');
const fileSelect = document.getElementById('file-select');
const imageContainer = document.getElementById('image-container');
const outputContainer = document.getElementById('output-container');

status.textContent = 'Loading model...';
let startTime = null;
let endTime = null;

// Load model and processor
const model = await AutoModel.from_pretrained('Xenova/modnet-onnx', { quantized: false });
const processor = await AutoProcessor.from_pretrained('Xenova/modnet-onnx');

status.textContent = 'Ready';
// Load image from URL
const url = 'https://images.pexels.com/photos/5965592/pexels-photo-5965592.jpeg?auto=compress&cs=tinysrgb&w=1024';
function useRemoteImage(url) {
  const image = document.createElement('img');
  image.crossOrigin = "anonymous";
  image.src = url;
  imageContainer.appendChild(image);
  setTimeout(() => start(url), 0)
}
useRemoteImage(url)

fileSelect.addEventListener('change', function (e) {
  const file = e.target.files[0];
  if (!file) {
    return;
  }

  const reader = new FileReader();

  // Set up a callback when the file is loaded
  reader.onload = function (e2) {
    status.textContent = 'Image loaded';

    imageContainer.innerHTML = '';
    outputContainer.innerHTML = '';
    const image = document.createElement('img');
    image.src = e2.target.result;
    imageContainer.appendChild(image);
    setTimeout(() => start(image.src), 0)
  };
  reader.readAsDataURL(file);
});

async function start(source) {
  startTime = new Date();
  status.textContent = 'processing';
  console.log('start process')

  const image = await RawImage.read(source);
  // Process image
  const { pixel_values: input } = await processor(image);

  // Predict alpha matte
  const { output } = await model({ input });
  console.log('image', RawImage)

  // Convert output tensor to RawImage
  const matteImage = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(image.width, image.height);

  console.log('matteImage', matteImage, output)

  async function renderRawImage(image) {
    let rawCanvas = await image.toCanvas();
    const canvas = document.createElement('canvas');
    outputContainer.appendChild(canvas); // 将新创建的 Canvas 添加到页面中
    canvas.width = image.width;
    canvas.height = image.height;

    const ctx = canvas.getContext('2d');

    ctx.drawImage(rawCanvas, 0, 0);

  }

  // renderRawImage(matteImage)

  async function getForeground(rawImage, maskImage) {
    const rawCanvas = rawImage.toCanvas();
    const rawCtx = rawCanvas.getContext('2d');

    const maskCanvas = maskImage.toCanvas();
    const maskCtx = maskCanvas.getContext('2d');

    const rawImageData = rawCtx.getImageData(0, 0, rawCanvas.width, rawCanvas.height);
    const maskImageData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height);

    for (let i = 0; i < rawImageData.data.length; i += 4) {
      // 把灰度通道值(RGB 都一样,这里取 R),赋到原图的透明通道(每个像素的第 4 个值)
      rawImageData.data[i + 3] = maskImageData.data[i];
    }

    rawCtx.putImageData(rawImageData, 0, 0);
    return rawCanvas;
  }

  let foregroundCanvas = await getForeground(image, matteImage);

  // 使用示例:
  console.log('debug', foregroundCanvas);
  // 模拟异步操作,确保在完成操作后才继续执行
  foregroundCanvas.convertToBlob()
    .then(function (blob) {
      // 创建图片
      let img = new Image();

      // 创建 blob URL 并设置为图片的 src
      img.src = URL.createObjectURL(blob);

      // 将图片添加到 body 中或者其他 HTML 元素
      outputContainer.appendChild(img);
      endTime = new Date();
      const diff = (endTime - startTime) / 1000
      setTimeout(() => status.textContent = 'Finish: ' + diff + 's', 0)    
    })
    .catch(function (error) {
      // 捕获和处理 blob 创建过程中可能出现的错误
      console.error("Blob creation error: ", error);
    });
}