StupidGame's picture
Upload 1941 files
baa8e90
raw
history blame contribute delete
No virus
6.81 kB
# -*- coding: utf-8 -*-
import random
import time
from typing import List, Dict, Tuple
from .shared import DreamImage
class RGBPalette:
ID = "RGB_PALETTE"
def __init__(self, colors: List[tuple[int, int, int]] = None, image: DreamImage = None):
self._colors = []
def _fix_tuple(t):
if len(t) < 3:
return (t[0], t[0], t[0])
else:
return t
if image:
for p, _, _ in image:
self._colors.append(_fix_tuple(p))
if colors:
for c in colors:
self._colors.append(_fix_tuple(c))
def _calculate_channel_contrast(self, c):
hist = list(map(lambda _: 0, range(16)))
for pixel in self._colors:
hist[pixel[c] // 16] += 1
s = 0
max_possible = (15 - 0) * (len(self) // 2) * (len(self) // 2)
for i in range(16):
for j in range(i):
if i != j:
s += abs(i - j) * hist[i] * hist[j]
return s / max_possible
def _calculate_combined_contrast(self):
s = 0
for c in range(3):
s += self._calculate_channel_contrast(c)
return s / 3
def analyze(self):
total_red = 0
total_blue = 0
total_green = 0
for pixel in self:
total_red += pixel[0]
total_green += pixel[1]
total_blue += pixel[2]
n = len(self._colors)
r = float(total_red) / (255 * n)
g = float(total_green) / (255 * n)
b = float(total_blue) / (255 * n)
return ((r + g + b) / 3.0, self._calculate_combined_contrast(), r, g, b)
def __len__(self):
return len(self._colors)
def __iter__(self):
return iter(self._colors)
def random_iteration(self, seed=None):
s = seed if seed is not None else int(time.time() * 1000)
n = len(self._colors) - 1
c = self._colors
class _ColorIterator:
def __init__(self):
self._r = random.Random()
self._r.seed(s)
self._n = n
self._c = c
def __next__(self):
return self._c[self._r.randint(0, self._n)]
return _ColorIterator()
class PartialPrompt:
ID = "PARTIAL_PROMPT"
def __init__(self):
self._data = {}
def add(self, text: str, weight: float):
output = PartialPrompt()
output._data = dict(self._data)
for parts in text.split(","):
parts = parts.strip()
if " " in parts:
output._data["(" + parts + ")"] = weight
else:
output._data[parts] = weight
return output
def is_empty(self):
return not self._data
def abs_sum(self):
if not self._data:
return 0.0
return sum(map(abs, self._data.values()))
def abs_max(self):
if not self._data:
return 0.0
return max(map(abs, self._data.values()))
def scaled_by(self, f: float):
new_data = PartialPrompt()
new_data._data = dict(self._data)
for text, weight in new_data._data.items():
new_data._data[text] = weight * f
return new_data
def finalize(self, clamp: float):
items = self._data.items()
items = sorted(items, key=lambda pair: (pair[1], pair[0]))
pos = list()
neg = list()
for text, w in sorted(items, key=lambda pair: (-pair[1], pair[0])):
if w >= 0.0001:
pos.append("({}:{:.3f})".format(text, min(clamp, w)))
for text, w in sorted(items, key=lambda pair: (pair[1], pair[0])):
if w <= -0.0001:
neg.append("({}:{:.3f})".format(text, min(clamp, -w)))
return ", ".join(pos), ", ".join(neg)
class LogEntry:
ID = "LOG_ENTRY"
@classmethod
def new(cls, text):
return LogEntry([(time.time(), text)])
def __init__(self, data: List[Tuple[float, str]] = None):
if data is None:
self._data = list()
else:
self._data = list(data)
def add(self, text: str):
new_data = list(self._data)
new_data.append((time.time(), text))
return LogEntry(new_data)
def merge(self, log_entry):
new_data = list(self._data)
new_data.extend(log_entry._data)
return LogEntry(new_data)
def get_filtered_entries(self, t: float):
for d in sorted(self._data):
if d[0] > t:
yield d
class FrameCounter:
ID = "FRAME_COUNTER"
def __init__(self, current_frame=0, total_frames=1, frames_per_second=25.0):
self.current_frame = max(0, current_frame)
self.total_frames = max(total_frames, 1)
self.frames_per_second = float(max(1.0, frames_per_second))
def incremented(self, amount: int):
return FrameCounter(self.current_frame + amount, self.total_frames, self.frames_per_second)
@property
def is_first_frame(self):
return self.current_frame == 0
@property
def is_final_frame(self):
return (self.current_frame + 1) == self.total_frames
@property
def is_after_last_frame(self):
return self.current_frame >= self.total_frames
@property
def current_time_in_seconds(self):
return float(self.current_frame) / self.frames_per_second
@property
def total_time_in_seconds(self):
return float(self.total_frames) / self.frames_per_second
@property
def remaining_time_in_seconds(self):
return self.total_time_in_seconds - self.current_time_in_seconds
@property
def progress(self):
return float(self.current_frame) / (max(2, self.total_frames) - 1)
class AnimationSequence:
ID = "ANIMATION_SEQUENCE"
def __init__(self, frame_counter: FrameCounter, frames: Dict[int, List[str]] = None):
self.frames = frames
self.fps = frame_counter.frames_per_second
self.frame_counter = frame_counter
if self.is_defined:
self.keys_in_order = sorted(frames.keys())
self.num_batches = min(map(len, self.frames.values()))
else:
self.keys_in_order = []
self.num_batches = 0
@property
def batches(self):
return range(self.num_batches)
def get_image_files_of_batch(self, batch_num):
for key in self.keys_in_order:
yield self.frames[key][batch_num]
@property
def is_defined(self):
if self.frames:
return True
else:
return False
class SharedTypes:
frame_counter = {"frame_counter": (FrameCounter.ID,)}
sequence = {"sequence": (AnimationSequence.ID,)}
palette = {"palette": (RGBPalette.ID,)}