{ "cells": [ { "cell_type": "markdown", "source": [ "# nougat-small-onnx example\n", "\n", "\n" ], "metadata": { "id": "WzFn5paySP9y" } }, { "cell_type": "code", "source": [ "!free -h\n", "!nproc\n", "!date" ], "metadata": { "id": "-l9Xxxfud9Db", "outputId": "69360f08-7389-4ce2-b6f9-5b48ac656f7d", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " total used free shared buff/cache available\n", "Mem: 12Gi 5.2Gi 4.2Gi 8.0Mi 3.3Gi 7.2Gi\n", "Swap: 0B 0B 0B\n", "2\n", "Wed Nov 22 11:57:53 PM UTC 2023\n" ] } ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lIYdn1woOS1n", "outputId": "daf86843-5b7b-421a-9379-21512bcde2ed" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m247.1/247.1 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for l (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for betterpath (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "!pip install transformers accelerate optimum[onnxruntime] -U -q\n", "!pip install onnx onnxruntime==1.15.1 nougat-ocr l python-Levenshtein -q" ] }, { "cell_type": "code", "source": [ "!pip install -U -q git+https://github.com/huggingface/optimum.git" ], "metadata": { "id": "zG3HjgwxShob", "outputId": "0d8bc751-1d66-41a8-96e9-c0e4fa88842c", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for optimum (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" ] } ] }, { "cell_type": "code", "source": [ "from optimum.onnxruntime import ORTModelForVision2Seq\n", "from transformers import NougatProcessor\n", "\n", "model_name = 'pszemraj/nougat-small-onnx'\n", "processor = NougatProcessor.from_pretrained(model_name)\n", "model = ORTModelForVision2Seq.from_pretrained(\n", " model_name,\n", " provider=\"CPUExecutionProvider\", # 'CUDAExecutionProvider' for gpu\n", " use_merged=False,\n", " use_io_binding=True\n", ")" ], "metadata": { "id": "N4Xr6XygRsi0" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "!pip list | grep cu" ], "metadata": { "id": "Dxcm8_HRRuGs", "outputId": "4232cb3f-75d8-45be-da72-1776b724b24b", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "cufflinks 0.17.3\n", "cupy-cuda11x 11.0.0\n", "docutils 0.18.1\n", "jaxlib 0.4.20+cuda11.cudnn86\n", "torch 2.1.0+cu118\n", "torchaudio 2.1.0+cu118\n", "torchvision 0.16.0+cu118\n" ] } ] }, { "cell_type": "code", "source": [ "!pip uninstall cupy-cuda11x -y" ], "metadata": { "id": "wiEMgl1NU0o4", "outputId": "0a6fbe76-2aac-47d2-f098-ace548570752", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Found existing installation: cupy-cuda11x 11.0.0\n", "Uninstalling cupy-cuda11x-11.0.0:\n", " Successfully uninstalled cupy-cuda11x-11.0.0\n" ] } ] }, { "cell_type": "markdown", "source": [ "# New Section" ], "metadata": { "id": "cXIncEKDWKNW" } }, { "cell_type": "code", "source": [ "from huggingface_hub import hf_hub_download\n", "import re\n", "from PIL import Image\n", "import requests\n", "from nougat.dataset.rasterize import rasterize_paper\n", "from transformers import NougatProcessor, VisionEncoderDecoderModel\n", "import torch\n", "import uuid\n", "import os\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig" ], "metadata": { "id": "jfOGuPZ3WN-H" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "# %%\n", "from transformers import StoppingCriteria, StoppingCriteriaList\n", "from collections import defaultdict\n", "\n", "class RunningVarTorch:\n", " def __init__(self, L=15, norm=False):\n", " self.values = None\n", " self.L = L\n", " self.norm = norm\n", "\n", " def push(self, x: torch.Tensor):\n", " assert x.dim() == 1\n", " if self.values is None:\n", " self.values = x[:, None]\n", " elif self.values.shape[1] < self.L:\n", " self.values = torch.cat((self.values, x[:, None]), 1)\n", " else:\n", " self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)\n", "\n", " def variance(self):\n", " if self.values is None:\n", " return\n", " if self.norm:\n", " return torch.var(self.values, 1) / self.values.shape[1]\n", " else:\n", " return torch.var(self.values, 1)\n", "\n", "#@title nougat things\n", "class StoppingCriteriaScores(StoppingCriteria):\n", " def __init__(self, threshold: float = 0.015, window_size: int = 200):\n", " super().__init__()\n", " self.threshold = threshold\n", " self.vars = RunningVarTorch(norm=True)\n", " self.varvars = RunningVarTorch(L=window_size)\n", " self.stop_inds = defaultdict(int)\n", " self.stopped = defaultdict(bool)\n", " self.size = 0\n", " self.window_size = window_size\n", "\n", " @torch.no_grad()\n", " def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):\n", " last_scores = scores[-1]\n", " self.vars.push(last_scores.max(1)[0].float().cpu())\n", " self.varvars.push(self.vars.variance())\n", " self.size += 1\n", " if self.size < self.window_size:\n", " return False\n", "\n", " varvar = self.varvars.variance()\n", " for b in range(len(last_scores)):\n", " if varvar[b] < self.threshold:\n", " if self.stop_inds[b] > 0 and not self.stopped[b]:\n", " self.stopped[b] = self.stop_inds[b] >= self.size\n", " else:\n", " self.stop_inds[b] = int(\n", " min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)\n", " )\n", " else:\n", " self.stop_inds[b] = 0\n", " self.stopped[b] = False\n", " return all(self.stopped.values()) and len(self.stopped) > 0" ], "metadata": { "cellView": "form", "id": "igzTKjsHBDnq" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "from pathlib import Path\n", "\n", "from tqdm.auto import tqdm\n", "\n", "\n", "import os\n", "import requests\n", "import uuid\n", "from urllib.parse import urlparse, unquote\n", "\n", "#@title inference fns\n", "def get_pdf(pdf_link):\n", " # Parse the URL to extract the filename\n", " parsed_url = urlparse(pdf_link)\n", " filename = os.path.basename(unquote(parsed_url.path))\n", "\n", " # Use a unique filename if the URL doesn't contain a filename or if it's not a PDF\n", " if not filename or not filename.lower().endswith('.pdf'):\n", " filename = f\"downloaded_paper_{uuid.uuid4().hex}.pdf\"\n", "\n", " full_path = os.path.join(os.getcwd(), filename)\n", "\n", " try:\n", " response = requests.get(pdf_link)\n", " response.raise_for_status() # Raise an HTTPError if the HTTP request returned an unsuccessful status code\n", "\n", " with open(full_path, \"wb\") as pdf_file:\n", " pdf_file.write(response.content)\n", " print(\"PDF downloaded successfully.\")\n", " except requests.HTTPError as http_err:\n", " print(f\"HTTP error occurred: {http_err}\")\n", " except Exception as err:\n", " print(f\"An error occurred: {err}\")\n", "\n", " return full_path\n", "\n", "def predict(image, fix_markdown=False):\n", " # prepare PDF image for the model\n", " image = Image.open(image)\n", " pixel_values = processor(image, return_tensors=\"pt\").pixel_values\n", "\n", " with torch.backends.cuda.sdp_kernel(\n", " enable_flash=True, enable_math=False, enable_mem_efficient=False\n", " ):\n", " outputs = model.generate(\n", " pixel_values.to(model.device),\n", " min_length=1,\n", " max_new_tokens=3584,\n", " bad_words_ids=[[processor.tokenizer.unk_token_id]],\n", " return_dict_in_generate=True,\n", " output_scores=True,\n", " stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]),\n", " )\n", " generated = processor.batch_decode(outputs[0], skip_special_tokens=True)[0]\n", " generated = processor.post_process_generation(generated, fix_markdown=fix_markdown)\n", "\n", " return generated\n", "\n", "\n", "def inference(pdf_file, pdf_link=None, fix_markdown=False):\n", " if pdf_file is None:\n", " if pdf_link == \"\" or pdf_link is None:\n", " print(\"No file is uploaded and No link is provided\")\n", " return \"No data provided. Upload a pdf file or provide a pdf link and try again!\"\n", " else:\n", " file_name = get_pdf(pdf_link)\n", " else:\n", " pdf_file = Path(pdf_file)\n", " file_name = pdf_file.name\n", " pdf_name = pdf_file.name.split(\"/\")[-1].split(\".\")[0]\n", "\n", " images = rasterize_paper(file_name, return_pil=True)\n", " sequence = \"\"\n", " # infer for every page and concat\n", " for image in tqdm(images, desc=\"inference\"):\n", " sequence += predict(image, fix_markdown=fix_markdown) + \"\\n\\n\"\n", "\n", " content = (\n", " sequence.replace(r\"\\(\", \"$\")\n", " .replace(r\"\\)\", \"$\")\n", " .replace(r\"\\[\", \"$$\")\n", " .replace(r\"\\]\", \"$$\")\n", " )\n", " out_path = Path.cwd() / f\"output-{Path(file_name).stem}.md\"\n", " with open(out_path, \"w+\") as f:\n", " f.write(content)\n", " f.close()\n", "\n", " return content, str(out_path.resolve())" ], "metadata": { "id": "Z659VnsO48gb", "cellView": "form" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "inference(None, 'https://arxiv.org/pdf/1706.03762.pdf', fix_markdown=True)" ], "metadata": { "id": "H5RT4hlyWay4", "outputId": "9a920319-968d-4916-bae5-8ed0b3338b52", "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "bef56be303034fb4befe96ad3695e119", "cab30515b4774dd4a18149cc12014712", "08d1986360d143179f0b7fcd513e0bcc", "b360a6604e784a95a99c6476d39fe466", "61146889483246f181bfb4c241ec4ca7", "00e89bed6c604cdb9667558074e45299", "cbec35456099410794c52ce247c841b0", "ff10d84d407e4bec8f5a3de6b325a7c4", "e9134b804c234ce0a444ed95b966b6cf", "ae5318a674e44383b89d7213301dc06f", "be497defc0dd47ba8eeb7456306418ba" ] } }, "execution_count": 8, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "PDF downloaded successfully.\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "inference: 0%| | 0/15 [00:00