Spaces:
Sleeping
Sleeping
File size: 6,249 Bytes
2d8b8bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
import base64
import fitz
from io import BytesIO
from PIL import Image
import requests
from llama_index.llms.nvidia import NVIDIA
from llama_index.vector_stores.milvus import MilvusVectorStore
from dotenv import load_dotenv
load_dotenv()
def set_environment_variables():
"""Set necessary environment variables."""
os.environ["NVIDIA_API_KEY"] = os.getenv("NVIDIA_API_KEY") #set API key
def get_b64_image_from_content(image_content):
"""Convert image content to base64 encoded string."""
img = Image.open(BytesIO(image_content))
if img.mode != 'RGB':
img = img.convert('RGB')
buffered = BytesIO()
img.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def is_graph(image_content):
"""Determine if an image is a graph, plot, chart, or table."""
res = describe_image(image_content)
return any(keyword in res.lower() for keyword in ["graph", "plot", "chart", "table"])
def process_graph(image_content):
"""Process a graph image and generate a description."""
deplot_description = process_graph_deplot(image_content)
mixtral = NVIDIA(model_name="meta/llama-3.1-70b-instruct")
response = mixtral.complete("Your responsibility is to explain charts. You are an expert in describing the responses of linearized tables into plain English text for LLMs to use. Explain the following linearized table. " + deplot_description)
return response.text
def describe_image(image_content):
"""Generate a description of an image using NVIDIA API."""
image_b64 = get_b64_image_from_content(image_content)
invoke_url = "https://ai.api.nvidia.com/v1/vlm/nvidia/neva-22b"
api_key = os.getenv("NVIDIA_API_KEY")
if not api_key:
raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.")
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
payload = {
"messages": [
{
"role": "user",
"content": f"""
Describe what you see in this image:
<img src="data:image/png;base64,{image_b64}" />
Also include:
1. Visible text extraction discovering names and description of products(can use ocr).
2. Inferred location or scene type in the image.
4. Date/time information and its location.
"""
}
],
"max_tokens": 1024,
"temperature": 0.20,
"top_p": 0.70,
"seed": 0,
"stream": False
}
response = requests.post(invoke_url, headers=headers, json=payload)
return response.json()["choices"][0]['message']['content']
def process_graph_deplot(image_content):
"""Process a graph image using NVIDIA's Deplot API."""
invoke_url = "https://ai.api.nvidia.com/v1/vlm/google/deplot"
image_b64 = get_b64_image_from_content(image_content)
api_key = os.getenv("NVIDIA_API_KEY")
if not api_key:
raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.")
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
payload = {
"messages": [
{
"role": "user",
"content": f'Generate underlying data table of the figure below: <img src="data:image/png;base64,{image_b64}" />'
}
],
"max_tokens": 1024,
"temperature": 0.20,
"top_p": 0.20,
"stream": False
}
response = requests.post(invoke_url, headers=headers, json=payload)
return response.json()["choices"][0]['message']['content']
def extract_text_around_item(text_blocks, bbox, page_height, threshold_percentage=0.1):
"""Extract text above and below a given bounding box on a page."""
before_text, after_text = "", ""
vertical_threshold_distance = page_height * threshold_percentage
horizontal_threshold_distance = bbox.width * threshold_percentage
for block in text_blocks:
block_bbox = fitz.Rect(block[:4])
vertical_distance = min(abs(block_bbox.y1 - bbox.y0), abs(block_bbox.y0 - bbox.y1))
horizontal_overlap = max(0, min(block_bbox.x1, bbox.x1) - max(block_bbox.x0, bbox.x0))
if vertical_distance <= vertical_threshold_distance and horizontal_overlap >= -horizontal_threshold_distance:
if block_bbox.y1 < bbox.y0 and not before_text:
before_text = block[4]
elif block_bbox.y0 > bbox.y1 and not after_text:
after_text = block[4]
break
return before_text, after_text
def process_text_blocks(text_blocks, char_count_threshold=500):
"""Group text blocks based on a character count threshold."""
current_group = []
grouped_blocks = []
current_char_count = 0
for block in text_blocks:
if block[-1] == 0: # Check if the block is of text type
block_text = block[4]
block_char_count = len(block_text)
if current_char_count + block_char_count <= char_count_threshold:
current_group.append(block)
current_char_count += block_char_count
else:
if current_group:
grouped_content = "\n".join([b[4] for b in current_group])
grouped_blocks.append((current_group[0], grouped_content))
current_group = [block]
current_char_count = block_char_count
# Append the last group
if current_group:
grouped_content = "\n".join([b[4] for b in current_group])
grouped_blocks.append((current_group[0], grouped_content))
return grouped_blocks
def save_uploaded_file(uploaded_file):
"""Save an uploaded file to a temporary directory."""
temp_dir = os.path.join(os.getcwd(), "vectorstore", "ppt_references", "tmp")
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, uploaded_file.name)
with open(temp_file_path, "wb") as temp_file:
temp_file.write(uploaded_file.read())
return temp_file_path |