Spaces:
Running
on
Zero
Running
on
Zero
| from transformers import AutoTokenizer, TextIteratorStreamer | |
| import difflib | |
| import torch | |
| import numpy as np | |
| import re | |
| from models.model_manager import ModelManager | |
| from PIL import Image | |
| valid_colors = { # r, g, b | |
| "aliceblue": (240, 248, 255), | |
| "antiquewhite": (250, 235, 215), | |
| "aqua": (0, 255, 255), | |
| "aquamarine": (127, 255, 212), | |
| "azure": (240, 255, 255), | |
| "beige": (245, 245, 220), | |
| "bisque": (255, 228, 196), | |
| "black": (0, 0, 0), | |
| "blanchedalmond": (255, 235, 205), | |
| "blue": (0, 0, 255), | |
| "blueviolet": (138, 43, 226), | |
| "brown": (165, 42, 42), | |
| "burlywood": (222, 184, 135), | |
| "cadetblue": (95, 158, 160), | |
| "chartreuse": (127, 255, 0), | |
| "chocolate": (210, 105, 30), | |
| "coral": (255, 127, 80), | |
| "cornflowerblue": (100, 149, 237), | |
| "cornsilk": (255, 248, 220), | |
| "crimson": (220, 20, 60), | |
| "cyan": (0, 255, 255), | |
| "darkblue": (0, 0, 139), | |
| "darkcyan": (0, 139, 139), | |
| "darkgoldenrod": (184, 134, 11), | |
| "darkgray": (169, 169, 169), | |
| "darkgrey": (169, 169, 169), | |
| "darkgreen": (0, 100, 0), | |
| "darkkhaki": (189, 183, 107), | |
| "darkmagenta": (139, 0, 139), | |
| "darkolivegreen": (85, 107, 47), | |
| "darkorange": (255, 140, 0), | |
| "darkorchid": (153, 50, 204), | |
| "darkred": (139, 0, 0), | |
| "darksalmon": (233, 150, 122), | |
| "darkseagreen": (143, 188, 143), | |
| "darkslateblue": (72, 61, 139), | |
| "darkslategray": (47, 79, 79), | |
| "darkslategrey": (47, 79, 79), | |
| "darkturquoise": (0, 206, 209), | |
| "darkviolet": (148, 0, 211), | |
| "deeppink": (255, 20, 147), | |
| "deepskyblue": (0, 191, 255), | |
| "dimgray": (105, 105, 105), | |
| "dimgrey": (105, 105, 105), | |
| "dodgerblue": (30, 144, 255), | |
| "firebrick": (178, 34, 34), | |
| "floralwhite": (255, 250, 240), | |
| "forestgreen": (34, 139, 34), | |
| "fuchsia": (255, 0, 255), | |
| "gainsboro": (220, 220, 220), | |
| "ghostwhite": (248, 248, 255), | |
| "gold": (255, 215, 0), | |
| "goldenrod": (218, 165, 32), | |
| "gray": (128, 128, 128), | |
| "grey": (128, 128, 128), | |
| "green": (0, 128, 0), | |
| "greenyellow": (173, 255, 47), | |
| "honeydew": (240, 255, 240), | |
| "hotpink": (255, 105, 180), | |
| "indianred": (205, 92, 92), | |
| "indigo": (75, 0, 130), | |
| "ivory": (255, 255, 240), | |
| "khaki": (240, 230, 140), | |
| "lavender": (230, 230, 250), | |
| "lavenderblush": (255, 240, 245), | |
| "lawngreen": (124, 252, 0), | |
| "lemonchiffon": (255, 250, 205), | |
| "lightblue": (173, 216, 230), | |
| "lightcoral": (240, 128, 128), | |
| "lightcyan": (224, 255, 255), | |
| "lightgoldenrodyellow": (250, 250, 210), | |
| "lightgray": (211, 211, 211), | |
| "lightgrey": (211, 211, 211), | |
| "lightgreen": (144, 238, 144), | |
| "lightpink": (255, 182, 193), | |
| "lightsalmon": (255, 160, 122), | |
| "lightseagreen": (32, 178, 170), | |
| "lightskyblue": (135, 206, 250), | |
| "lightslategray": (119, 136, 153), | |
| "lightslategrey": (119, 136, 153), | |
| "lightsteelblue": (176, 196, 222), | |
| "lightyellow": (255, 255, 224), | |
| "lime": (0, 255, 0), | |
| "limegreen": (50, 205, 50), | |
| "linen": (250, 240, 230), | |
| "magenta": (255, 0, 255), | |
| "maroon": (128, 0, 0), | |
| "mediumaquamarine": (102, 205, 170), | |
| "mediumblue": (0, 0, 205), | |
| "mediumorchid": (186, 85, 211), | |
| "mediumpurple": (147, 112, 219), | |
| "mediumseagreen": (60, 179, 113), | |
| "mediumslateblue": (123, 104, 238), | |
| "mediumspringgreen": (0, 250, 154), | |
| "mediumturquoise": (72, 209, 204), | |
| "mediumvioletred": (199, 21, 133), | |
| "midnightblue": (25, 25, 112), | |
| "mintcream": (245, 255, 250), | |
| "mistyrose": (255, 228, 225), | |
| "moccasin": (255, 228, 181), | |
| "navajowhite": (255, 222, 173), | |
| "navy": (0, 0, 128), | |
| "navyblue": (0, 0, 128), | |
| "oldlace": (253, 245, 230), | |
| "olive": (128, 128, 0), | |
| "olivedrab": (107, 142, 35), | |
| "orange": (255, 165, 0), | |
| "orangered": (255, 69, 0), | |
| "orchid": (218, 112, 214), | |
| "palegoldenrod": (238, 232, 170), | |
| "palegreen": (152, 251, 152), | |
| "paleturquoise": (175, 238, 238), | |
| "palevioletred": (219, 112, 147), | |
| "papayawhip": (255, 239, 213), | |
| "peachpuff": (255, 218, 185), | |
| "peru": (205, 133, 63), | |
| "pink": (255, 192, 203), | |
| "plum": (221, 160, 221), | |
| "powderblue": (176, 224, 230), | |
| "purple": (128, 0, 128), | |
| "rebeccapurple": (102, 51, 153), | |
| "red": (255, 0, 0), | |
| "rosybrown": (188, 143, 143), | |
| "royalblue": (65, 105, 225), | |
| "saddlebrown": (139, 69, 19), | |
| "salmon": (250, 128, 114), | |
| "sandybrown": (244, 164, 96), | |
| "seagreen": (46, 139, 87), | |
| "seashell": (255, 245, 238), | |
| "sienna": (160, 82, 45), | |
| "silver": (192, 192, 192), | |
| "skyblue": (135, 206, 235), | |
| "slateblue": (106, 90, 205), | |
| "slategray": (112, 128, 144), | |
| "slategrey": (112, 128, 144), | |
| "snow": (255, 250, 250), | |
| "springgreen": (0, 255, 127), | |
| "steelblue": (70, 130, 180), | |
| "tan": (210, 180, 140), | |
| "teal": (0, 128, 128), | |
| "thistle": (216, 191, 216), | |
| "tomato": (255, 99, 71), | |
| "turquoise": (64, 224, 208), | |
| "violet": (238, 130, 238), | |
| "wheat": (245, 222, 179), | |
| "white": (255, 255, 255), | |
| "whitesmoke": (245, 245, 245), | |
| "yellow": (255, 255, 0), | |
| "yellowgreen": (154, 205, 50), | |
| } | |
| valid_locations = { # x, y in 90*90 | |
| "in the center": (45, 45), | |
| "on the left": (15, 45), | |
| "on the right": (75, 45), | |
| "on the top": (45, 15), | |
| "on the bottom": (45, 75), | |
| "on the top-left": (15, 15), | |
| "on the top-right": (75, 15), | |
| "on the bottom-left": (15, 75), | |
| "on the bottom-right": (75, 75), | |
| } | |
| valid_offsets = { # x, y in 90*90 | |
| "no offset": (0, 0), | |
| "slightly to the left": (-10, 0), | |
| "slightly to the right": (10, 0), | |
| "slightly to the upper": (0, -10), | |
| "slightly to the lower": (0, 10), | |
| "slightly to the upper-left": (-10, -10), | |
| "slightly to the upper-right": (10, -10), | |
| "slightly to the lower-left": (-10, 10), | |
| "slightly to the lower-right": (10, 10), | |
| } | |
| valid_areas = { # w, h in 90*90 | |
| "a small square area": (50, 50), | |
| "a small vertical area": (40, 60), | |
| "a small horizontal area": (60, 40), | |
| "a medium-sized square area": (60, 60), | |
| "a medium-sized vertical area": (50, 80), | |
| "a medium-sized horizontal area": (80, 50), | |
| "a large square area": (70, 70), | |
| "a large vertical area": (60, 90), | |
| "a large horizontal area": (90, 60), | |
| } | |
| def safe_str(x): | |
| return x.strip(",. ") + "." | |
| def closest_name(input_str, options): | |
| input_str = input_str.lower() | |
| closest_match = difflib.get_close_matches( | |
| input_str, list(options.keys()), n=1, cutoff=0.5 | |
| ) | |
| assert isinstance(closest_match, list) and len(closest_match) > 0, ( | |
| f"The value [{input_str}] is not valid!" | |
| ) | |
| result = closest_match[0] | |
| if result != input_str: | |
| print(f"Automatically corrected [{input_str}] -> [{result}].") | |
| return result | |
| class Canvas: | |
| def from_bot_response(response: str): | |
| matched = re.search(r"```python\n(.*?)\n```", response, re.DOTALL) | |
| assert matched, "Response does not contain codes!" | |
| code_content = matched.group(1) | |
| assert "canvas = Canvas()" in code_content, ( | |
| "Code block must include valid canvas var!" | |
| ) | |
| local_vars = {"Canvas": Canvas} | |
| exec(code_content, {}, local_vars) | |
| canvas = local_vars.get("canvas", None) | |
| assert isinstance(canvas, Canvas), "Code block must produce valid canvas var!" | |
| return canvas | |
| def __init__(self): | |
| self.components = [] | |
| self.color = None | |
| self.record_tags = True | |
| self.prefixes = [] | |
| self.suffixes = [] | |
| return | |
| def set_global_description( | |
| self, | |
| description: str, | |
| detailed_descriptions: list, | |
| tags: str, | |
| HTML_web_color_name: str, | |
| ): | |
| assert isinstance(description, str), "Global description is not valid!" | |
| assert isinstance(detailed_descriptions, list) and all( | |
| isinstance(item, str) for item in detailed_descriptions | |
| ), "Global detailed_descriptions is not valid!" | |
| assert isinstance(tags, str), "Global tags is not valid!" | |
| HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors) | |
| self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8) | |
| self.prefixes = [description] | |
| self.suffixes = detailed_descriptions | |
| if self.record_tags: | |
| self.suffixes = self.suffixes + [tags] | |
| self.prefixes = [safe_str(x) for x in self.prefixes] | |
| self.suffixes = [safe_str(x) for x in self.suffixes] | |
| return | |
| def add_local_description( | |
| self, | |
| location: str, | |
| offset: str, | |
| area: str, | |
| distance_to_viewer: float, | |
| description: str, | |
| detailed_descriptions: list, | |
| tags: str, | |
| atmosphere: str, | |
| style: str, | |
| quality_meta: str, | |
| HTML_web_color_name: str, | |
| ): | |
| assert isinstance(description, str), "Local description is wrong!" | |
| assert ( | |
| isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0 | |
| ), f"The distance_to_viewer for [{description}] is not positive float number!" | |
| assert isinstance(detailed_descriptions, list) and all( | |
| isinstance(item, str) for item in detailed_descriptions | |
| ), f"The detailed_descriptions for [{description}] is not valid!" | |
| assert isinstance(tags, str), f"The tags for [{description}] is not valid!" | |
| assert isinstance(atmosphere, str), ( | |
| f"The atmosphere for [{description}] is not valid!" | |
| ) | |
| assert isinstance(style, str), f"The style for [{description}] is not valid!" | |
| assert isinstance(quality_meta, str), ( | |
| f"The quality_meta for [{description}] is not valid!" | |
| ) | |
| location = closest_name(location, valid_locations) | |
| offset = closest_name(offset, valid_offsets) | |
| area = closest_name(area, valid_areas) | |
| HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors) | |
| xb, yb = valid_locations[location] | |
| xo, yo = valid_offsets[offset] | |
| w, h = valid_areas[area] | |
| rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2) | |
| rect = [max(0, min(90, i)) for i in rect] | |
| color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8) | |
| prefixes = self.prefixes + [description] | |
| suffixes = detailed_descriptions | |
| if self.record_tags: | |
| suffixes = suffixes + [tags, atmosphere, style, quality_meta] | |
| prefixes = [safe_str(x) for x in prefixes] | |
| suffixes = [safe_str(x) for x in suffixes] | |
| self.components.append( | |
| dict( | |
| rect=rect, | |
| distance_to_viewer=distance_to_viewer, | |
| color=color, | |
| prefixes=prefixes, | |
| suffixes=suffixes, | |
| location=location, | |
| ) | |
| ) | |
| return | |
| def process(self): | |
| # sort components | |
| self.components = sorted( | |
| self.components, key=lambda x: x["distance_to_viewer"], reverse=True | |
| ) | |
| # compute initial latent | |
| # print(self.color) | |
| initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color | |
| for component in self.components: | |
| a, b, c, d = component["rect"] | |
| initial_latent[a:b, c:d] = ( | |
| 0.7 * component["color"] + 0.3 * initial_latent[a:b, c:d] | |
| ) | |
| initial_latent = initial_latent.clip(0, 255).astype(np.uint8) | |
| # compute conditions | |
| bag_of_conditions = [ | |
| dict( | |
| mask=np.ones(shape=(90, 90), dtype=np.float32), | |
| prefixes=self.prefixes, | |
| suffixes=self.suffixes, | |
| location="full", | |
| ) | |
| ] | |
| for i, component in enumerate(self.components): | |
| a, b, c, d = component["rect"] | |
| m = np.zeros(shape=(90, 90), dtype=np.float32) | |
| m[a:b, c:d] = 1.0 | |
| bag_of_conditions.append( | |
| dict( | |
| mask=m, | |
| prefixes=component["prefixes"], | |
| suffixes=component["suffixes"], | |
| location=component["location"], | |
| ) | |
| ) | |
| return dict( | |
| initial_latent=initial_latent, | |
| bag_of_conditions=bag_of_conditions, | |
| ) | |
| class OmostPromter(torch.nn.Module): | |
| def __init__(self, model=None, tokenizer=None, template="", device="cpu"): | |
| super().__init__() | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| if template == "": | |
| template = r"""You are a helpful AI assistant to compose images using the below python class `Canvas`: | |
| ```python | |
| class Canvas: | |
| def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str): | |
| pass | |
| def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str): | |
| assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"] | |
| assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"] | |
| assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"] | |
| assert distance_to_viewer > 0 | |
| pass | |
| ```""" | |
| self.template = template | |
| def from_model_manager(model_manager: ModelManager): | |
| model, model_path = model_manager.fetch_model( | |
| "omost_prompt", require_model_path=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| omost = OmostPromter( | |
| model=model, tokenizer=tokenizer, device=model_manager.device | |
| ) | |
| return omost | |
| def __call__(self, prompt_dict: dict): | |
| raw_prompt = prompt_dict["prompt"] | |
| conversation = [{"role": "system", "content": self.template}] | |
| conversation.append({"role": "user", "content": raw_prompt}) | |
| input_ids = self.tokenizer.apply_chat_template( | |
| conversation, return_tensors="pt", add_generation_prompt=True | |
| ).to(self.device) | |
| streamer = TextIteratorStreamer( | |
| self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| attention_mask = torch.ones( | |
| input_ids.shape, dtype=torch.bfloat16, device=self.device | |
| ) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| streamer=streamer, | |
| # stopping_criteria=stopping_criteria, | |
| # max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| attention_mask=attention_mask, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| # temperature=temperature, | |
| # top_p=top_p, | |
| ) | |
| self.model.generate(**generate_kwargs) | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| llm_outputs = "".join(outputs) | |
| canvas = Canvas.from_bot_response(llm_outputs) | |
| canvas_output = canvas.process() | |
| prompts = [ | |
| " ".join(_["prefixes"] + _["suffixes"][:2]) | |
| for _ in canvas_output["bag_of_conditions"] | |
| ] | |
| canvas_output["prompt"] = prompts[0] | |
| canvas_output["prompts"] = prompts[1:] | |
| raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]] | |
| masks = [] | |
| for mask in raw_masks: | |
| mask[mask > 0.5] = 255 | |
| mask = np.stack([mask] * 3, axis=-1).astype("uint8") | |
| masks.append(Image.fromarray(mask)) | |
| canvas_output["masks"] = masks | |
| prompt_dict.update(canvas_output) | |
| print(f"Your prompt is extended by Omost:\n") | |
| cnt = 0 | |
| for component, pmt in zip(canvas_output["bag_of_conditions"], prompts): | |
| loc = component["location"] | |
| cnt += 1 | |
| print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n") | |
| return prompt_dict | |