Spaces:
Runtime error
Runtime error
| import pandas | |
| import numpy | |
| import pandas.io.formats.style | |
| import random | |
| import functools | |
| from typing import Callable, Literal | |
| DATA_FOLDER = "." | |
| CAT_GENERAL = 0 | |
| CAT_ARTIST = 1 | |
| CAT_UNUSED = 2 | |
| CAT_COPYRIGHT = 3 | |
| CAT_CHARACTER = 4 | |
| CAT_SPECIES = 5 | |
| CAT_INVALID = 6 | |
| CAT_META = 7 | |
| CAT_LORE = 8 | |
| CATEGORY_COLORS = { | |
| CAT_GENERAL: "#808080", | |
| CAT_ARTIST: "#f2ac08", | |
| CAT_UNUSED: "#ff3d3d", | |
| CAT_COPYRIGHT: "#d0d", | |
| CAT_CHARACTER: "#0a0", | |
| CAT_SPECIES: "#ed5d1f", | |
| CAT_INVALID: "#ff3d3d", | |
| CAT_META: "#04f", | |
| CAT_LORE: "#282" | |
| } | |
| def get_feather(filename: str) -> pandas.DataFrame: | |
| return pandas.read_feather(f"{DATA_FOLDER}/{filename}.feather") | |
| tags = get_feather("tags") | |
| posts_by_tag = get_feather("posts_by_tag").set_index("tag_id") | |
| tags_by_post = get_feather("tags_by_post").set_index("post_id") | |
| tag_ratings = get_feather("tag_ratings") | |
| implications = get_feather("implications") | |
| tags_by_name = tags.copy(deep=True) | |
| tags_by_name.set_index("name", inplace=True) | |
| tags.set_index("tag_id", inplace=True) | |
| def get_related_tags(targets: tuple[str, ...], exclude: tuple[str, ...] = (), samples: int = 100_000) -> pandas.DataFrame: | |
| these_tags = tags_by_name.loc[list(targets)] | |
| posts_with_these_tags = posts_by_tag.loc[these_tags["tag_id"]].map(set).groupby(lambda x: True).agg(lambda x: set.intersection(*x))["post_id"][True] | |
| if (len(exclude) > 0): | |
| excluded_tags = tags_by_name.loc[list(exclude)] | |
| posts_with_excluded_tags = posts_by_tag.loc[excluded_tags["tag_id"]].map(set).groupby(lambda x: True).agg(lambda x: set.union(*x))["post_id"][True] | |
| posts_with_these_tags = posts_with_these_tags - posts_with_excluded_tags | |
| total_post_count_together = len(posts_with_these_tags) | |
| sample_posts = random.sample(list(posts_with_these_tags), samples) if total_post_count_together > samples else list(posts_with_these_tags) | |
| post_count_together = len(sample_posts) | |
| sample_ratio = post_count_together / total_post_count_together | |
| tags_in_these_posts = tags_by_post.loc[sample_posts] | |
| counts_in_these_posts = tags_in_these_posts["tag_id"].explode().value_counts().rename("overlap") | |
| summaries = pandas.DataFrame(counts_in_these_posts).join(tags[tags["post_count"]>0], how="right").fillna(0) | |
| summaries["overlap"] = numpy.minimum(summaries["overlap"] / sample_ratio, summaries["post_count"]) | |
| summaries = summaries[["category", "name", "overlap", "post_count"]] | |
| # Old "interestingness" value, didn't give as good results as an actual statistical technique, go figure. Code kept for curiosity's sake. | |
| #summaries["interestingness"] = summaries["overlap"].pow(2) / (total_post_count_together * summaries["post_count"]) | |
| # Phi coefficient stuff. | |
| n = float(len(tags_by_post)) | |
| n11 = summaries["overlap"] | |
| n1x = float(total_post_count_together) | |
| nx1 = summaries["post_count"].astype("float64") | |
| summaries["correlation"] = (n * n11 - n1x * nx1) / numpy.sqrt(n1x * nx1 * (n - n1x) * (n - nx1)) | |
| return summaries | |
| def format_tags(styler: pandas.io.formats.style.Styler): | |
| styler.apply(lambda row: numpy.where(row.index == "name", "color:"+CATEGORY_COLORS[row["category"]], ""), axis=1) | |
| styler.hide(level=0) | |
| styler.hide("category",axis=1) | |
| if 'overlap' in styler.data: | |
| styler.format("{:.0f}".format, subset=["overlap"]) | |
| if 'correlation' in styler.data: | |
| styler.format("{:.2f}".format, subset=["correlation"]) | |
| styler.background_gradient(vmin=-1.0, vmax=1.0, cmap="RdYlGn", subset=["correlation"]) | |
| if 'score' in styler.data: | |
| styler.format("{:.2f}".format, subset=["score"]) | |
| styler.background_gradient(vmin=-1.0, vmax=1.0, cmap="RdYlGn", subset=["score"]) | |
| return styler | |
| def related_tags(*targets: str, exclude: tuple[str, ...] = (), category: int = None, samples: int = 100_000, min_overlap: int = 5, min_posts: int = 20, top: int = 30, bottom: int = 0) -> pandas.DataFrame: | |
| result = get_related_tags(targets, exclude=exclude, samples=samples) | |
| if category != None: | |
| result = result[result["category"] == category] | |
| result = result[~result["name"].isin(targets)] | |
| result = result[result["overlap"] >= min_overlap] | |
| result = result[result["post_count"] >= min_posts] | |
| top_part = result.sort_values("correlation", ascending=False)[:top] | |
| bottom_part = result.sort_values("correlation", ascending=True)[:bottom].sort_values("correlation", ascending=False) | |
| return pandas.concat([top_part, bottom_part]).style.pipe(format_tags) | |
| def implications_for(*subjects: str, seen: set[str] = None): | |
| if seen is None: | |
| seen = set() | |
| for subject in subjects: | |
| found = tags.loc[list(implications[implications["antecedent_id"] == tags_by_name.loc[subject, "tag_id"]].loc[:,"consequent_id"]), "name"].values | |
| for f in found: | |
| if f in seen: | |
| pass | |
| else: | |
| yield f | |
| seen.add(f) | |
| yield from implications_for(f, seen=seen) | |
| def parse_tag(potential_tag: str): | |
| potential_tag = potential_tag.strip().replace(" ", "_").replace("\\(", "(").replace("\\)", ")") | |
| if potential_tag == "": | |
| return None | |
| elif potential_tag in tags_by_name.index: | |
| return potential_tag | |
| elif potential_tag.startswith("by_") and potential_tag[3:] in tags_by_name.index: | |
| return potential_tag[3:] | |
| else: | |
| print(f"Couldn't find tag '{potential_tag}', skipping it.") | |
| def parse_tags(*parts: str): | |
| for part in parts: | |
| for potential_tag in part.split(","): | |
| tag = parse_tag(potential_tag) | |
| if tag is not None: | |
| yield tag | |
| def add_suggestions(suggestions: pandas.DataFrame, new_tags: str | list[str], multiplier: int, samples : int, min_posts: int, rating: Literal['s', 'q', 'e']): | |
| if isinstance(new_tags, str): | |
| new_tags = [new_tags] | |
| for new_tag in new_tags: | |
| related = get_related_tags((new_tag,), samples=samples) | |
| # Implementing the rating filter this way is horribly inefficient, fix it later | |
| if rating == 's': | |
| related = related.join(tag_ratings.set_index("tag_id"), on="tag_id") | |
| related["post_count"] = related["s"] | |
| related = related.drop("s", axis=1) | |
| related = related.drop("q", axis=1) | |
| related = related.drop("e", axis=1) | |
| elif rating == 'q': | |
| related = related.join(tag_ratings.set_index("tag_id"), on="tag_id") | |
| related["post_count"] = related["s"] + related["q"] | |
| related = related.drop("s", axis=1) | |
| related = related.drop("q", axis=1) | |
| related = related.drop("e", axis=1) | |
| related = related[related["post_count"] >= min_posts] | |
| if suggestions is None: | |
| suggestions = related.rename(columns={"correlation": "score"}) | |
| else: | |
| suggestions = suggestions.join(related, rsuffix="r") | |
| # This is a totally made up way to combine correlations. It keeps them from going outside the +/- 1 range, which is nice. It also makes older | |
| # tags less important every time newer ones are added. That could be considered a feature or not. | |
| suggestions["score"] = numpy.real(numpy.power((numpy.sqrt(suggestions["score"] + 0j) + numpy.sqrt(multiplier * suggestions["correlation"] + 0j)) / 2, 2)) | |
| return suggestions[["category", "name", "post_count", "score"]] | |
| def pick_tags(suggestions: pandas.DataFrame, category: int, count: int, from_top: int, excluding: list[str], weighted: bool = True): | |
| options = suggestions[(True if category is None else suggestions["category"] == category) & (suggestions["score"] > 0) & ~suggestions["name"].isin(excluding)].sort_values("score", ascending=False)[:from_top] | |
| if weighted: | |
| values = list(options["name"].values) | |
| weights = list(options["score"].values) | |
| choices = [] | |
| for _ in range(count): | |
| choice = random.choices(population=values, weights=weights, k=1)[0] | |
| weights.pop(values.index(choice)) | |
| values.remove(choice) | |
| choices.append(choice) | |
| return choices | |
| else: | |
| return random.sample(list(options["name"].values), count) | |
| def tag_to_prompt(tag: str) -> str: | |
| if (tags_by_name.loc[tag]["category"] == CAT_ARTIST): | |
| tag = "by " + tag | |
| return tag.replace("_", " ").replace("(" , "\\(").replace(")" , "\\)") | |
| # A lambda in a for loop doesn't capture variables the way I want it to, so this is a method now | |
| def add_suggestions_later(suggestions: pandas.DataFrame, new_tags: str | list[str], multiplier: int, samples: int, min_posts: int, rating: Literal['s', 'q', 'e']): | |
| return lambda: add_suggestions(suggestions, new_tags, multiplier, samples, min_posts, rating) | |
| Prompt = tuple[list[str], list[str], Callable[[], pandas.DataFrame]] | |
| class PromptBuilder: | |
| prompts: list[Prompt] | |
| samples: int | |
| min_posts: int | |
| rating: Literal['s', 'q', 'e'] | |
| skip_list: list[str] | |
| def __init__(self, prompts = [([],[],lambda: None)], skip=[], samples = 100_000, min_posts = 20, rating: Literal['s', 'q', 'e'] = 'e'): | |
| self.prompts = prompts | |
| self.samples = samples | |
| self.min_posts = min_posts | |
| self.rating = rating | |
| self.skip_list = skip | |
| def include(self, tag: str): | |
| return PromptBuilder(prompts=[ | |
| (tag_list + [tag], negative_list, add_suggestions_later(suggestions(), tag, 1, self.samples, self.min_posts, self.rating)) | |
| for (tag_list, negative_list, suggestions) in self.prompts | |
| ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
| def focus(self, tag: str): | |
| return PromptBuilder(prompts=[ | |
| (tag_list, negative_list, add_suggestions_later(suggestions(), tag, 1, self.samples, self.min_posts, self.rating)) | |
| for (tag_list, negative_list, suggestions) in self.prompts | |
| ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
| def exclude(self, tag: str): | |
| return PromptBuilder(prompts=[ | |
| (tag_list, negative_list + [tag], add_suggestions_later(suggestions(), tag, -1, self.samples, self.min_posts, self.rating)) | |
| for (tag_list, negative_list, suggestions) in self.prompts | |
| ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
| def avoid(self, tag: str): | |
| return PromptBuilder(prompts=[ | |
| (tag_list, negative_list, add_suggestions_later(suggestions(), tag, -1, self.samples, self.min_posts, self.rating)) | |
| for (tag_list, negative_list, suggestions) in self.prompts | |
| ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
| def pick(self, category: int, count: int, from_top: int): | |
| new_prompts = self.prompts | |
| for _ in range(count): | |
| new_prompts = [ | |
| (tag_list + [tag], negative_list, add_suggestions_later(s, tag, 1, self.samples, self.min_posts, self.rating)) | |
| for (tag_list, negative_list, suggestions) in new_prompts | |
| for s in (suggestions(),) | |
| for tag in pick_tags(s, category, 1, from_top, tag_list + negative_list + self.skip_list) | |
| ] | |
| return PromptBuilder(new_prompts, samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
| def foreach_pick(self, category: int, count: int, from_top: int): | |
| return PromptBuilder(prompts=[ | |
| (tag_list + [tag], negative_list, add_suggestions_later(s, tag, 1, self.samples, self.min_posts, self.rating)) | |
| for (tag_list, negative_list, suggestions) in self.prompts | |
| for s in (suggestions(),) | |
| for tag in pick_tags(s, category, count, from_top, tag_list + negative_list + self.skip_list) | |
| ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
| def pick_fast(self, category: int, count: int, from_top: int): | |
| prompts = [] | |
| for (tag_list, negative_list, suggestions) in self.prompts: | |
| s = suggestions() | |
| new_tags = pick_tags(s, category, count, from_top, tag_list + negative_list + self.skip_list) | |
| prompts.append((tag_list + new_tags, negative_list, add_suggestions_later(s, new_tags, 1, self.samples, self.min_posts, self.rating))) | |
| return PromptBuilder(prompts=prompts, samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
| def branch(self, count: int): | |
| return PromptBuilder(prompts=[prompt for prompt in self.prompts for _ in range(count)], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
| def build(self): | |
| for (tag_list, negative_list, _) in self.prompts: | |
| positive_prompt = ", ".join([ tag_to_prompt(tag) for tag in tag_list]) | |
| negative_prompt = ", ".join([ tag_to_prompt(tag) for tag in negative_list]) | |
| if negative_prompt: | |
| yield f"{positive_prompt}\nNegative prompt: {negative_prompt}" | |
| else: | |
| yield positive_prompt | |
| def print(self): | |
| for prompt in self.build(): | |
| print(prompt) | |
| def get_one(self): | |
| for prompt in self.build(): | |
| return prompt | |