Jinglong Xiong
add models
6642f4e
import numpy as np
from svgpathtools import (
Path, Arc, CubicBezier, QuadraticBezier,
svgstr2paths)
import os
from noise import pnoise1
import re
import matplotlib.colors as mcolors
from bs4 import BeautifulSoup
from starvector.data.util import rasterize_svg
class SVGTransforms:
def __init__(self, transformations):
self.transformations = transformations
self.noise_std = self.transformations.get('noise_std', False)
self.noise_type = self.transformations.get('noise_type', False)
self.rotate = self.transformations.get('rotate', False)
self.shift_re = self.transformations.get('shift_re', False)
self.shift_im = self.transformations.get('shift_im', False)
self.scale = self.transformations.get('scale', False)
self.color_noise = self.transformations.get('color_noise', False)
self.p = self.transformations.get('p', 0.5)
self.color_change = self.transformations.get('color_change', False)
self.colors = self.transformations.get('colors', ['#ff0000', '#0000ff', '#000000'])
def sample_transformations(self):
if self.rotate:
a, b = self.rotate['from'], self.rotate['to']
rotation_angle = np.random.uniform(a, b)
self.rotation_angle = rotation_angle
if self.shift_re or self.shift_im:
self.shift_real = np.random.uniform(self.shift_re['from'], self.shift_re['to'])
self.shift_imag = np.random.uniform(self.shift_im['from'], self.shift_im['to'])
if self.scale:
self.scale = np.random.uniform(self.scale['from'], self.scale['to'])
if self.color_noise:
self.color_noise_std = np.random.uniform(self.color_noise['from'], self.color_noise['to'])
def paths2str(self, groupped_paths, svg_opening_tag='<svg xmlns="http://www.w3.org/2000/svg" version="1.1">'):
keys_to_exclude = ['d', 'cx', 'cy', 'rx', 'ry']
all_groups_srt = ''
for group, elements in groupped_paths.items():
group_attributes, paths_and_attributes = elements.get('attrs', {}), elements.get('paths', [])
group_attr_str = ' '.join(f'{key}="{value}"' for key, value in group_attributes.items())
path_strings = []
path_str = ''
for path, attributes in paths_and_attributes:
path_attr_str = ''
d_str = path.d()
for key, value in attributes.items():
if key not in keys_to_exclude:
path_attr_str += f' {key}="{value}"'
path_strings.append(f'<path d="{d_str}"{path_attr_str} />')
path_str = "\n".join(path_strings)
if 'no_group'in group:
group_str = path_str
else:
group_str = f'<g {group_attr_str}>\n{path_str}\n</g>\n'
all_groups_srt += group_str
svg = f'{svg_opening_tag}\n{all_groups_srt}</svg>'
return svg
def add_noise(self, seg):
noise_scale = np.random.uniform(self.noise_std['from'], self.noise_std['to'])
if self.noise_type == 'gaussian':
noise_sample = np.random.normal(loc=0.0, scale=noise_scale) + \
1j * np.random.normal(loc=0.0, scale=noise_scale)
elif self.noise_type == 'perlin':
noise_sample = complex(pnoise1(np.random.random(), octaves=2), pnoise1(np.random.random(), octaves=2))*noise_scale
if isinstance(seg, CubicBezier):
seg.control1 = seg.control1 + noise_sample
seg.control2 = seg.control2 + noise_sample
elif isinstance(seg, QuadraticBezier):
seg.control = seg.control + noise_sample
elif isinstance(seg, Arc):
seg.radius = seg.radius + noise_sample
return seg
def do_rotate(self, path, viewbox_width, viewbox_height):
if self.rotate:
new_path = path.rotated(self.rotation_angle, complex(viewbox_width/2, viewbox_height/2))
return new_path
else:
return path
def do_shift(self, path):
if self.shift_re or self.shift_im:
return path.translated(complex(self.shift_real, self.shift_imag))
else:
return path
def do_scale(self, path):
if self.scale:
return path.scaled(self.scale)
else:
return path
def add_color_noise(self, source_color):
# Convert color to RGB
if source_color.startswith("#"):
base_color = mcolors.hex2color(source_color)
else:
base_color = mcolors.hex2color(mcolors.CSS4_COLORS.get(source_color, '#FFFFFF'))
# Add noise to each RGB component
noise = np.random.normal(0, self.color_noise_std, 3)
noisy_color = np.clip(np.array(base_color) + noise, 0, 1)
# Convert the RGB color back to hex
hex_color = mcolors.rgb2hex(noisy_color)
return hex_color
def do_color_change(self, attr):
if 'fill' in attr:
if self.color_noise or self.color_change:
fill_value = attr['fill']
if fill_value == 'none':
new_fill_value = 'none'
else:
if self.color_noise:
new_fill_value = self.add_color_noise(fill_value)
elif self.color_change:
new_fill_value = np.random.choice(self.colors)
attr['fill'] = new_fill_value
return attr
def clean_attributes(self, attr):
attr_out = {}
if 'fill' in attr:
attr_out = attr
elif 'style' in attr:
fill_values = re.findall('fill:[^;]+', attr['style'])
if fill_values:
fill_value = fill_values[0].replace('fill:', '').strip()
attr_out['fill'] = fill_value
else:
attr_out = attr
else:
attr_out = attr
return attr_out
def get_viewbox_size(self, svg):
# Try to extract viewBox attribute
match = re.search(r'viewBox="([^"]+)"', svg)
if match:
viewbox = match.group(1)
else:
# If viewBox is not found, try to extract width and height attributes
match = re.search(r'width="([^"]+)px" height="([^"]+)px"', svg)
if match:
width, height = match.groups()
viewbox = f"0 0 {width} {height}"
else:
viewbox = "0 0 256 256" # Default if neither viewBox nor width/height are found
viewbox = [float(x) for x in viewbox.split()]
viewbox_width, viewbox_height = viewbox[2], viewbox[3]
return viewbox_width, viewbox_height
def augment(self, svg):
if os.path.isfile(svg):
# open svg file
with open(svg, 'r') as f:
svg = f.read()
# Sample transformations for this sample
self.sample_transformations()
# Parse the SVG content
soup = BeautifulSoup(svg, 'xml')
# Get opening tag
svg_opening_tag = re.findall('<svg[^>]+>', svg)[0]
viewbox_width, viewbox_height = self.get_viewbox_size(svg)
# Get all svg parents
groups = soup.findAll()
# Create the groups of paths based on their original <g> tag
grouped_paths = {}
for i, g in enumerate(groups):
if g.name == 'g':
group_id = group_id = g.get('id') if g.get('id') else f'none_{i}'
group_attrs = g.attrs
elif g.name == 'svg' or g.name == 'metadata' or g.name == 'defs':
continue
else:
group_id = f'no_group_{i}'
group_attrs = {}
group_svg_string = f'{svg_opening_tag}{str(g)}</svg>'
try:
paths, attributes = svgstr2paths(group_svg_string)
except:
return svg, rasterize_svg(svg)
if not paths:
continue
paths_and_attributes = []
# Rotation, shift, scale, noise addition
new_paths = []
new_attributes = []
for path, attribute in zip(paths, attributes):
attr = self.clean_attributes(attribute)
new_path = self.do_rotate(path, viewbox_width, viewbox_height)
new_path = self.do_shift(new_path)
new_path = self.do_scale(new_path)
if self.noise_std:
# Add noise to path to deform svg
noisy_path = []
for seg in new_path:
noisy_seg = self.add_noise(seg)
noisy_path.append(noisy_seg)
new_paths.append(Path(*noisy_path))
else:
new_paths.append(new_path)
# Color change
attr = self.do_color_change(attr)
paths_and_attributes.append((new_path, attr))
grouped_paths[group_id] = {
'paths': paths_and_attributes,
'attrs': group_attrs
}
svg = self.paths2str(grouped_paths, svg_opening_tag)
image = rasterize_svg(svg)
return svg, image