|
import sys
|
|
sys.path.append('./rxn/')
|
|
import torch
|
|
from rxn.reaction import Reaction
|
|
import json
|
|
from matplotlib import pyplot as plt
|
|
import numpy as np
|
|
|
|
ckpt_path = "./rxn/model/model.ckpt"
|
|
model = Reaction(ckpt_path, device=torch.device('cpu'))
|
|
device = torch.device('cpu')
|
|
|
|
def get_reaction(image_path: str) -> list:
|
|
'''Returns a list of reactions extracted from the image.'''
|
|
image_file = image_path
|
|
return json.dumps(model.predict_image_file(image_file, molscribe=True, ocr=True))
|
|
|
|
|
|
|
|
def generate_combined_image(predictions, image_file):
|
|
"""
|
|
将预测的图像整合到一个对称的布局中输出。
|
|
"""
|
|
output = model.draw_predictions(predictions, image_file=image_file)
|
|
n_images = len(output)
|
|
if n_images == 1:
|
|
n_cols = 1
|
|
elif n_images == 2:
|
|
n_cols = 2
|
|
else:
|
|
n_cols = 3
|
|
n_rows = (n_images + n_cols - 1) // n_cols
|
|
|
|
|
|
processed_images = []
|
|
for img in output:
|
|
if len(img.shape) == 2:
|
|
img = np.stack([img] * 3, axis=-1)
|
|
elif img.shape[2] > 3:
|
|
img = img[:, :, :3]
|
|
if img.dtype == np.float32 or img.dtype == np.float64:
|
|
img = (img * 255).astype(np.uint8)
|
|
processed_images.append(img)
|
|
output = processed_images
|
|
|
|
|
|
if n_images < n_rows * n_cols:
|
|
blank_image = np.ones_like(output[0]) * 255
|
|
while len(output) < n_rows * n_cols:
|
|
output.append(blank_image)
|
|
|
|
|
|
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
|
|
|
|
|
|
if isinstance(axes, np.ndarray):
|
|
axes = axes.flatten()
|
|
else:
|
|
axes = [axes]
|
|
|
|
|
|
for idx, img in enumerate(output):
|
|
ax = axes[idx]
|
|
ax.imshow(img)
|
|
ax.axis('off')
|
|
if idx < n_images:
|
|
ax.set_title(f"Reaction {idx + 1}")
|
|
|
|
|
|
for idx in range(n_images, len(axes)):
|
|
fig.delaxes(axes[idx])
|
|
|
|
|
|
combined_image_path = "combined_output.png"
|
|
plt.tight_layout()
|
|
plt.savefig(combined_image_path)
|
|
plt.close(fig)
|
|
return combined_image_path
|
|
|