Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python | |
| import re | |
| from argparse import ArgumentParser | |
| from functools import lru_cache | |
| from importlib.resources import files | |
| from inspect import signature | |
| from multiprocessing.pool import ThreadPool | |
| from tempfile import NamedTemporaryFile | |
| from textwrap import dedent | |
| from typing import Optional | |
| from PIL import Image | |
| import fitz | |
| import gradio as gr | |
| from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline, AutoModelForPreTraining, AutoProcessor | |
| import os | |
| # from pix2tex.cli import LatexOCR | |
| from munch import Munch | |
| import spaces | |
| from infer import TikzDocument, TikzGenerator | |
| # assets = files(__package__) / "assets" if __package__ else files("assets") / "." | |
| models = { | |
| # "pix2tikz": "pix2tikz/mixed_e362_step201.pth", | |
| "llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b", | |
| "new llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b v2" | |
| } | |
| def is_quantization(model_name): | |
| return "waleko/TikZ-llava" in model_name | |
| def cached_load(model_name, **kwargs) -> ImageToTextPipeline: | |
| # split | |
| model_dict = model_name.split(" ") | |
| revision = "main" | |
| if len(model_dict) > 1: | |
| model_name, revision = model_dict | |
| gr.Info("Instantiating model. Could take a while...") # type: ignore | |
| if not is_quantization(model_name): | |
| return pipeline("image-to-text", model=model_name, revision=revision, **kwargs) | |
| else: | |
| model = AutoModelForPreTraining.from_pretrained(model_name, load_in_4bit=True, revision=revision, **kwargs) | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| return pipeline(task="image-to-text", model=model, tokenizer=processor.tokenizer, image_processor=processor.image_processor) | |
| def convert_to_svg(pdf): | |
| doc = fitz.open("pdf", pdf.raw) # type: ignore | |
| return doc[0].get_svg_image() | |
| # def pix2tikz( | |
| # checkpoint: str, | |
| # image: Image.Image, | |
| # temperature: float, | |
| # _: float, | |
| # __: int, | |
| # ___: bool, | |
| # ): | |
| # cur_pwd = os.path.dirname(os.path.abspath(__file__)) | |
| # config_path = os.path.join(cur_pwd, 'pix2tikz/config.yaml') | |
| # model_path = os.path.join(cur_pwd, checkpoint) | |
| # | |
| # print(cur_pwd, config_path, model_path, os.path.exists(config_path), os.path.exists(model_path)) | |
| # | |
| # args = Munch({'config': config_path, | |
| # 'checkpoint': model_path, | |
| # 'no_resize': False, | |
| # 'no_cuda': False, | |
| # 'temperature': temperature}) | |
| # model = LatexOCR(args) | |
| # res = model(image) | |
| # text = re.sub(r'\\n(?=\W)', '\n', res) | |
| # return text, None, True | |
| def inference( | |
| model_name: str, | |
| image_dict: dict, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| expand_to_square: bool, | |
| ): | |
| try: | |
| image = image_dict['composite'] | |
| if "pix2tikz" in model_name: | |
| # yield pix2tikz(model_name, image, temperature, top_p, top_k, expand_to_square) | |
| return | |
| generate = TikzGenerator( | |
| cached_load(model_name, device_map="auto"), | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| expand_to_square=expand_to_square, | |
| ) | |
| streamer = TextIteratorStreamer( | |
| generate.pipeline.tokenizer, # type: ignore | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| thread = ThreadPool(processes=1) | |
| async_result = thread.apply_async(spaces.GPU(generate), kwds=dict(image=image, streamer=streamer)) | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| yield generated_text, None, False | |
| yield async_result.get().code, None, True | |
| except Exception as e: | |
| raise gr.Error(f"Internal Error! {e}") | |
| def tex_compile( | |
| code: str, | |
| timeout: int, | |
| rasterize: bool | |
| ): | |
| tikzdoc = TikzDocument(code, timeout=timeout) | |
| if not tikzdoc.has_content: | |
| if tikzdoc.compiled_with_errors: | |
| raise gr.Error("TikZ code did not compile!") # type: ignore | |
| else: | |
| gr.Warning("TikZ code compiled to an empty image!") # type: ignore | |
| elif tikzdoc.compiled_with_errors: | |
| # gr.Warning("TikZ code compiled with errors!") # type: ignore | |
| print("TikZ code compiled with errors!") | |
| if rasterize: | |
| yield tikzdoc.rasterize() | |
| else: | |
| with NamedTemporaryFile(suffix=".svg", buffering=0) as tmpfile: | |
| if pdf:=tikzdoc.pdf: | |
| tmpfile.write(convert_to_svg(pdf).encode()) | |
| yield tmpfile.name if pdf else None | |
| def check_inputs(image: Image.Image): | |
| if not image: | |
| raise gr.Error("Image is required") | |
| def get_banner(): | |
| return dedent('''\ | |
| # Ti*k*Z Assistant: Sketches to Vector Graphics with Ti*k*Z | |
| <p> | |
| <!--<a style="display:inline-block" href="https://github.com/potamides/AutomaTikZ"> | |
| <img src="https://img.shields.io/badge/View%20on%20GitHub-green?logo=github&labelColor=gray" alt="View on GitHub"> | |
| </a> | |
| <a style="display:inline-block" href="https://arxiv.org/abs/2310.00367"> | |
| <img src="https://img.shields.io/badge/View%20on%20arXiv-B31B1B?logo=arxiv&labelColor=gray" alt="View on arXiv"> | |
| </a> | |
| <a style="display:inline-block" href="https://colab.research.google.com/drive/14S22x_8VohMr9pbnlkB4FqtF4n81khIh"> | |
| <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"> | |
| </a>--> | |
| <a style="display:inline-block" href="https://huggingface.co/spaces/waleko/TikZ-Assistant"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg" alt="Open in HF Spaces"> | |
| </a> | |
| </p> | |
| ''') | |
| def remove_darkness(stylable): | |
| """ | |
| Patch gradio to only contain light mode colors. | |
| """ | |
| if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme | |
| params = signature(stylable.set).parameters | |
| colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params} | |
| return stylable.set(**colors) | |
| elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals) | |
| stylable.load(js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))") | |
| return stylable | |
| else: | |
| raise ValueError | |
| def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=False, lock_reason="locked", timeout=120): | |
| theme = remove_darkness(gr.themes.Soft()) if force_light else gr.themes.Soft() | |
| with gr.Blocks(theme=theme, title="TikZ Assistant") as demo: # type: ignore | |
| if force_light: remove_darkness(demo) | |
| gr.Markdown(get_banner()) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(): | |
| info = ( | |
| "Describe what you want to generate. " | |
| "Scientific graphics benefit from captions with at least 30 tokens (see examples below), " | |
| "while simple objects work best with shorter, 2-3 word captions." | |
| ) | |
| # caption = gr.Textbox(label="Caption", info=info, placeholder="Type a caption...") | |
| # image = gr.Image(label="Image Input", type="pil") | |
| image = gr.ImageEditor(label="Image Input", type="pil", sources=['upload', 'clipboard'], value=Image.new('RGB', (336, 336), (255, 255, 255))) | |
| label = "Model" + (f" ({lock_reason})" if lock else "") | |
| model = gr.Dropdown(label=label, choices=list(models.items()), value=models[model], interactive=not lock) # type: ignore | |
| with gr.Accordion(label="Advanced Options", open=False): | |
| temperature = gr.Slider(minimum=0, maximum=2, step=0.05, value=0.8, label="Temperature") | |
| top_p = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.95, label="Top-P") | |
| top_k = gr.Slider(minimum=0, maximum=100, step=10, value=0, label="Top-K") | |
| expand_to_square = gr.Checkbox(value=True, label="Expand image to square") | |
| with gr.Row(): | |
| run_btn = gr.Button("Run", variant="primary") | |
| stop_btn = gr.Button("Stop") | |
| clear_btn = gr.ClearButton([image]) | |
| with gr.Column(): | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem(label:="TikZ Code", id=0): | |
| info = "Source code of the generated image." | |
| tikz_code = gr.Code(label=label, show_label=False, interactive=False) | |
| with gr.TabItem(label:="Compiled Image", id=1): | |
| result_image = gr.Image(label=label, show_label=False, show_share_button=rasterize) | |
| clear_btn.add([tikz_code, result_image]) | |
| gr.Examples(examples=[ | |
| ["https://waleko.github.io/data/image.jpg"], | |
| ["https://waleko.github.io/data/image2.jpg"], | |
| ["https://waleko.github.io/data/image3.jpg"], | |
| ["https://waleko.github.io/data/image4.jpg"], | |
| ], inputs=[image]) | |
| events = list() | |
| finished = gr.Textbox(visible=False) # hack to cancel compile on canceled inference | |
| for listener in [run_btn.click]: | |
| generate_event = listener( | |
| check_inputs, | |
| inputs=[image], | |
| queue=False | |
| ).success( | |
| lambda: gr.Tabs(selected=0), | |
| outputs=tabs, # type: ignore | |
| queue=False | |
| ).then( | |
| inference, | |
| inputs=[model, image, temperature, top_p, top_k, expand_to_square], | |
| outputs=[tikz_code, result_image, finished] | |
| ) | |
| def tex_compile_if_finished(finished, *args): | |
| yield from (tex_compile(*args, timeout=timeout, rasterize=rasterize) if finished == "True" else []) | |
| compile_event = generate_event.then( | |
| lambda finished: gr.Tabs(selected=1) if finished == "True" else gr.Tabs(), | |
| inputs=finished, | |
| outputs=tabs, # type: ignore | |
| queue=False | |
| ).then( | |
| tex_compile_if_finished, | |
| inputs=[finished, tikz_code], | |
| outputs=result_image | |
| ) | |
| events.extend([generate_event, compile_event]) | |
| # model.select(lambda model_name: gr.Image(visible="clima" in model_name), inputs=model, outputs=image, queue=False) | |
| for btn in [clear_btn, stop_btn]: | |
| btn.click(fn=None, cancels=events, queue=False) | |
| return demo | |