Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
| import spaces | |
| import torch | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_id="mychen76/paligemma-receipt-json-3b-mix-448-v2b" | |
| dtype = torch.bfloat16 | |
| model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=dtype).to(device).eval() | |
| processor = PaliGemmaProcessor.from_pretrained(model_id) | |
| MAX_TOKENS = 512 | |
| import re | |
| # let's turn that into JSON source from Donut | |
| def token2json(tokens, is_inner_value=False, added_vocab=None): | |
| """ | |
| Convert a (generated) token sequence into an ordered JSON format. | |
| """ | |
| if added_vocab is None: | |
| added_vocab = processor.tokenizer.get_added_vocab() | |
| output = {} | |
| while tokens: | |
| start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE) | |
| if start_token is None: | |
| break | |
| key = start_token.group(1) | |
| key_escaped = re.escape(key) | |
| end_token = re.search(rf"</s_{key_escaped}>", tokens, re.IGNORECASE) | |
| start_token = start_token.group() | |
| if end_token is None: | |
| tokens = tokens.replace(start_token, "") | |
| else: | |
| end_token = end_token.group() | |
| start_token_escaped = re.escape(start_token) | |
| end_token_escaped = re.escape(end_token) | |
| content = re.search( | |
| f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL | |
| ) | |
| if content is not None: | |
| content = content.group(1).strip() | |
| if r"<s_" in content and r"</s_" in content: # non-leaf node | |
| value = token2json(content, is_inner_value=True, added_vocab=added_vocab) | |
| if value: | |
| if len(value) == 1: | |
| value = value[0] | |
| output[key] = value | |
| else: # leaf nodes | |
| output[key] = [] | |
| for leaf in content.split(r"<sep/>"): | |
| leaf = leaf.strip() | |
| if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>": | |
| leaf = leaf[1:-2] # for categorical special tokens | |
| output[key].append(leaf) | |
| if len(output[key]) == 1: | |
| output[key] = output[key][0] | |
| tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() | |
| if tokens[:6] == r"<sep/>": # non-leaf nodes | |
| return [output] + token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab) | |
| if len(output): | |
| return [output] if is_inner_value else output | |
| else: | |
| return [] if is_inner_value else {"text_sequence": tokens} | |
| def modify_caption(caption: str) -> str: | |
| """ | |
| Removes specific prefixes from captions. | |
| Args: | |
| caption (str): A string containing a caption. | |
| Returns: | |
| str: The caption with the prefix removed if it was present. | |
| """ | |
| # Define the prefixes to remove | |
| prefix_substrings = [ | |
| ('EXTRACT_JSON_RECEIPT', '') | |
| ] | |
| # Create a regex pattern to match any of the prefixes | |
| pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) | |
| replacers = {opening: replacer for opening, replacer in prefix_substrings} | |
| # Function to replace matched prefix with its corresponding replacement | |
| def replace_fn(match): | |
| return replacers[match.group(0)] | |
| # Apply the regex to the caption | |
| return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE) | |
| def json_inference(image, input_text="EXTRACT_JSON_RECEIPT", device="cuda:0", max_new_tokens=512): | |
| inputs = processor(text=input_text, images=image, return_tensors="pt").to(device) | |
| # Autoregressively generate use greedy decoding here,for more fancy methods see https://huggingface.co/blog/how-to-generate | |
| generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| # Next turn each predicted token ID back into a string using the decode method | |
| # We chop of the prompt, which consists of image tokens and our text prompt | |
| image_token_index = model.config.image_token_index | |
| num_image_tokens = len(generated_ids[generated_ids==image_token_index]) | |
| num_text_tokens = len(processor.tokenizer.encode(input_text)) | |
| num_prompt_tokens = num_image_tokens + num_text_tokens + 2 | |
| generated_text = processor.batch_decode(generated_ids[:, num_prompt_tokens:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| # convert it into JSON using the method below (taken from Donut): | |
| generated_json = token2json(generated_text) | |
| return generated_text, generated_json | |
| # enable space | |
| # @spaces.GPU | |
| def create_captions_rich(image): | |
| torch.cuda.empty_cache() | |
| prompt = "EXTRACT_JSON_RECEIPT" | |
| generated_text, generated_json = json_inference(image=image,input_text="EXTRACT_JSON_RECEIPT", device=device, max_new_tokens=MAX_TOKENS) | |
| return generated_json | |
| css = """ | |
| #mkd { | |
| height: 500px; | |
| overflow: auto; | |
| border: 1px solid #ccc; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML("<h1><center>PaliGemma Receipt and Invoice Model<center><h1>") | |
| with gr.Tab(label="Receipt or Invoices Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image(label="Input Picture") | |
| submit_btn = gr.Button(value="Submit") | |
| output = gr.Text(label="Receipt Json") | |
| gr.Examples([["receipt_image1.jpg"], ["receipt_image2.jpg"], ["receipt_image3.png"],["receipt_image4.png"]], | |
| inputs = [input_img], | |
| outputs = [output], | |
| fn=create_captions_rich, | |
| label='Try captioning on examples' | |
| ) | |
| submit_btn.click(create_captions_rich, [input_img], [output]) | |
| demo.queue().launch(share=True,server_name="0.0.0.0",debug=True) |