Spaces:
Running
Running
Upload 42 files
Browse files- app.py +69 -137
- backup/__pycache__/model.cpython-310.pyc +0 -0
- backup/__pycache__/save_load.cpython-310.pyc +0 -0
- backup/__pycache__/train.cpython-310.pyc +0 -0
- backup/backup.py +58 -58
- backup/model.py +412 -412
- backup/modules/__pycache__/base.cpython-310.pyc +0 -0
- backup/modules/__pycache__/evaluator.cpython-310.pyc +0 -0
- backup/modules/__pycache__/layers.cpython-310.pyc +0 -0
- backup/modules/__pycache__/run_evaluation.cpython-310.pyc +0 -0
- backup/modules/__pycache__/span_rep.cpython-310.pyc +0 -0
- backup/modules/__pycache__/token_rep.cpython-310.pyc +0 -0
- backup/modules/base.py +150 -150
- backup/modules/data_proc.py +73 -73
- backup/modules/evaluator.py +152 -152
- backup/modules/layers.py +28 -28
- backup/modules/run_evaluation.py +188 -188
- backup/modules/span_rep.py +369 -369
- backup/modules/token_rep.py +54 -54
- backup/requirements.txt +5 -5
- backup/save_load.py +20 -20
- backup/train.py +132 -132
- core/__pycache__/base.cpython-310.pyc +0 -0
- core/__pycache__/gradio_ocr.cpython-310.pyc +0 -0
- core/__pycache__/ner_engine.cpython-310.pyc +0 -0
- core/__pycache__/ocr_engine.cpython-310.pyc +0 -0
- core/__pycache__/vlm_engine.cpython-310.pyc +0 -0
- core/base.py +22 -0
- core/gradio_ocr.py +50 -0
- core/ner_engine.py +49 -0
- core/ocr_engine.py +114 -0
- core/vlm_engine.py +91 -0
- requirements.txt +18 -16
- static/uploads/IN_Standard-Visiting-Cards_Overview.png +0 -0
- templates/index.html +236 -284
- templates/result.html +326 -248
- utility/__pycache__/utils.cpython-310.pyc +0 -0
- utility/__pycache__/utils.cpython-312.pyc +0 -0
- utility/__pycache__/utils.cpython-313.pyc +0 -0
- utility/utils.py +120 -688
app.py
CHANGED
|
@@ -1,186 +1,118 @@
|
|
| 1 |
-
# libraries
|
| 2 |
-
from flask import Flask, render_template, request, redirect, url_for, flash, session, send_from_directory
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
-
from
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
|
| 9 |
-
# Configure logging
|
| 10 |
-
logging.basicConfig(
|
| 11 |
-
level=logging.INFO,
|
| 12 |
-
handlers=[
|
| 13 |
-
logging.StreamHandler() # Remove FileHandler and log only to the console
|
| 14 |
-
]
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
# Flask App
|
| 18 |
-
app = Flask(__name__)
|
| 19 |
-
app.secret_key = 'your_secret_key'
|
| 20 |
-
app.config['UPLOAD_FOLDER'] = 'uploads/'
|
| 21 |
-
app.config['RESULT_FOLDER'] = 'results/'
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 26 |
-
os.makedirs(RESULT_FOLDER, exist_ok=True)
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
os.
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
os.
|
| 40 |
-
logging.info("Created PaddleOCR home directory.")
|
| 41 |
-
else:
|
| 42 |
-
logging.info("PaddleOCR home directory exists.")
|
| 43 |
|
| 44 |
@app.route('/')
|
| 45 |
def index():
|
| 46 |
uploaded_files = session.get('uploaded_files', [])
|
| 47 |
-
logging.info(f"Accessed index page, uploaded files: {uploaded_files}")
|
| 48 |
return render_template('index.html', uploaded_files=uploaded_files)
|
| 49 |
|
| 50 |
@app.route('/upload', methods=['POST'])
|
| 51 |
def upload_file():
|
|
|
|
| 52 |
if 'files' not in request.files:
|
|
|
|
| 53 |
flash('No file part')
|
| 54 |
-
logging.warning("No file part found in the request")
|
| 55 |
return redirect(request.url)
|
| 56 |
|
| 57 |
files = request.files.getlist('files')
|
| 58 |
if not files or all(file.filename == '' for file in files):
|
|
|
|
| 59 |
flash('No selected files')
|
| 60 |
-
logging.warning("No files selected for upload")
|
| 61 |
return redirect(request.url)
|
| 62 |
|
| 63 |
-
uploaded_files =
|
| 64 |
for file in files:
|
| 65 |
if file:
|
| 66 |
filename = file.filename
|
| 67 |
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 68 |
file.save(file_path)
|
| 69 |
-
print(f"file path --->{file_path}")
|
| 70 |
uploaded_files.append(filename)
|
| 71 |
-
logging.info(f"
|
| 72 |
|
| 73 |
session['uploaded_files'] = uploaded_files
|
| 74 |
-
flash('Files successfully uploaded')
|
| 75 |
-
logging.info(f"Files successfully uploaded: {uploaded_files}")
|
| 76 |
return process_file()
|
| 77 |
|
| 78 |
-
@app.route('/
|
| 79 |
-
def remove_file():
|
| 80 |
-
uploaded_files = session.get('uploaded_files', [])
|
| 81 |
-
if uploaded_file:
|
| 82 |
-
for filename in uploaded_files:
|
| 83 |
-
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 84 |
-
if os.path.exists(file_path):
|
| 85 |
-
os.remove(file_path)
|
| 86 |
-
logging.info(f"Removed file: {filename}")
|
| 87 |
-
else:
|
| 88 |
-
logging.warning(f"File not found for removal: {file_path}") # More specific log
|
| 89 |
-
|
| 90 |
-
session.pop('uploaded_files', None)
|
| 91 |
-
flash('Files successfully removed')
|
| 92 |
-
logging.info("All uploaded files removed")
|
| 93 |
-
else:
|
| 94 |
-
flash('No file to remove.')
|
| 95 |
-
logging.warning("File not found for removal")
|
| 96 |
-
return redirect(url_for('index'))
|
| 97 |
-
|
| 98 |
-
@app.route('/reset_upload')
|
| 99 |
-
def reset_upload():
|
| 100 |
-
"""Reset the uploaded file and the processed data."""
|
| 101 |
-
uploaded_files = session.get('uploaded_files', [])
|
| 102 |
-
if uploaded_file:
|
| 103 |
-
for filename in uploaded_files:
|
| 104 |
-
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 105 |
-
if os.path.exists(file_path):
|
| 106 |
-
os.remove(file_path)
|
| 107 |
-
logging.info(f"Removed file: {filename}")
|
| 108 |
-
else:
|
| 109 |
-
logging.warning(f"File not found for removal: {file_path}") # More specific log
|
| 110 |
-
|
| 111 |
-
session.pop('uploaded_files', None)
|
| 112 |
-
flash('Files successfully removed')
|
| 113 |
-
logging.info("All uploaded files removed")
|
| 114 |
-
else:
|
| 115 |
-
flash('No file to remove.')
|
| 116 |
-
logging.warning("File not found for removal")
|
| 117 |
-
return redirect(url_for('index'))
|
| 118 |
-
|
| 119 |
-
@app.route('/process', methods=['GET','POST'])
|
| 120 |
def process_file():
|
|
|
|
| 121 |
uploaded_files = session.get('uploaded_files', [])
|
| 122 |
if not uploaded_files:
|
|
|
|
| 123 |
flash('No files selected for processing')
|
| 124 |
-
logging.warning("No files selected for processing")
|
| 125 |
return redirect(url_for('index'))
|
| 126 |
|
| 127 |
-
file_paths = [os.path.join(app.config['UPLOAD_FOLDER'],
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
extracted_text = {}
|
| 131 |
-
processed_Img = {}
|
| 132 |
-
|
| 133 |
try:
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
logging.info(f"Processed images: {processed_Img}")
|
| 137 |
-
|
| 138 |
-
llmText = json_to_llm_str(extracted_text)
|
| 139 |
-
logging.info(f"LLM text: {llmText}")
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
|
|
|
|
|
|
|
|
|
|
| 145 |
except Exception as e:
|
| 146 |
-
logging.
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
extracted_text, processed_Img = extract_text_from_images(file_paths)
|
| 151 |
-
logging.info(f"Extracted text(Backup): {extracted_text}")
|
| 152 |
-
logging.info(f"Processed images(Backup): {processed_Img}")
|
| 153 |
-
if extracted_text:
|
| 154 |
-
text = json_to_llm_str(extracted_text)
|
| 155 |
-
LLMdata = NER_Model(text)
|
| 156 |
-
logging.info(f"NER model data: {LLMdata}")
|
| 157 |
-
else:
|
| 158 |
-
logging.warning("No extracted text available for backup model")
|
| 159 |
-
|
| 160 |
-
cont_data = process_extracted_text(extracted_text)
|
| 161 |
-
logging.info(f"Contextual data: {cont_data}")
|
| 162 |
-
|
| 163 |
-
processed_data = process_resume_data(LLMdata, cont_data, extracted_text)
|
| 164 |
-
logging.info(f"Processed data: {processed_data}")
|
| 165 |
-
|
| 166 |
-
session['processed_data'] = processed_data
|
| 167 |
-
session['processed_Img'] = processed_Img
|
| 168 |
-
flash('Data processed and analyzed successfully')
|
| 169 |
-
logging.info("Data processed and analyzed successfully")
|
| 170 |
-
return redirect(url_for('result'))
|
| 171 |
-
|
| 172 |
@app.route('/result')
|
| 173 |
def result():
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
| 178 |
|
| 179 |
-
@app.route('/
|
| 180 |
-
def
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
if __name__ == '__main__':
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import logging
|
| 3 |
+
from flask import Flask, render_template, request, redirect, url_for, flash, session, send_from_directory
|
| 4 |
+
from utility.utils import process_image_pipeline
|
| 5 |
+
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
# Load environment variables
|
| 8 |
+
load_dotenv()
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
# Configure logging
|
| 11 |
+
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler()])
|
| 12 |
|
| 13 |
+
app = Flask(__name__)
|
| 14 |
+
app.secret_key = os.getenv('SECRET_KEY', 'default_secret_key')
|
| 15 |
+
app.config['UPLOAD_FOLDER'] = 'static/uploads/'
|
| 16 |
+
app.config['RESULT_FOLDER'] = 'static/results/'
|
| 17 |
|
| 18 |
+
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 19 |
+
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
|
| 20 |
|
| 21 |
+
@app.template_filter('basename')
|
| 22 |
+
def basename_filter(path):
|
| 23 |
+
return os.path.basename(path)
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
@app.route('/')
|
| 26 |
def index():
|
| 27 |
uploaded_files = session.get('uploaded_files', [])
|
|
|
|
| 28 |
return render_template('index.html', uploaded_files=uploaded_files)
|
| 29 |
|
| 30 |
@app.route('/upload', methods=['POST'])
|
| 31 |
def upload_file():
|
| 32 |
+
logging.info("Request: /upload received")
|
| 33 |
if 'files' not in request.files:
|
| 34 |
+
logging.warning("Upload: No file part in request")
|
| 35 |
flash('No file part')
|
|
|
|
| 36 |
return redirect(request.url)
|
| 37 |
|
| 38 |
files = request.files.getlist('files')
|
| 39 |
if not files or all(file.filename == '' for file in files):
|
| 40 |
+
logging.warning("Upload: No files selected")
|
| 41 |
flash('No selected files')
|
|
|
|
| 42 |
return redirect(request.url)
|
| 43 |
|
| 44 |
+
uploaded_files = []
|
| 45 |
for file in files:
|
| 46 |
if file:
|
| 47 |
filename = file.filename
|
| 48 |
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 49 |
file.save(file_path)
|
|
|
|
| 50 |
uploaded_files.append(filename)
|
| 51 |
+
logging.info(f"Upload: Successfully saved {filename}")
|
| 52 |
|
| 53 |
session['uploaded_files'] = uploaded_files
|
|
|
|
|
|
|
| 54 |
return process_file()
|
| 55 |
|
| 56 |
+
@app.route('/process', methods=['GET', 'POST'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def process_file():
|
| 58 |
+
logging.info("Request: /process started")
|
| 59 |
uploaded_files = session.get('uploaded_files', [])
|
| 60 |
if not uploaded_files:
|
| 61 |
+
logging.warning("Process: No files in session")
|
| 62 |
flash('No files selected for processing')
|
|
|
|
| 63 |
return redirect(url_for('index'))
|
| 64 |
|
| 65 |
+
file_paths = [os.path.join(app.config['UPLOAD_FOLDER'], f) for f in uploaded_files]
|
| 66 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
try:
|
| 68 |
+
logging.info(f"Process: Sending {len(file_paths)} files to pipeline")
|
| 69 |
+
processed_data = process_image_pipeline(file_paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
# Format images for result.html
|
| 72 |
+
processed_Img = {f: os.path.join(app.config['UPLOAD_FOLDER'], f) for f in uploaded_files}
|
| 73 |
+
|
| 74 |
+
session['processed_data'] = processed_data
|
| 75 |
+
session['processed_Img'] = processed_Img
|
| 76 |
|
| 77 |
+
logging.info("Process: Pipeline completed successfully")
|
| 78 |
+
flash('Data processed successfully')
|
| 79 |
+
return redirect(url_for('result'))
|
| 80 |
except Exception as e:
|
| 81 |
+
logging.exception(f"Process: Critical failure: {e}")
|
| 82 |
+
flash(f'Processing error: {str(e)}')
|
| 83 |
+
return redirect(url_for('index'))
|
| 84 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
@app.route('/result')
|
| 86 |
def result():
|
| 87 |
+
data = session.get('processed_data', {})
|
| 88 |
+
Img = session.get('processed_Img', {})
|
| 89 |
+
if not data:
|
| 90 |
+
return redirect(url_for('index'))
|
| 91 |
+
return render_template('result.html', data=data, Img=Img)
|
| 92 |
|
| 93 |
+
@app.route('/reset_upload')
|
| 94 |
+
def reset_upload():
|
| 95 |
+
uploaded_files = session.get('uploaded_files', [])
|
| 96 |
+
for f in uploaded_files:
|
| 97 |
+
path = os.path.join(app.config['UPLOAD_FOLDER'], f)
|
| 98 |
+
if os.path.exists(path):
|
| 99 |
+
os.remove(path)
|
| 100 |
+
|
| 101 |
+
session.pop('uploaded_files', None)
|
| 102 |
+
session.pop('processed_data', None)
|
| 103 |
+
session.pop('processed_Img', None)
|
| 104 |
+
flash('System reset successful.')
|
| 105 |
+
return redirect(url_for('index'))
|
| 106 |
|
| 107 |
if __name__ == '__main__':
|
| 108 |
+
from utility.utils import get_ocr, get_ner
|
| 109 |
+
logging.info("Core: Pre-initializing engines (this may take a minute)...")
|
| 110 |
+
# Trigger lazy load at startup to avoid reloader issues and request timeouts
|
| 111 |
+
try:
|
| 112 |
+
get_ocr()
|
| 113 |
+
get_ner()
|
| 114 |
+
logging.info("Core: Engines pre-initialized successfully.")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logging.error(f"Core: Failed to pre-initialize engines: {e}")
|
| 117 |
+
|
| 118 |
+
app.run(debug=True, use_reloader=False, port=int(os.getenv('PORT', 5000)))
|
backup/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (9.97 kB). View file
|
|
|
backup/__pycache__/save_load.cpython-310.pyc
ADDED
|
Binary file (765 Bytes). View file
|
|
|
backup/__pycache__/train.cpython-310.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
backup/backup.py
CHANGED
|
@@ -1,59 +1,59 @@
|
|
| 1 |
-
from .model import GLiNER
|
| 2 |
-
|
| 3 |
-
# Initialize GLiNER with the base model
|
| 4 |
-
model = GLiNER.from_pretrained("urchade/gliner_mediumv2.1")
|
| 5 |
-
|
| 6 |
-
# Sample text for entity prediction
|
| 7 |
-
text = """
|
| 8 |
-
lenskart m: (0)9428002330 Lenskart Store,Surat m: (0)9723817060) e:lenskartsurat@gmail.com Store Address UG-4.Ascon City.Opp.Maheshwari Bhavan,Citylight,Surat-395007"""
|
| 9 |
-
|
| 10 |
-
def NER_Model(text):
|
| 11 |
-
|
| 12 |
-
labels = ["Person", "Mail", "Number", "Address", "Organization","Designation","Link"]
|
| 13 |
-
|
| 14 |
-
# Perform entity prediction
|
| 15 |
-
entities = model.predict_entities(text, labels, threshold=0.
|
| 16 |
-
|
| 17 |
-
# Initialize the processed data dictionary
|
| 18 |
-
processed_data = {
|
| 19 |
-
"Name": [],
|
| 20 |
-
"Contact": [],
|
| 21 |
-
"Designation": [],
|
| 22 |
-
"Address": [],
|
| 23 |
-
"Link": [],
|
| 24 |
-
"Company": [],
|
| 25 |
-
"Email": [],
|
| 26 |
-
"extracted_text": "",
|
| 27 |
-
}
|
| 28 |
-
|
| 29 |
-
for entity in entities:
|
| 30 |
-
|
| 31 |
-
print(entity["text"], "=>", entity["label"])
|
| 32 |
-
|
| 33 |
-
#loading the data into json
|
| 34 |
-
if entity["label"]==labels[0]:
|
| 35 |
-
processed_data['Name'].extend([entity["text"]])
|
| 36 |
-
|
| 37 |
-
if entity["label"]==labels[1]:
|
| 38 |
-
processed_data['Email'].extend([entity["text"]])
|
| 39 |
-
|
| 40 |
-
if entity["label"]==labels[2]:
|
| 41 |
-
processed_data['Contact'].extend([entity["text"]])
|
| 42 |
-
|
| 43 |
-
if entity["label"]==labels[3]:
|
| 44 |
-
processed_data['Address'].extend([entity["text"]])
|
| 45 |
-
|
| 46 |
-
if entity["label"]==labels[4]:
|
| 47 |
-
processed_data['Company'].extend([entity["text"]])
|
| 48 |
-
|
| 49 |
-
if entity["label"]==labels[5]:
|
| 50 |
-
processed_data['Designation'].extend([entity["text"]])
|
| 51 |
-
|
| 52 |
-
if entity["label"]==labels[6]:
|
| 53 |
-
processed_data['Link'].extend([entity["text"]])
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
processed_data['Address']=[', '.join(processed_data['Address'])]
|
| 57 |
-
processed_data['extracted_text']=[text]
|
| 58 |
-
|
| 59 |
return processed_data
|
|
|
|
| 1 |
+
from .model import GLiNER
|
| 2 |
+
|
| 3 |
+
# Initialize GLiNER with the base model
|
| 4 |
+
model = GLiNER.from_pretrained("urchade/gliner_mediumv2.1")
|
| 5 |
+
|
| 6 |
+
# Sample text for entity prediction
|
| 7 |
+
text = """
|
| 8 |
+
lenskart m: (0)9428002330 Lenskart Store,Surat m: (0)9723817060) e:lenskartsurat@gmail.com Store Address UG-4.Ascon City.Opp.Maheshwari Bhavan,Citylight,Surat-395007"""
|
| 9 |
+
|
| 10 |
+
def NER_Model(text):
|
| 11 |
+
|
| 12 |
+
labels = ["Person", "Mail", "Number", "Address", "Organization","Designation","Link"]
|
| 13 |
+
|
| 14 |
+
# Perform entity prediction
|
| 15 |
+
entities = model.predict_entities(text, labels, threshold=0.3)
|
| 16 |
+
|
| 17 |
+
# Initialize the processed data dictionary
|
| 18 |
+
processed_data = {
|
| 19 |
+
"Name": [],
|
| 20 |
+
"Contact": [],
|
| 21 |
+
"Designation": [],
|
| 22 |
+
"Address": [],
|
| 23 |
+
"Link": [],
|
| 24 |
+
"Company": [],
|
| 25 |
+
"Email": [],
|
| 26 |
+
"extracted_text": "",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
for entity in entities:
|
| 30 |
+
|
| 31 |
+
print(entity["text"], "=>", entity["label"])
|
| 32 |
+
|
| 33 |
+
#loading the data into json
|
| 34 |
+
if entity["label"]==labels[0]:
|
| 35 |
+
processed_data['Name'].extend([entity["text"]])
|
| 36 |
+
|
| 37 |
+
if entity["label"]==labels[1]:
|
| 38 |
+
processed_data['Email'].extend([entity["text"]])
|
| 39 |
+
|
| 40 |
+
if entity["label"]==labels[2]:
|
| 41 |
+
processed_data['Contact'].extend([entity["text"]])
|
| 42 |
+
|
| 43 |
+
if entity["label"]==labels[3]:
|
| 44 |
+
processed_data['Address'].extend([entity["text"]])
|
| 45 |
+
|
| 46 |
+
if entity["label"]==labels[4]:
|
| 47 |
+
processed_data['Company'].extend([entity["text"]])
|
| 48 |
+
|
| 49 |
+
if entity["label"]==labels[5]:
|
| 50 |
+
processed_data['Designation'].extend([entity["text"]])
|
| 51 |
+
|
| 52 |
+
if entity["label"]==labels[6]:
|
| 53 |
+
processed_data['Link'].extend([entity["text"]])
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
processed_data['Address']=[', '.join(processed_data['Address'])]
|
| 57 |
+
processed_data['extracted_text']=[text]
|
| 58 |
+
|
| 59 |
return processed_data
|
backup/model.py
CHANGED
|
@@ -1,412 +1,412 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import json
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import re
|
| 5 |
-
from typing import Dict, Optional, Union
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from .modules.layers import LstmSeq2SeqEncoder
|
| 9 |
-
from .modules.base import InstructBase
|
| 10 |
-
from .modules.evaluator import Evaluator, greedy_search
|
| 11 |
-
from .modules.span_rep import SpanRepLayer
|
| 12 |
-
from .modules.token_rep import TokenRepLayer
|
| 13 |
-
from torch import nn
|
| 14 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 15 |
-
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
| 16 |
-
from huggingface_hub.utils import HfHubHTTPError
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class GLiNER(InstructBase, PyTorchModelHubMixin):
|
| 21 |
-
def __init__(self, config):
|
| 22 |
-
super().__init__(config)
|
| 23 |
-
|
| 24 |
-
self.config = config
|
| 25 |
-
|
| 26 |
-
# [ENT] token
|
| 27 |
-
self.entity_token = "<<ENT>>"
|
| 28 |
-
self.sep_token = "<<SEP>>"
|
| 29 |
-
|
| 30 |
-
# usually a pretrained bidirectional transformer, returns first subtoken representation
|
| 31 |
-
self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
|
| 32 |
-
subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
|
| 33 |
-
add_tokens=[self.entity_token, self.sep_token])
|
| 34 |
-
|
| 35 |
-
# hierarchical representation of tokens
|
| 36 |
-
self.rnn = LstmSeq2SeqEncoder(
|
| 37 |
-
input_size=config.hidden_size,
|
| 38 |
-
hidden_size=config.hidden_size // 2,
|
| 39 |
-
num_layers=1,
|
| 40 |
-
bidirectional=True,
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
# span representation
|
| 44 |
-
self.span_rep_layer = SpanRepLayer(
|
| 45 |
-
span_mode=config.span_mode,
|
| 46 |
-
hidden_size=config.hidden_size,
|
| 47 |
-
max_width=config.max_width,
|
| 48 |
-
dropout=config.dropout,
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
# prompt representation (FFN)
|
| 52 |
-
self.prompt_rep_layer = nn.Sequential(
|
| 53 |
-
nn.Linear(config.hidden_size, config.hidden_size * 4),
|
| 54 |
-
nn.Dropout(config.dropout),
|
| 55 |
-
nn.ReLU(),
|
| 56 |
-
nn.Linear(config.hidden_size * 4, config.hidden_size)
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
def compute_score_train(self, x):
|
| 60 |
-
span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
|
| 61 |
-
|
| 62 |
-
new_length = x['seq_length'].clone()
|
| 63 |
-
new_tokens = []
|
| 64 |
-
all_len_prompt = []
|
| 65 |
-
num_classes_all = []
|
| 66 |
-
|
| 67 |
-
# add prompt to the tokens
|
| 68 |
-
for i in range(len(x['tokens'])):
|
| 69 |
-
all_types_i = list(x['classes_to_id'][i].keys())
|
| 70 |
-
# multiple entity types in all_types. Prompt is appended at the start of tokens
|
| 71 |
-
entity_prompt = []
|
| 72 |
-
num_classes_all.append(len(all_types_i))
|
| 73 |
-
# add enity types to prompt
|
| 74 |
-
for entity_type in all_types_i:
|
| 75 |
-
entity_prompt.append(self.entity_token) # [ENT] token
|
| 76 |
-
entity_prompt.append(entity_type) # entity type
|
| 77 |
-
entity_prompt.append(self.sep_token) # [SEP] token
|
| 78 |
-
|
| 79 |
-
# prompt format:
|
| 80 |
-
# [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
|
| 81 |
-
|
| 82 |
-
# add prompt to the tokens
|
| 83 |
-
tokens_p = entity_prompt + x['tokens'][i]
|
| 84 |
-
|
| 85 |
-
# input format:
|
| 86 |
-
# [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
|
| 87 |
-
|
| 88 |
-
# update length of the sequence (add prompt length to the original length)
|
| 89 |
-
new_length[i] = new_length[i] + len(entity_prompt)
|
| 90 |
-
# update tokens
|
| 91 |
-
new_tokens.append(tokens_p)
|
| 92 |
-
# store prompt length
|
| 93 |
-
all_len_prompt.append(len(entity_prompt))
|
| 94 |
-
|
| 95 |
-
# create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
|
| 96 |
-
max_num_classes = max(num_classes_all)
|
| 97 |
-
entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
|
| 98 |
-
x['span_mask'].device)
|
| 99 |
-
entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
|
| 100 |
-
x['span_mask'].device) # [batch_size, max_num_classes]
|
| 101 |
-
|
| 102 |
-
# compute all token representations
|
| 103 |
-
bert_output = self.token_rep_layer(new_tokens, new_length)
|
| 104 |
-
word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt)
|
| 105 |
-
mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt)
|
| 106 |
-
|
| 107 |
-
# get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
|
| 108 |
-
word_rep = [] # word representation (after [SEP])
|
| 109 |
-
mask = [] # mask (after [SEP])
|
| 110 |
-
entity_type_rep = [] # entity type representation (before [SEP])
|
| 111 |
-
for i in range(len(x['tokens'])):
|
| 112 |
-
prompt_entity_length = all_len_prompt[i] # length of prompt for this example
|
| 113 |
-
# get word representation (after [SEP])
|
| 114 |
-
word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
|
| 115 |
-
# get mask (after [SEP])
|
| 116 |
-
mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
|
| 117 |
-
|
| 118 |
-
# get entity type representation (before [SEP])
|
| 119 |
-
entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP]
|
| 120 |
-
entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one
|
| 121 |
-
entity_type_rep.append(entity_rep)
|
| 122 |
-
|
| 123 |
-
# padding for word_rep, mask and entity_type_rep
|
| 124 |
-
word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size]
|
| 125 |
-
mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len]
|
| 126 |
-
entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size]
|
| 127 |
-
|
| 128 |
-
# compute span representation
|
| 129 |
-
word_rep = self.rnn(word_rep, mask)
|
| 130 |
-
span_rep = self.span_rep_layer(word_rep, span_idx)
|
| 131 |
-
|
| 132 |
-
# compute final entity type representation (FFN)
|
| 133 |
-
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
|
| 134 |
-
num_classes = entity_type_rep.shape[1] # number of entity types
|
| 135 |
-
|
| 136 |
-
# similarity score
|
| 137 |
-
scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
|
| 138 |
-
|
| 139 |
-
return scores, num_classes, entity_type_mask
|
| 140 |
-
|
| 141 |
-
def forward(self, x):
|
| 142 |
-
# compute span representation
|
| 143 |
-
scores, num_classes, entity_type_mask = self.compute_score_train(x)
|
| 144 |
-
batch_size = scores.shape[0]
|
| 145 |
-
|
| 146 |
-
# loss for filtering classifier
|
| 147 |
-
logits_label = scores.view(-1, num_classes)
|
| 148 |
-
labels = x["span_label"].view(-1) # (batch_size * num_spans)
|
| 149 |
-
mask_label = labels != -1 # (batch_size * num_spans)
|
| 150 |
-
labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
|
| 151 |
-
|
| 152 |
-
# one-hot encoding
|
| 153 |
-
labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
|
| 154 |
-
labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
|
| 155 |
-
labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
|
| 156 |
-
# Shape of labels_one_hot: (batch_size * num_spans, num_classes)
|
| 157 |
-
|
| 158 |
-
# compute loss (without reduction)
|
| 159 |
-
all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
|
| 160 |
-
reduction='none')
|
| 161 |
-
# mask loss using entity_type_mask (B, C)
|
| 162 |
-
masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
|
| 163 |
-
all_losses = masked_loss.view(-1, num_classes)
|
| 164 |
-
# expand mask_label to all_losses
|
| 165 |
-
mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
|
| 166 |
-
# put lower loss for in label_one_hot (2 for positive, 1 for negative)
|
| 167 |
-
weight_c = labels_one_hot + 1
|
| 168 |
-
# apply mask
|
| 169 |
-
all_losses = all_losses * mask_label.float() * weight_c
|
| 170 |
-
return all_losses.sum()
|
| 171 |
-
|
| 172 |
-
def compute_score_eval(self, x, device):
|
| 173 |
-
# check if classes_to_id is dict
|
| 174 |
-
assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
|
| 175 |
-
|
| 176 |
-
span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
|
| 177 |
-
|
| 178 |
-
all_types = list(x['classes_to_id'].keys())
|
| 179 |
-
# multiple entity types in all_types. Prompt is appended at the start of tokens
|
| 180 |
-
entity_prompt = []
|
| 181 |
-
|
| 182 |
-
# add enity types to prompt
|
| 183 |
-
for entity_type in all_types:
|
| 184 |
-
entity_prompt.append(self.entity_token)
|
| 185 |
-
entity_prompt.append(entity_type)
|
| 186 |
-
|
| 187 |
-
entity_prompt.append(self.sep_token)
|
| 188 |
-
|
| 189 |
-
prompt_entity_length = len(entity_prompt)
|
| 190 |
-
|
| 191 |
-
# add prompt
|
| 192 |
-
tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
|
| 193 |
-
seq_length_p = x['seq_length'] + prompt_entity_length
|
| 194 |
-
|
| 195 |
-
out = self.token_rep_layer(tokens_p, seq_length_p)
|
| 196 |
-
|
| 197 |
-
word_rep_w_prompt = out["embeddings"]
|
| 198 |
-
mask_w_prompt = out["mask"]
|
| 199 |
-
|
| 200 |
-
# remove prompt
|
| 201 |
-
word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
|
| 202 |
-
mask = mask_w_prompt[:, prompt_entity_length:]
|
| 203 |
-
|
| 204 |
-
# get_entity_type_rep
|
| 205 |
-
entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
|
| 206 |
-
# extract [ENT] tokens (which are at even positions in entity_type_rep)
|
| 207 |
-
entity_type_rep = entity_type_rep[:, 0::2, :]
|
| 208 |
-
|
| 209 |
-
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
|
| 210 |
-
|
| 211 |
-
word_rep = self.rnn(word_rep, mask)
|
| 212 |
-
|
| 213 |
-
span_rep = self.span_rep_layer(word_rep, span_idx)
|
| 214 |
-
|
| 215 |
-
local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
|
| 216 |
-
|
| 217 |
-
return local_scores
|
| 218 |
-
|
| 219 |
-
@torch.no_grad()
|
| 220 |
-
def predict(self, x, flat_ner=False, threshold=0.5):
|
| 221 |
-
self.eval()
|
| 222 |
-
local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
|
| 223 |
-
spans = []
|
| 224 |
-
for i, _ in enumerate(x["tokens"]):
|
| 225 |
-
local_i = local_scores[i]
|
| 226 |
-
wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
|
| 227 |
-
span_i = []
|
| 228 |
-
for s, k, c in zip(*wh_i):
|
| 229 |
-
if s + k < len(x["tokens"][i]):
|
| 230 |
-
span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
|
| 231 |
-
span_i = greedy_search(span_i, flat_ner)
|
| 232 |
-
spans.append(span_i)
|
| 233 |
-
return spans
|
| 234 |
-
|
| 235 |
-
def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
|
| 236 |
-
tokens = []
|
| 237 |
-
start_token_idx_to_text_idx = []
|
| 238 |
-
end_token_idx_to_text_idx = []
|
| 239 |
-
for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
|
| 240 |
-
tokens.append(match.group())
|
| 241 |
-
start_token_idx_to_text_idx.append(match.start())
|
| 242 |
-
end_token_idx_to_text_idx.append(match.end())
|
| 243 |
-
|
| 244 |
-
input_x = {"tokenized_text": tokens, "ner": None}
|
| 245 |
-
x = self.collate_fn([input_x], labels)
|
| 246 |
-
output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
|
| 247 |
-
|
| 248 |
-
entities = []
|
| 249 |
-
for start_token_idx, end_token_idx, ent_type in output[0]:
|
| 250 |
-
start_text_idx = start_token_idx_to_text_idx[start_token_idx]
|
| 251 |
-
end_text_idx = end_token_idx_to_text_idx[end_token_idx]
|
| 252 |
-
entities.append({
|
| 253 |
-
"start": start_token_idx_to_text_idx[start_token_idx],
|
| 254 |
-
"end": end_token_idx_to_text_idx[end_token_idx],
|
| 255 |
-
"text": text[start_text_idx:end_text_idx],
|
| 256 |
-
"label": ent_type,
|
| 257 |
-
})
|
| 258 |
-
return entities
|
| 259 |
-
|
| 260 |
-
def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
|
| 261 |
-
self.eval()
|
| 262 |
-
data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
|
| 263 |
-
device = next(self.parameters()).device
|
| 264 |
-
all_preds = []
|
| 265 |
-
all_trues = []
|
| 266 |
-
for x in data_loader:
|
| 267 |
-
for k, v in x.items():
|
| 268 |
-
if isinstance(v, torch.Tensor):
|
| 269 |
-
x[k] = v.to(device)
|
| 270 |
-
batch_predictions = self.predict(x, flat_ner, threshold)
|
| 271 |
-
all_preds.extend(batch_predictions)
|
| 272 |
-
all_trues.extend(x["entities"])
|
| 273 |
-
evaluator = Evaluator(all_trues, all_preds)
|
| 274 |
-
out, f1 = evaluator.evaluate()
|
| 275 |
-
return out, f1
|
| 276 |
-
|
| 277 |
-
@classmethod
|
| 278 |
-
def _from_pretrained(
|
| 279 |
-
cls,
|
| 280 |
-
*,
|
| 281 |
-
model_id: str,
|
| 282 |
-
revision: Optional[str],
|
| 283 |
-
cache_dir: Optional[Union[str, Path]],
|
| 284 |
-
force_download: bool,
|
| 285 |
-
proxies: Optional[Dict],
|
| 286 |
-
resume_download: bool,
|
| 287 |
-
local_files_only: bool,
|
| 288 |
-
token: Union[str, bool, None],
|
| 289 |
-
map_location: str = "cpu",
|
| 290 |
-
strict: bool = False,
|
| 291 |
-
**model_kwargs,
|
| 292 |
-
):
|
| 293 |
-
# 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
|
| 294 |
-
filenames = ["gliner_base.pt", "gliner_multi.pt"]
|
| 295 |
-
for filename in filenames:
|
| 296 |
-
model_file = Path(model_id) / filename
|
| 297 |
-
if not model_file.exists():
|
| 298 |
-
try:
|
| 299 |
-
model_file = hf_hub_download(
|
| 300 |
-
repo_id=model_id,
|
| 301 |
-
filename=filename,
|
| 302 |
-
revision=revision,
|
| 303 |
-
cache_dir=cache_dir,
|
| 304 |
-
force_download=force_download,
|
| 305 |
-
proxies=proxies,
|
| 306 |
-
resume_download=resume_download,
|
| 307 |
-
token=token,
|
| 308 |
-
local_files_only=local_files_only,
|
| 309 |
-
)
|
| 310 |
-
except HfHubHTTPError:
|
| 311 |
-
continue
|
| 312 |
-
dict_load = torch.load(model_file, map_location=torch.device(map_location))
|
| 313 |
-
config = dict_load["config"]
|
| 314 |
-
state_dict = dict_load["model_weights"]
|
| 315 |
-
config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
|
| 316 |
-
model = cls(config)
|
| 317 |
-
model.load_state_dict(state_dict, strict=strict, assign=True)
|
| 318 |
-
# Required to update flair's internals as well:
|
| 319 |
-
model.to(map_location)
|
| 320 |
-
return model
|
| 321 |
-
|
| 322 |
-
# 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
|
| 323 |
-
from .train import load_config_as_namespace
|
| 324 |
-
|
| 325 |
-
model_file = Path(model_id) / "pytorch_model.bin"
|
| 326 |
-
if not model_file.exists():
|
| 327 |
-
model_file = hf_hub_download(
|
| 328 |
-
repo_id=model_id,
|
| 329 |
-
filename="pytorch_model.bin",
|
| 330 |
-
revision=revision,
|
| 331 |
-
cache_dir=cache_dir,
|
| 332 |
-
force_download=force_download,
|
| 333 |
-
proxies=proxies,
|
| 334 |
-
resume_download=resume_download,
|
| 335 |
-
token=token,
|
| 336 |
-
local_files_only=local_files_only,
|
| 337 |
-
)
|
| 338 |
-
config_file = Path(model_id) / "gliner_config.json"
|
| 339 |
-
if not config_file.exists():
|
| 340 |
-
config_file = hf_hub_download(
|
| 341 |
-
repo_id=model_id,
|
| 342 |
-
filename="gliner_config.json",
|
| 343 |
-
revision=revision,
|
| 344 |
-
cache_dir=cache_dir,
|
| 345 |
-
force_download=force_download,
|
| 346 |
-
proxies=proxies,
|
| 347 |
-
resume_download=resume_download,
|
| 348 |
-
token=token,
|
| 349 |
-
local_files_only=local_files_only,
|
| 350 |
-
)
|
| 351 |
-
config = load_config_as_namespace(config_file)
|
| 352 |
-
model = cls(config)
|
| 353 |
-
state_dict = torch.load(model_file, map_location=torch.device(map_location))
|
| 354 |
-
model.load_state_dict(state_dict, strict=strict, assign=True)
|
| 355 |
-
model.to(map_location)
|
| 356 |
-
return model
|
| 357 |
-
|
| 358 |
-
def save_pretrained(
|
| 359 |
-
self,
|
| 360 |
-
save_directory: Union[str, Path],
|
| 361 |
-
*,
|
| 362 |
-
config: Optional[Union[dict, "DataclassInstance"]] = None,
|
| 363 |
-
repo_id: Optional[str] = None,
|
| 364 |
-
push_to_hub: bool = False,
|
| 365 |
-
**push_to_hub_kwargs,
|
| 366 |
-
) -> Optional[str]:
|
| 367 |
-
"""
|
| 368 |
-
Save weights in local directory.
|
| 369 |
-
|
| 370 |
-
Args:
|
| 371 |
-
save_directory (`str` or `Path`):
|
| 372 |
-
Path to directory in which the model weights and configuration will be saved.
|
| 373 |
-
config (`dict` or `DataclassInstance`, *optional*):
|
| 374 |
-
Model configuration specified as a key/value dictionary or a dataclass instance.
|
| 375 |
-
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 376 |
-
Whether or not to push your model to the Huggingface Hub after saving it.
|
| 377 |
-
repo_id (`str`, *optional*):
|
| 378 |
-
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
| 379 |
-
not provided.
|
| 380 |
-
kwargs:
|
| 381 |
-
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
| 382 |
-
"""
|
| 383 |
-
save_directory = Path(save_directory)
|
| 384 |
-
save_directory.mkdir(parents=True, exist_ok=True)
|
| 385 |
-
|
| 386 |
-
# save model weights/files
|
| 387 |
-
torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
|
| 388 |
-
|
| 389 |
-
# save config (if provided)
|
| 390 |
-
if config is None:
|
| 391 |
-
config = self.config
|
| 392 |
-
if config is not None:
|
| 393 |
-
if isinstance(config, argparse.Namespace):
|
| 394 |
-
config = vars(config)
|
| 395 |
-
(save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
|
| 396 |
-
|
| 397 |
-
# push to the Hub if required
|
| 398 |
-
if push_to_hub:
|
| 399 |
-
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
|
| 400 |
-
if config is not None: # kwarg for `push_to_hub`
|
| 401 |
-
kwargs["config"] = config
|
| 402 |
-
if repo_id is None:
|
| 403 |
-
repo_id = save_directory.name # Defaults to `save_directory` name
|
| 404 |
-
return self.push_to_hub(repo_id=repo_id, **kwargs)
|
| 405 |
-
return None
|
| 406 |
-
|
| 407 |
-
def to(self, device):
|
| 408 |
-
super().to(device)
|
| 409 |
-
import flair
|
| 410 |
-
|
| 411 |
-
flair.device = device
|
| 412 |
-
return self
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import re
|
| 5 |
+
from typing import Dict, Optional, Union
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from .modules.layers import LstmSeq2SeqEncoder
|
| 9 |
+
from .modules.base import InstructBase
|
| 10 |
+
from .modules.evaluator import Evaluator, greedy_search
|
| 11 |
+
from .modules.span_rep import SpanRepLayer
|
| 12 |
+
from .modules.token_rep import TokenRepLayer
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 15 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
| 16 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GLiNER(InstructBase, PyTorchModelHubMixin):
|
| 21 |
+
def __init__(self, config):
|
| 22 |
+
super().__init__(config)
|
| 23 |
+
|
| 24 |
+
self.config = config
|
| 25 |
+
|
| 26 |
+
# [ENT] token
|
| 27 |
+
self.entity_token = "<<ENT>>"
|
| 28 |
+
self.sep_token = "<<SEP>>"
|
| 29 |
+
|
| 30 |
+
# usually a pretrained bidirectional transformer, returns first subtoken representation
|
| 31 |
+
self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
|
| 32 |
+
subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
|
| 33 |
+
add_tokens=[self.entity_token, self.sep_token])
|
| 34 |
+
|
| 35 |
+
# hierarchical representation of tokens
|
| 36 |
+
self.rnn = LstmSeq2SeqEncoder(
|
| 37 |
+
input_size=config.hidden_size,
|
| 38 |
+
hidden_size=config.hidden_size // 2,
|
| 39 |
+
num_layers=1,
|
| 40 |
+
bidirectional=True,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# span representation
|
| 44 |
+
self.span_rep_layer = SpanRepLayer(
|
| 45 |
+
span_mode=config.span_mode,
|
| 46 |
+
hidden_size=config.hidden_size,
|
| 47 |
+
max_width=config.max_width,
|
| 48 |
+
dropout=config.dropout,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# prompt representation (FFN)
|
| 52 |
+
self.prompt_rep_layer = nn.Sequential(
|
| 53 |
+
nn.Linear(config.hidden_size, config.hidden_size * 4),
|
| 54 |
+
nn.Dropout(config.dropout),
|
| 55 |
+
nn.ReLU(),
|
| 56 |
+
nn.Linear(config.hidden_size * 4, config.hidden_size)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def compute_score_train(self, x):
|
| 60 |
+
span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
|
| 61 |
+
|
| 62 |
+
new_length = x['seq_length'].clone()
|
| 63 |
+
new_tokens = []
|
| 64 |
+
all_len_prompt = []
|
| 65 |
+
num_classes_all = []
|
| 66 |
+
|
| 67 |
+
# add prompt to the tokens
|
| 68 |
+
for i in range(len(x['tokens'])):
|
| 69 |
+
all_types_i = list(x['classes_to_id'][i].keys())
|
| 70 |
+
# multiple entity types in all_types. Prompt is appended at the start of tokens
|
| 71 |
+
entity_prompt = []
|
| 72 |
+
num_classes_all.append(len(all_types_i))
|
| 73 |
+
# add enity types to prompt
|
| 74 |
+
for entity_type in all_types_i:
|
| 75 |
+
entity_prompt.append(self.entity_token) # [ENT] token
|
| 76 |
+
entity_prompt.append(entity_type) # entity type
|
| 77 |
+
entity_prompt.append(self.sep_token) # [SEP] token
|
| 78 |
+
|
| 79 |
+
# prompt format:
|
| 80 |
+
# [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
|
| 81 |
+
|
| 82 |
+
# add prompt to the tokens
|
| 83 |
+
tokens_p = entity_prompt + x['tokens'][i]
|
| 84 |
+
|
| 85 |
+
# input format:
|
| 86 |
+
# [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
|
| 87 |
+
|
| 88 |
+
# update length of the sequence (add prompt length to the original length)
|
| 89 |
+
new_length[i] = new_length[i] + len(entity_prompt)
|
| 90 |
+
# update tokens
|
| 91 |
+
new_tokens.append(tokens_p)
|
| 92 |
+
# store prompt length
|
| 93 |
+
all_len_prompt.append(len(entity_prompt))
|
| 94 |
+
|
| 95 |
+
# create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
|
| 96 |
+
max_num_classes = max(num_classes_all)
|
| 97 |
+
entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
|
| 98 |
+
x['span_mask'].device)
|
| 99 |
+
entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
|
| 100 |
+
x['span_mask'].device) # [batch_size, max_num_classes]
|
| 101 |
+
|
| 102 |
+
# compute all token representations
|
| 103 |
+
bert_output = self.token_rep_layer(new_tokens, new_length)
|
| 104 |
+
word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt)
|
| 105 |
+
mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt)
|
| 106 |
+
|
| 107 |
+
# get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
|
| 108 |
+
word_rep = [] # word representation (after [SEP])
|
| 109 |
+
mask = [] # mask (after [SEP])
|
| 110 |
+
entity_type_rep = [] # entity type representation (before [SEP])
|
| 111 |
+
for i in range(len(x['tokens'])):
|
| 112 |
+
prompt_entity_length = all_len_prompt[i] # length of prompt for this example
|
| 113 |
+
# get word representation (after [SEP])
|
| 114 |
+
word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
|
| 115 |
+
# get mask (after [SEP])
|
| 116 |
+
mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
|
| 117 |
+
|
| 118 |
+
# get entity type representation (before [SEP])
|
| 119 |
+
entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP]
|
| 120 |
+
entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one
|
| 121 |
+
entity_type_rep.append(entity_rep)
|
| 122 |
+
|
| 123 |
+
# padding for word_rep, mask and entity_type_rep
|
| 124 |
+
word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size]
|
| 125 |
+
mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len]
|
| 126 |
+
entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size]
|
| 127 |
+
|
| 128 |
+
# compute span representation
|
| 129 |
+
word_rep = self.rnn(word_rep, mask)
|
| 130 |
+
span_rep = self.span_rep_layer(word_rep, span_idx)
|
| 131 |
+
|
| 132 |
+
# compute final entity type representation (FFN)
|
| 133 |
+
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
|
| 134 |
+
num_classes = entity_type_rep.shape[1] # number of entity types
|
| 135 |
+
|
| 136 |
+
# similarity score
|
| 137 |
+
scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
|
| 138 |
+
|
| 139 |
+
return scores, num_classes, entity_type_mask
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
# compute span representation
|
| 143 |
+
scores, num_classes, entity_type_mask = self.compute_score_train(x)
|
| 144 |
+
batch_size = scores.shape[0]
|
| 145 |
+
|
| 146 |
+
# loss for filtering classifier
|
| 147 |
+
logits_label = scores.view(-1, num_classes)
|
| 148 |
+
labels = x["span_label"].view(-1) # (batch_size * num_spans)
|
| 149 |
+
mask_label = labels != -1 # (batch_size * num_spans)
|
| 150 |
+
labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
|
| 151 |
+
|
| 152 |
+
# one-hot encoding
|
| 153 |
+
labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
|
| 154 |
+
labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
|
| 155 |
+
labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
|
| 156 |
+
# Shape of labels_one_hot: (batch_size * num_spans, num_classes)
|
| 157 |
+
|
| 158 |
+
# compute loss (without reduction)
|
| 159 |
+
all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
|
| 160 |
+
reduction='none')
|
| 161 |
+
# mask loss using entity_type_mask (B, C)
|
| 162 |
+
masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
|
| 163 |
+
all_losses = masked_loss.view(-1, num_classes)
|
| 164 |
+
# expand mask_label to all_losses
|
| 165 |
+
mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
|
| 166 |
+
# put lower loss for in label_one_hot (2 for positive, 1 for negative)
|
| 167 |
+
weight_c = labels_one_hot + 1
|
| 168 |
+
# apply mask
|
| 169 |
+
all_losses = all_losses * mask_label.float() * weight_c
|
| 170 |
+
return all_losses.sum()
|
| 171 |
+
|
| 172 |
+
def compute_score_eval(self, x, device):
|
| 173 |
+
# check if classes_to_id is dict
|
| 174 |
+
assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
|
| 175 |
+
|
| 176 |
+
span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
|
| 177 |
+
|
| 178 |
+
all_types = list(x['classes_to_id'].keys())
|
| 179 |
+
# multiple entity types in all_types. Prompt is appended at the start of tokens
|
| 180 |
+
entity_prompt = []
|
| 181 |
+
|
| 182 |
+
# add enity types to prompt
|
| 183 |
+
for entity_type in all_types:
|
| 184 |
+
entity_prompt.append(self.entity_token)
|
| 185 |
+
entity_prompt.append(entity_type)
|
| 186 |
+
|
| 187 |
+
entity_prompt.append(self.sep_token)
|
| 188 |
+
|
| 189 |
+
prompt_entity_length = len(entity_prompt)
|
| 190 |
+
|
| 191 |
+
# add prompt
|
| 192 |
+
tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
|
| 193 |
+
seq_length_p = x['seq_length'] + prompt_entity_length
|
| 194 |
+
|
| 195 |
+
out = self.token_rep_layer(tokens_p, seq_length_p)
|
| 196 |
+
|
| 197 |
+
word_rep_w_prompt = out["embeddings"]
|
| 198 |
+
mask_w_prompt = out["mask"]
|
| 199 |
+
|
| 200 |
+
# remove prompt
|
| 201 |
+
word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
|
| 202 |
+
mask = mask_w_prompt[:, prompt_entity_length:]
|
| 203 |
+
|
| 204 |
+
# get_entity_type_rep
|
| 205 |
+
entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
|
| 206 |
+
# extract [ENT] tokens (which are at even positions in entity_type_rep)
|
| 207 |
+
entity_type_rep = entity_type_rep[:, 0::2, :]
|
| 208 |
+
|
| 209 |
+
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
|
| 210 |
+
|
| 211 |
+
word_rep = self.rnn(word_rep, mask)
|
| 212 |
+
|
| 213 |
+
span_rep = self.span_rep_layer(word_rep, span_idx)
|
| 214 |
+
|
| 215 |
+
local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
|
| 216 |
+
|
| 217 |
+
return local_scores
|
| 218 |
+
|
| 219 |
+
@torch.no_grad()
|
| 220 |
+
def predict(self, x, flat_ner=False, threshold=0.5):
|
| 221 |
+
self.eval()
|
| 222 |
+
local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
|
| 223 |
+
spans = []
|
| 224 |
+
for i, _ in enumerate(x["tokens"]):
|
| 225 |
+
local_i = local_scores[i]
|
| 226 |
+
wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
|
| 227 |
+
span_i = []
|
| 228 |
+
for s, k, c in zip(*wh_i):
|
| 229 |
+
if s + k < len(x["tokens"][i]):
|
| 230 |
+
span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
|
| 231 |
+
span_i = greedy_search(span_i, flat_ner)
|
| 232 |
+
spans.append(span_i)
|
| 233 |
+
return spans
|
| 234 |
+
|
| 235 |
+
def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
|
| 236 |
+
tokens = []
|
| 237 |
+
start_token_idx_to_text_idx = []
|
| 238 |
+
end_token_idx_to_text_idx = []
|
| 239 |
+
for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
|
| 240 |
+
tokens.append(match.group())
|
| 241 |
+
start_token_idx_to_text_idx.append(match.start())
|
| 242 |
+
end_token_idx_to_text_idx.append(match.end())
|
| 243 |
+
|
| 244 |
+
input_x = {"tokenized_text": tokens, "ner": None}
|
| 245 |
+
x = self.collate_fn([input_x], labels)
|
| 246 |
+
output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
|
| 247 |
+
|
| 248 |
+
entities = []
|
| 249 |
+
for start_token_idx, end_token_idx, ent_type in output[0]:
|
| 250 |
+
start_text_idx = start_token_idx_to_text_idx[start_token_idx]
|
| 251 |
+
end_text_idx = end_token_idx_to_text_idx[end_token_idx]
|
| 252 |
+
entities.append({
|
| 253 |
+
"start": start_token_idx_to_text_idx[start_token_idx],
|
| 254 |
+
"end": end_token_idx_to_text_idx[end_token_idx],
|
| 255 |
+
"text": text[start_text_idx:end_text_idx],
|
| 256 |
+
"label": ent_type,
|
| 257 |
+
})
|
| 258 |
+
return entities
|
| 259 |
+
|
| 260 |
+
def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
|
| 261 |
+
self.eval()
|
| 262 |
+
data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
|
| 263 |
+
device = next(self.parameters()).device
|
| 264 |
+
all_preds = []
|
| 265 |
+
all_trues = []
|
| 266 |
+
for x in data_loader:
|
| 267 |
+
for k, v in x.items():
|
| 268 |
+
if isinstance(v, torch.Tensor):
|
| 269 |
+
x[k] = v.to(device)
|
| 270 |
+
batch_predictions = self.predict(x, flat_ner, threshold)
|
| 271 |
+
all_preds.extend(batch_predictions)
|
| 272 |
+
all_trues.extend(x["entities"])
|
| 273 |
+
evaluator = Evaluator(all_trues, all_preds)
|
| 274 |
+
out, f1 = evaluator.evaluate()
|
| 275 |
+
return out, f1
|
| 276 |
+
|
| 277 |
+
@classmethod
|
| 278 |
+
def _from_pretrained(
|
| 279 |
+
cls,
|
| 280 |
+
*,
|
| 281 |
+
model_id: str,
|
| 282 |
+
revision: Optional[str],
|
| 283 |
+
cache_dir: Optional[Union[str, Path]],
|
| 284 |
+
force_download: bool,
|
| 285 |
+
proxies: Optional[Dict],
|
| 286 |
+
resume_download: bool,
|
| 287 |
+
local_files_only: bool,
|
| 288 |
+
token: Union[str, bool, None],
|
| 289 |
+
map_location: str = "cpu",
|
| 290 |
+
strict: bool = False,
|
| 291 |
+
**model_kwargs,
|
| 292 |
+
):
|
| 293 |
+
# 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
|
| 294 |
+
filenames = ["gliner_base.pt", "gliner_multi.pt"]
|
| 295 |
+
for filename in filenames:
|
| 296 |
+
model_file = Path(model_id) / filename
|
| 297 |
+
if not model_file.exists():
|
| 298 |
+
try:
|
| 299 |
+
model_file = hf_hub_download(
|
| 300 |
+
repo_id=model_id,
|
| 301 |
+
filename=filename,
|
| 302 |
+
revision=revision,
|
| 303 |
+
cache_dir=cache_dir,
|
| 304 |
+
force_download=force_download,
|
| 305 |
+
proxies=proxies,
|
| 306 |
+
resume_download=resume_download,
|
| 307 |
+
token=token,
|
| 308 |
+
local_files_only=local_files_only,
|
| 309 |
+
)
|
| 310 |
+
except HfHubHTTPError:
|
| 311 |
+
continue
|
| 312 |
+
dict_load = torch.load(model_file, map_location=torch.device(map_location))
|
| 313 |
+
config = dict_load["config"]
|
| 314 |
+
state_dict = dict_load["model_weights"]
|
| 315 |
+
config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
|
| 316 |
+
model = cls(config)
|
| 317 |
+
model.load_state_dict(state_dict, strict=strict, assign=True)
|
| 318 |
+
# Required to update flair's internals as well:
|
| 319 |
+
model.to(map_location)
|
| 320 |
+
return model
|
| 321 |
+
|
| 322 |
+
# 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
|
| 323 |
+
from .train import load_config_as_namespace
|
| 324 |
+
|
| 325 |
+
model_file = Path(model_id) / "pytorch_model.bin"
|
| 326 |
+
if not model_file.exists():
|
| 327 |
+
model_file = hf_hub_download(
|
| 328 |
+
repo_id=model_id,
|
| 329 |
+
filename="pytorch_model.bin",
|
| 330 |
+
revision=revision,
|
| 331 |
+
cache_dir=cache_dir,
|
| 332 |
+
force_download=force_download,
|
| 333 |
+
proxies=proxies,
|
| 334 |
+
resume_download=resume_download,
|
| 335 |
+
token=token,
|
| 336 |
+
local_files_only=local_files_only,
|
| 337 |
+
)
|
| 338 |
+
config_file = Path(model_id) / "gliner_config.json"
|
| 339 |
+
if not config_file.exists():
|
| 340 |
+
config_file = hf_hub_download(
|
| 341 |
+
repo_id=model_id,
|
| 342 |
+
filename="gliner_config.json",
|
| 343 |
+
revision=revision,
|
| 344 |
+
cache_dir=cache_dir,
|
| 345 |
+
force_download=force_download,
|
| 346 |
+
proxies=proxies,
|
| 347 |
+
resume_download=resume_download,
|
| 348 |
+
token=token,
|
| 349 |
+
local_files_only=local_files_only,
|
| 350 |
+
)
|
| 351 |
+
config = load_config_as_namespace(config_file)
|
| 352 |
+
model = cls(config)
|
| 353 |
+
state_dict = torch.load(model_file, map_location=torch.device(map_location))
|
| 354 |
+
model.load_state_dict(state_dict, strict=strict, assign=True)
|
| 355 |
+
model.to(map_location)
|
| 356 |
+
return model
|
| 357 |
+
|
| 358 |
+
def save_pretrained(
|
| 359 |
+
self,
|
| 360 |
+
save_directory: Union[str, Path],
|
| 361 |
+
*,
|
| 362 |
+
config: Optional[Union[dict, "DataclassInstance"]] = None,
|
| 363 |
+
repo_id: Optional[str] = None,
|
| 364 |
+
push_to_hub: bool = False,
|
| 365 |
+
**push_to_hub_kwargs,
|
| 366 |
+
) -> Optional[str]:
|
| 367 |
+
"""
|
| 368 |
+
Save weights in local directory.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
save_directory (`str` or `Path`):
|
| 372 |
+
Path to directory in which the model weights and configuration will be saved.
|
| 373 |
+
config (`dict` or `DataclassInstance`, *optional*):
|
| 374 |
+
Model configuration specified as a key/value dictionary or a dataclass instance.
|
| 375 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 376 |
+
Whether or not to push your model to the Huggingface Hub after saving it.
|
| 377 |
+
repo_id (`str`, *optional*):
|
| 378 |
+
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
| 379 |
+
not provided.
|
| 380 |
+
kwargs:
|
| 381 |
+
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
| 382 |
+
"""
|
| 383 |
+
save_directory = Path(save_directory)
|
| 384 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
| 385 |
+
|
| 386 |
+
# save model weights/files
|
| 387 |
+
torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
|
| 388 |
+
|
| 389 |
+
# save config (if provided)
|
| 390 |
+
if config is None:
|
| 391 |
+
config = self.config
|
| 392 |
+
if config is not None:
|
| 393 |
+
if isinstance(config, argparse.Namespace):
|
| 394 |
+
config = vars(config)
|
| 395 |
+
(save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
|
| 396 |
+
|
| 397 |
+
# push to the Hub if required
|
| 398 |
+
if push_to_hub:
|
| 399 |
+
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
|
| 400 |
+
if config is not None: # kwarg for `push_to_hub`
|
| 401 |
+
kwargs["config"] = config
|
| 402 |
+
if repo_id is None:
|
| 403 |
+
repo_id = save_directory.name # Defaults to `save_directory` name
|
| 404 |
+
return self.push_to_hub(repo_id=repo_id, **kwargs)
|
| 405 |
+
return None
|
| 406 |
+
|
| 407 |
+
def to(self, device):
|
| 408 |
+
super().to(device)
|
| 409 |
+
import flair
|
| 410 |
+
|
| 411 |
+
flair.device = device
|
| 412 |
+
return self
|
backup/modules/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
backup/modules/__pycache__/evaluator.cpython-310.pyc
ADDED
|
Binary file (4.44 kB). View file
|
|
|
backup/modules/__pycache__/layers.cpython-310.pyc
ADDED
|
Binary file (1.23 kB). View file
|
|
|
backup/modules/__pycache__/run_evaluation.cpython-310.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|
backup/modules/__pycache__/span_rep.cpython-310.pyc
ADDED
|
Binary file (9.62 kB). View file
|
|
|
backup/modules/__pycache__/token_rep.cpython-310.pyc
ADDED
|
Binary file (2.4 kB). View file
|
|
|
backup/modules/base.py
CHANGED
|
@@ -1,150 +1,150 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
-
from typing import List, Tuple, Dict
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 7 |
-
from torch.utils.data import DataLoader
|
| 8 |
-
import random
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class InstructBase(nn.Module):
|
| 12 |
-
def __init__(self, config):
|
| 13 |
-
super().__init__()
|
| 14 |
-
self.max_width = config.max_width
|
| 15 |
-
self.base_config = config
|
| 16 |
-
|
| 17 |
-
def get_dict(self, spans, classes_to_id):
|
| 18 |
-
dict_tag = defaultdict(int)
|
| 19 |
-
for span in spans:
|
| 20 |
-
if span[2] in classes_to_id:
|
| 21 |
-
dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
|
| 22 |
-
return dict_tag
|
| 23 |
-
|
| 24 |
-
def preprocess_spans(self, tokens, ner, classes_to_id):
|
| 25 |
-
|
| 26 |
-
max_len = self.base_config.max_len
|
| 27 |
-
|
| 28 |
-
if len(tokens) > max_len:
|
| 29 |
-
length = max_len
|
| 30 |
-
tokens = tokens[:max_len]
|
| 31 |
-
else:
|
| 32 |
-
length = len(tokens)
|
| 33 |
-
|
| 34 |
-
spans_idx = []
|
| 35 |
-
for i in range(length):
|
| 36 |
-
spans_idx.extend([(i, i + j) for j in range(self.max_width)])
|
| 37 |
-
|
| 38 |
-
dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
|
| 39 |
-
|
| 40 |
-
# 0 for null labels
|
| 41 |
-
span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
|
| 42 |
-
spans_idx = torch.LongTensor(spans_idx)
|
| 43 |
-
|
| 44 |
-
# mask for valid spans
|
| 45 |
-
valid_span_mask = spans_idx[:, 1] > length - 1
|
| 46 |
-
|
| 47 |
-
# mask invalid positions
|
| 48 |
-
span_label = span_label.masked_fill(valid_span_mask, -1)
|
| 49 |
-
|
| 50 |
-
return {
|
| 51 |
-
'tokens': tokens,
|
| 52 |
-
'span_idx': spans_idx,
|
| 53 |
-
'span_label': span_label,
|
| 54 |
-
'seq_length': length,
|
| 55 |
-
'entities': ner,
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
def collate_fn(self, batch_list, entity_types=None):
|
| 59 |
-
# batch_list: list of dict containing tokens, ner
|
| 60 |
-
if entity_types is None:
|
| 61 |
-
negs = self.get_negatives(batch_list, 100)
|
| 62 |
-
class_to_ids = []
|
| 63 |
-
id_to_classes = []
|
| 64 |
-
for b in batch_list:
|
| 65 |
-
# negs = b["negative"]
|
| 66 |
-
random.shuffle(negs)
|
| 67 |
-
|
| 68 |
-
# negs = negs[:sampled_neg]
|
| 69 |
-
max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)
|
| 70 |
-
|
| 71 |
-
if max_neg_type_ratio == 0:
|
| 72 |
-
# no negatives
|
| 73 |
-
neg_type_ratio = 0
|
| 74 |
-
else:
|
| 75 |
-
neg_type_ratio = random.randint(0, max_neg_type_ratio)
|
| 76 |
-
|
| 77 |
-
if neg_type_ratio == 0:
|
| 78 |
-
# no negatives
|
| 79 |
-
negs_i = []
|
| 80 |
-
else:
|
| 81 |
-
negs_i = negs[:len(b['ner']) * neg_type_ratio]
|
| 82 |
-
|
| 83 |
-
# this is the list of all possible entity types (positive and negative)
|
| 84 |
-
types = list(set([el[-1] for el in b['ner']] + negs_i))
|
| 85 |
-
|
| 86 |
-
# shuffle (every epoch)
|
| 87 |
-
random.shuffle(types)
|
| 88 |
-
|
| 89 |
-
if len(types) != 0:
|
| 90 |
-
# prob of higher number shoul
|
| 91 |
-
# random drop
|
| 92 |
-
if self.base_config.random_drop:
|
| 93 |
-
num_ents = random.randint(1, len(types))
|
| 94 |
-
types = types[:num_ents]
|
| 95 |
-
|
| 96 |
-
# maximum number of entities types
|
| 97 |
-
types = types[:int(self.base_config.max_types)]
|
| 98 |
-
|
| 99 |
-
# supervised training
|
| 100 |
-
if "label" in b:
|
| 101 |
-
types = sorted(b["label"])
|
| 102 |
-
|
| 103 |
-
class_to_id = {k: v for v, k in enumerate(types, start=1)}
|
| 104 |
-
id_to_class = {k: v for v, k in class_to_id.items()}
|
| 105 |
-
class_to_ids.append(class_to_id)
|
| 106 |
-
id_to_classes.append(id_to_class)
|
| 107 |
-
|
| 108 |
-
batch = [
|
| 109 |
-
self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
|
| 110 |
-
]
|
| 111 |
-
|
| 112 |
-
else:
|
| 113 |
-
class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
|
| 114 |
-
id_to_classes = {k: v for v, k in class_to_ids.items()}
|
| 115 |
-
batch = [
|
| 116 |
-
self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
|
| 117 |
-
]
|
| 118 |
-
|
| 119 |
-
span_idx = pad_sequence(
|
| 120 |
-
[b['span_idx'] for b in batch], batch_first=True, padding_value=0
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
span_label = pad_sequence(
|
| 124 |
-
[el['span_label'] for el in batch], batch_first=True, padding_value=-1
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
return {
|
| 128 |
-
'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
|
| 129 |
-
'span_idx': span_idx,
|
| 130 |
-
'tokens': [el['tokens'] for el in batch],
|
| 131 |
-
'span_mask': span_label != -1,
|
| 132 |
-
'span_label': span_label,
|
| 133 |
-
'entities': [el['entities'] for el in batch],
|
| 134 |
-
'classes_to_id': class_to_ids,
|
| 135 |
-
'id_to_classes': id_to_classes,
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
@staticmethod
|
| 139 |
-
def get_negatives(batch_list, sampled_neg=5):
|
| 140 |
-
ent_types = []
|
| 141 |
-
for b in batch_list:
|
| 142 |
-
types = set([el[-1] for el in b['ner']])
|
| 143 |
-
ent_types.extend(list(types))
|
| 144 |
-
ent_types = list(set(ent_types))
|
| 145 |
-
# sample negatives
|
| 146 |
-
random.shuffle(ent_types)
|
| 147 |
-
return ent_types[:sampled_neg]
|
| 148 |
-
|
| 149 |
-
def create_dataloader(self, data, entity_types=None, **kwargs):
|
| 150 |
-
return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from typing import List, Tuple, Dict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class InstructBase(nn.Module):
|
| 12 |
+
def __init__(self, config):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.max_width = config.max_width
|
| 15 |
+
self.base_config = config
|
| 16 |
+
|
| 17 |
+
def get_dict(self, spans, classes_to_id):
|
| 18 |
+
dict_tag = defaultdict(int)
|
| 19 |
+
for span in spans:
|
| 20 |
+
if span[2] in classes_to_id:
|
| 21 |
+
dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
|
| 22 |
+
return dict_tag
|
| 23 |
+
|
| 24 |
+
def preprocess_spans(self, tokens, ner, classes_to_id):
|
| 25 |
+
|
| 26 |
+
max_len = self.base_config.max_len
|
| 27 |
+
|
| 28 |
+
if len(tokens) > max_len:
|
| 29 |
+
length = max_len
|
| 30 |
+
tokens = tokens[:max_len]
|
| 31 |
+
else:
|
| 32 |
+
length = len(tokens)
|
| 33 |
+
|
| 34 |
+
spans_idx = []
|
| 35 |
+
for i in range(length):
|
| 36 |
+
spans_idx.extend([(i, i + j) for j in range(self.max_width)])
|
| 37 |
+
|
| 38 |
+
dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
|
| 39 |
+
|
| 40 |
+
# 0 for null labels
|
| 41 |
+
span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
|
| 42 |
+
spans_idx = torch.LongTensor(spans_idx)
|
| 43 |
+
|
| 44 |
+
# mask for valid spans
|
| 45 |
+
valid_span_mask = spans_idx[:, 1] > length - 1
|
| 46 |
+
|
| 47 |
+
# mask invalid positions
|
| 48 |
+
span_label = span_label.masked_fill(valid_span_mask, -1)
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
'tokens': tokens,
|
| 52 |
+
'span_idx': spans_idx,
|
| 53 |
+
'span_label': span_label,
|
| 54 |
+
'seq_length': length,
|
| 55 |
+
'entities': ner,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
def collate_fn(self, batch_list, entity_types=None):
|
| 59 |
+
# batch_list: list of dict containing tokens, ner
|
| 60 |
+
if entity_types is None:
|
| 61 |
+
negs = self.get_negatives(batch_list, 100)
|
| 62 |
+
class_to_ids = []
|
| 63 |
+
id_to_classes = []
|
| 64 |
+
for b in batch_list:
|
| 65 |
+
# negs = b["negative"]
|
| 66 |
+
random.shuffle(negs)
|
| 67 |
+
|
| 68 |
+
# negs = negs[:sampled_neg]
|
| 69 |
+
max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)
|
| 70 |
+
|
| 71 |
+
if max_neg_type_ratio == 0:
|
| 72 |
+
# no negatives
|
| 73 |
+
neg_type_ratio = 0
|
| 74 |
+
else:
|
| 75 |
+
neg_type_ratio = random.randint(0, max_neg_type_ratio)
|
| 76 |
+
|
| 77 |
+
if neg_type_ratio == 0:
|
| 78 |
+
# no negatives
|
| 79 |
+
negs_i = []
|
| 80 |
+
else:
|
| 81 |
+
negs_i = negs[:len(b['ner']) * neg_type_ratio]
|
| 82 |
+
|
| 83 |
+
# this is the list of all possible entity types (positive and negative)
|
| 84 |
+
types = list(set([el[-1] for el in b['ner']] + negs_i))
|
| 85 |
+
|
| 86 |
+
# shuffle (every epoch)
|
| 87 |
+
random.shuffle(types)
|
| 88 |
+
|
| 89 |
+
if len(types) != 0:
|
| 90 |
+
# prob of higher number shoul
|
| 91 |
+
# random drop
|
| 92 |
+
if self.base_config.random_drop:
|
| 93 |
+
num_ents = random.randint(1, len(types))
|
| 94 |
+
types = types[:num_ents]
|
| 95 |
+
|
| 96 |
+
# maximum number of entities types
|
| 97 |
+
types = types[:int(self.base_config.max_types)]
|
| 98 |
+
|
| 99 |
+
# supervised training
|
| 100 |
+
if "label" in b:
|
| 101 |
+
types = sorted(b["label"])
|
| 102 |
+
|
| 103 |
+
class_to_id = {k: v for v, k in enumerate(types, start=1)}
|
| 104 |
+
id_to_class = {k: v for v, k in class_to_id.items()}
|
| 105 |
+
class_to_ids.append(class_to_id)
|
| 106 |
+
id_to_classes.append(id_to_class)
|
| 107 |
+
|
| 108 |
+
batch = [
|
| 109 |
+
self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
|
| 114 |
+
id_to_classes = {k: v for v, k in class_to_ids.items()}
|
| 115 |
+
batch = [
|
| 116 |
+
self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
span_idx = pad_sequence(
|
| 120 |
+
[b['span_idx'] for b in batch], batch_first=True, padding_value=0
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
span_label = pad_sequence(
|
| 124 |
+
[el['span_label'] for el in batch], batch_first=True, padding_value=-1
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
|
| 129 |
+
'span_idx': span_idx,
|
| 130 |
+
'tokens': [el['tokens'] for el in batch],
|
| 131 |
+
'span_mask': span_label != -1,
|
| 132 |
+
'span_label': span_label,
|
| 133 |
+
'entities': [el['entities'] for el in batch],
|
| 134 |
+
'classes_to_id': class_to_ids,
|
| 135 |
+
'id_to_classes': id_to_classes,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def get_negatives(batch_list, sampled_neg=5):
|
| 140 |
+
ent_types = []
|
| 141 |
+
for b in batch_list:
|
| 142 |
+
types = set([el[-1] for el in b['ner']])
|
| 143 |
+
ent_types.extend(list(types))
|
| 144 |
+
ent_types = list(set(ent_types))
|
| 145 |
+
# sample negatives
|
| 146 |
+
random.shuffle(ent_types)
|
| 147 |
+
return ent_types[:sampled_neg]
|
| 148 |
+
|
| 149 |
+
def create_dataloader(self, data, entity_types=None, **kwargs):
|
| 150 |
+
return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)
|
backup/modules/data_proc.py
CHANGED
|
@@ -1,73 +1,73 @@
|
|
| 1 |
-
import json
|
| 2 |
-
from tqdm import tqdm
|
| 3 |
-
# ast.literal_eval
|
| 4 |
-
import ast, re
|
| 5 |
-
|
| 6 |
-
path = 'train.json'
|
| 7 |
-
|
| 8 |
-
with open(path, 'r') as f:
|
| 9 |
-
data = json.load(f)
|
| 10 |
-
|
| 11 |
-
def tokenize_text(text):
|
| 12 |
-
return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
|
| 13 |
-
|
| 14 |
-
def extract_entity_spans(entry):
|
| 15 |
-
text = ""
|
| 16 |
-
len_start = len("What describes ")
|
| 17 |
-
len_end = len(" in the text?")
|
| 18 |
-
entity_types = []
|
| 19 |
-
entity_texts = []
|
| 20 |
-
|
| 21 |
-
for c in entry['conversations']:
|
| 22 |
-
if c['from'] == 'human' and c['value'].startswith('Text: '):
|
| 23 |
-
text = c['value'][len('Text: '):]
|
| 24 |
-
tokenized_text = tokenize_text(text)
|
| 25 |
-
|
| 26 |
-
if c['from'] == 'human' and c['value'].startswith('What describes '):
|
| 27 |
-
|
| 28 |
-
c_type = c['value'][len_start:-len_end]
|
| 29 |
-
c_type = c_type.replace(' ', '_')
|
| 30 |
-
entity_types.append(c_type)
|
| 31 |
-
|
| 32 |
-
elif c['from'] == 'gpt' and c['value'].startswith('['):
|
| 33 |
-
if c['value'] == '[]':
|
| 34 |
-
entity_types = entity_types[:-1]
|
| 35 |
-
continue
|
| 36 |
-
|
| 37 |
-
texts_ents = ast.literal_eval(c['value'])
|
| 38 |
-
# replace space to _ in texts_ents
|
| 39 |
-
entity_texts.extend(texts_ents)
|
| 40 |
-
num_repeat = len(texts_ents) - 1
|
| 41 |
-
entity_types.extend([entity_types[-1]] * num_repeat)
|
| 42 |
-
|
| 43 |
-
entity_spans = []
|
| 44 |
-
for j, entity_text in enumerate(entity_texts):
|
| 45 |
-
entity_tokens = tokenize_text(entity_text)
|
| 46 |
-
matches = []
|
| 47 |
-
for i in range(len(tokenized_text) - len(entity_tokens) + 1):
|
| 48 |
-
if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
|
| 49 |
-
matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
|
| 50 |
-
if matches:
|
| 51 |
-
entity_spans.extend(matches)
|
| 52 |
-
|
| 53 |
-
return entity_spans, tokenized_text
|
| 54 |
-
|
| 55 |
-
# Usage:
|
| 56 |
-
# Replace 'entry' with the specific entry from your JSON data
|
| 57 |
-
entry = data[17818] # For example, taking the first entry
|
| 58 |
-
entity_spans, tokenized_text = extract_entity_spans(entry)
|
| 59 |
-
print("Entity Spans:", entity_spans)
|
| 60 |
-
#print("Tokenized Text:", tokenized_text)
|
| 61 |
-
|
| 62 |
-
# create a dict: {"tokenized_text": tokenized_text, "entity_spans": entity_spans}
|
| 63 |
-
|
| 64 |
-
all_data = []
|
| 65 |
-
|
| 66 |
-
for entry in tqdm(data):
|
| 67 |
-
entity_spans, tokenized_text = extract_entity_spans(entry)
|
| 68 |
-
all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans})
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
with open('train_instruct.json', 'w') as f:
|
| 72 |
-
json.dump(all_data, f)
|
| 73 |
-
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
# ast.literal_eval
|
| 4 |
+
import ast, re
|
| 5 |
+
|
| 6 |
+
path = 'train.json'
|
| 7 |
+
|
| 8 |
+
with open(path, 'r') as f:
|
| 9 |
+
data = json.load(f)
|
| 10 |
+
|
| 11 |
+
def tokenize_text(text):
|
| 12 |
+
return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
|
| 13 |
+
|
| 14 |
+
def extract_entity_spans(entry):
|
| 15 |
+
text = ""
|
| 16 |
+
len_start = len("What describes ")
|
| 17 |
+
len_end = len(" in the text?")
|
| 18 |
+
entity_types = []
|
| 19 |
+
entity_texts = []
|
| 20 |
+
|
| 21 |
+
for c in entry['conversations']:
|
| 22 |
+
if c['from'] == 'human' and c['value'].startswith('Text: '):
|
| 23 |
+
text = c['value'][len('Text: '):]
|
| 24 |
+
tokenized_text = tokenize_text(text)
|
| 25 |
+
|
| 26 |
+
if c['from'] == 'human' and c['value'].startswith('What describes '):
|
| 27 |
+
|
| 28 |
+
c_type = c['value'][len_start:-len_end]
|
| 29 |
+
c_type = c_type.replace(' ', '_')
|
| 30 |
+
entity_types.append(c_type)
|
| 31 |
+
|
| 32 |
+
elif c['from'] == 'gpt' and c['value'].startswith('['):
|
| 33 |
+
if c['value'] == '[]':
|
| 34 |
+
entity_types = entity_types[:-1]
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
texts_ents = ast.literal_eval(c['value'])
|
| 38 |
+
# replace space to _ in texts_ents
|
| 39 |
+
entity_texts.extend(texts_ents)
|
| 40 |
+
num_repeat = len(texts_ents) - 1
|
| 41 |
+
entity_types.extend([entity_types[-1]] * num_repeat)
|
| 42 |
+
|
| 43 |
+
entity_spans = []
|
| 44 |
+
for j, entity_text in enumerate(entity_texts):
|
| 45 |
+
entity_tokens = tokenize_text(entity_text)
|
| 46 |
+
matches = []
|
| 47 |
+
for i in range(len(tokenized_text) - len(entity_tokens) + 1):
|
| 48 |
+
if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
|
| 49 |
+
matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
|
| 50 |
+
if matches:
|
| 51 |
+
entity_spans.extend(matches)
|
| 52 |
+
|
| 53 |
+
return entity_spans, tokenized_text
|
| 54 |
+
|
| 55 |
+
# Usage:
|
| 56 |
+
# Replace 'entry' with the specific entry from your JSON data
|
| 57 |
+
entry = data[17818] # For example, taking the first entry
|
| 58 |
+
entity_spans, tokenized_text = extract_entity_spans(entry)
|
| 59 |
+
print("Entity Spans:", entity_spans)
|
| 60 |
+
#print("Tokenized Text:", tokenized_text)
|
| 61 |
+
|
| 62 |
+
# create a dict: {"tokenized_text": tokenized_text, "entity_spans": entity_spans}
|
| 63 |
+
|
| 64 |
+
all_data = []
|
| 65 |
+
|
| 66 |
+
for entry in tqdm(data):
|
| 67 |
+
entity_spans, tokenized_text = extract_entity_spans(entry)
|
| 68 |
+
all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans})
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
with open('train_instruct.json', 'w') as f:
|
| 72 |
+
json.dump(all_data, f)
|
| 73 |
+
|
backup/modules/evaluator.py
CHANGED
|
@@ -1,152 +1,152 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
from seqeval.metrics.v1 import _prf_divide
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def extract_tp_actual_correct(y_true, y_pred):
|
| 9 |
-
entities_true = defaultdict(set)
|
| 10 |
-
entities_pred = defaultdict(set)
|
| 11 |
-
|
| 12 |
-
for type_name, (start, end), idx in y_true:
|
| 13 |
-
entities_true[type_name].add((start, end, idx))
|
| 14 |
-
for type_name, (start, end), idx in y_pred:
|
| 15 |
-
entities_pred[type_name].add((start, end, idx))
|
| 16 |
-
|
| 17 |
-
target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
|
| 18 |
-
|
| 19 |
-
tp_sum = np.array([], dtype=np.int32)
|
| 20 |
-
pred_sum = np.array([], dtype=np.int32)
|
| 21 |
-
true_sum = np.array([], dtype=np.int32)
|
| 22 |
-
for type_name in target_names:
|
| 23 |
-
entities_true_type = entities_true.get(type_name, set())
|
| 24 |
-
entities_pred_type = entities_pred.get(type_name, set())
|
| 25 |
-
tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
|
| 26 |
-
pred_sum = np.append(pred_sum, len(entities_pred_type))
|
| 27 |
-
true_sum = np.append(true_sum, len(entities_true_type))
|
| 28 |
-
|
| 29 |
-
return pred_sum, tp_sum, true_sum, target_names
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def flatten_for_eval(y_true, y_pred):
|
| 33 |
-
all_true = []
|
| 34 |
-
all_pred = []
|
| 35 |
-
|
| 36 |
-
for i, (true, pred) in enumerate(zip(y_true, y_pred)):
|
| 37 |
-
all_true.extend([t + [i] for t in true])
|
| 38 |
-
all_pred.extend([p + [i] for p in pred])
|
| 39 |
-
|
| 40 |
-
return all_true, all_pred
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def compute_prf(y_true, y_pred, average='micro'):
|
| 44 |
-
y_true, y_pred = flatten_for_eval(y_true, y_pred)
|
| 45 |
-
|
| 46 |
-
pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
|
| 47 |
-
|
| 48 |
-
if average == 'micro':
|
| 49 |
-
tp_sum = np.array([tp_sum.sum()])
|
| 50 |
-
pred_sum = np.array([pred_sum.sum()])
|
| 51 |
-
true_sum = np.array([true_sum.sum()])
|
| 52 |
-
|
| 53 |
-
precision = _prf_divide(
|
| 54 |
-
numerator=tp_sum,
|
| 55 |
-
denominator=pred_sum,
|
| 56 |
-
metric='precision',
|
| 57 |
-
modifier='predicted',
|
| 58 |
-
average=average,
|
| 59 |
-
warn_for=('precision', 'recall', 'f-score'),
|
| 60 |
-
zero_division='warn'
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
recall = _prf_divide(
|
| 64 |
-
numerator=tp_sum,
|
| 65 |
-
denominator=true_sum,
|
| 66 |
-
metric='recall',
|
| 67 |
-
modifier='true',
|
| 68 |
-
average=average,
|
| 69 |
-
warn_for=('precision', 'recall', 'f-score'),
|
| 70 |
-
zero_division='warn'
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
denominator = precision + recall
|
| 74 |
-
denominator[denominator == 0.] = 1
|
| 75 |
-
f_score = 2 * (precision * recall) / denominator
|
| 76 |
-
|
| 77 |
-
return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class Evaluator:
|
| 81 |
-
def __init__(self, all_true, all_outs):
|
| 82 |
-
self.all_true = all_true
|
| 83 |
-
self.all_outs = all_outs
|
| 84 |
-
|
| 85 |
-
def get_entities_fr(self, ents):
|
| 86 |
-
all_ents = []
|
| 87 |
-
for s, e, lab in ents:
|
| 88 |
-
all_ents.append([lab, (s, e)])
|
| 89 |
-
return all_ents
|
| 90 |
-
|
| 91 |
-
def transform_data(self):
|
| 92 |
-
all_true_ent = []
|
| 93 |
-
all_outs_ent = []
|
| 94 |
-
for i, j in zip(self.all_true, self.all_outs):
|
| 95 |
-
e = self.get_entities_fr(i)
|
| 96 |
-
all_true_ent.append(e)
|
| 97 |
-
e = self.get_entities_fr(j)
|
| 98 |
-
all_outs_ent.append(e)
|
| 99 |
-
return all_true_ent, all_outs_ent
|
| 100 |
-
|
| 101 |
-
@torch.no_grad()
|
| 102 |
-
def evaluate(self):
|
| 103 |
-
all_true_typed, all_outs_typed = self.transform_data()
|
| 104 |
-
precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
|
| 105 |
-
output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
|
| 106 |
-
return output_str, f1
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def is_nested(idx1, idx2):
|
| 110 |
-
# Return True if idx2 is nested inside idx1 or vice versa
|
| 111 |
-
return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def has_overlapping(idx1, idx2):
|
| 115 |
-
overlapping = True
|
| 116 |
-
if idx1[:2] == idx2[:2]:
|
| 117 |
-
return overlapping
|
| 118 |
-
if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
|
| 119 |
-
overlapping = False
|
| 120 |
-
return overlapping
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def has_overlapping_nested(idx1, idx2):
|
| 124 |
-
# Return True if idx1 and idx2 overlap, but neither is nested inside the other
|
| 125 |
-
if idx1[:2] == idx2[:2]:
|
| 126 |
-
return True
|
| 127 |
-
if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
|
| 128 |
-
return False
|
| 129 |
-
else:
|
| 130 |
-
return True
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def greedy_search(spans, flat_ner=True): # start, end, class, score
|
| 134 |
-
|
| 135 |
-
if flat_ner:
|
| 136 |
-
has_ov = has_overlapping
|
| 137 |
-
else:
|
| 138 |
-
has_ov = has_overlapping_nested
|
| 139 |
-
|
| 140 |
-
new_list = []
|
| 141 |
-
span_prob = sorted(spans, key=lambda x: -x[-1])
|
| 142 |
-
for i in range(len(spans)):
|
| 143 |
-
b = span_prob[i]
|
| 144 |
-
flag = False
|
| 145 |
-
for new in new_list:
|
| 146 |
-
if has_ov(b[:-1], new):
|
| 147 |
-
flag = True
|
| 148 |
-
break
|
| 149 |
-
if not flag:
|
| 150 |
-
new_list.append(b[:-1])
|
| 151 |
-
new_list = sorted(new_list, key=lambda x: x[0])
|
| 152 |
-
return new_list
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from seqeval.metrics.v1 import _prf_divide
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def extract_tp_actual_correct(y_true, y_pred):
|
| 9 |
+
entities_true = defaultdict(set)
|
| 10 |
+
entities_pred = defaultdict(set)
|
| 11 |
+
|
| 12 |
+
for type_name, (start, end), idx in y_true:
|
| 13 |
+
entities_true[type_name].add((start, end, idx))
|
| 14 |
+
for type_name, (start, end), idx in y_pred:
|
| 15 |
+
entities_pred[type_name].add((start, end, idx))
|
| 16 |
+
|
| 17 |
+
target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
|
| 18 |
+
|
| 19 |
+
tp_sum = np.array([], dtype=np.int32)
|
| 20 |
+
pred_sum = np.array([], dtype=np.int32)
|
| 21 |
+
true_sum = np.array([], dtype=np.int32)
|
| 22 |
+
for type_name in target_names:
|
| 23 |
+
entities_true_type = entities_true.get(type_name, set())
|
| 24 |
+
entities_pred_type = entities_pred.get(type_name, set())
|
| 25 |
+
tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
|
| 26 |
+
pred_sum = np.append(pred_sum, len(entities_pred_type))
|
| 27 |
+
true_sum = np.append(true_sum, len(entities_true_type))
|
| 28 |
+
|
| 29 |
+
return pred_sum, tp_sum, true_sum, target_names
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def flatten_for_eval(y_true, y_pred):
|
| 33 |
+
all_true = []
|
| 34 |
+
all_pred = []
|
| 35 |
+
|
| 36 |
+
for i, (true, pred) in enumerate(zip(y_true, y_pred)):
|
| 37 |
+
all_true.extend([t + [i] for t in true])
|
| 38 |
+
all_pred.extend([p + [i] for p in pred])
|
| 39 |
+
|
| 40 |
+
return all_true, all_pred
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def compute_prf(y_true, y_pred, average='micro'):
|
| 44 |
+
y_true, y_pred = flatten_for_eval(y_true, y_pred)
|
| 45 |
+
|
| 46 |
+
pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
|
| 47 |
+
|
| 48 |
+
if average == 'micro':
|
| 49 |
+
tp_sum = np.array([tp_sum.sum()])
|
| 50 |
+
pred_sum = np.array([pred_sum.sum()])
|
| 51 |
+
true_sum = np.array([true_sum.sum()])
|
| 52 |
+
|
| 53 |
+
precision = _prf_divide(
|
| 54 |
+
numerator=tp_sum,
|
| 55 |
+
denominator=pred_sum,
|
| 56 |
+
metric='precision',
|
| 57 |
+
modifier='predicted',
|
| 58 |
+
average=average,
|
| 59 |
+
warn_for=('precision', 'recall', 'f-score'),
|
| 60 |
+
zero_division='warn'
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
recall = _prf_divide(
|
| 64 |
+
numerator=tp_sum,
|
| 65 |
+
denominator=true_sum,
|
| 66 |
+
metric='recall',
|
| 67 |
+
modifier='true',
|
| 68 |
+
average=average,
|
| 69 |
+
warn_for=('precision', 'recall', 'f-score'),
|
| 70 |
+
zero_division='warn'
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
denominator = precision + recall
|
| 74 |
+
denominator[denominator == 0.] = 1
|
| 75 |
+
f_score = 2 * (precision * recall) / denominator
|
| 76 |
+
|
| 77 |
+
return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Evaluator:
|
| 81 |
+
def __init__(self, all_true, all_outs):
|
| 82 |
+
self.all_true = all_true
|
| 83 |
+
self.all_outs = all_outs
|
| 84 |
+
|
| 85 |
+
def get_entities_fr(self, ents):
|
| 86 |
+
all_ents = []
|
| 87 |
+
for s, e, lab in ents:
|
| 88 |
+
all_ents.append([lab, (s, e)])
|
| 89 |
+
return all_ents
|
| 90 |
+
|
| 91 |
+
def transform_data(self):
|
| 92 |
+
all_true_ent = []
|
| 93 |
+
all_outs_ent = []
|
| 94 |
+
for i, j in zip(self.all_true, self.all_outs):
|
| 95 |
+
e = self.get_entities_fr(i)
|
| 96 |
+
all_true_ent.append(e)
|
| 97 |
+
e = self.get_entities_fr(j)
|
| 98 |
+
all_outs_ent.append(e)
|
| 99 |
+
return all_true_ent, all_outs_ent
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def evaluate(self):
|
| 103 |
+
all_true_typed, all_outs_typed = self.transform_data()
|
| 104 |
+
precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
|
| 105 |
+
output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
|
| 106 |
+
return output_str, f1
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def is_nested(idx1, idx2):
|
| 110 |
+
# Return True if idx2 is nested inside idx1 or vice versa
|
| 111 |
+
return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def has_overlapping(idx1, idx2):
|
| 115 |
+
overlapping = True
|
| 116 |
+
if idx1[:2] == idx2[:2]:
|
| 117 |
+
return overlapping
|
| 118 |
+
if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
|
| 119 |
+
overlapping = False
|
| 120 |
+
return overlapping
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def has_overlapping_nested(idx1, idx2):
|
| 124 |
+
# Return True if idx1 and idx2 overlap, but neither is nested inside the other
|
| 125 |
+
if idx1[:2] == idx2[:2]:
|
| 126 |
+
return True
|
| 127 |
+
if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
|
| 128 |
+
return False
|
| 129 |
+
else:
|
| 130 |
+
return True
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def greedy_search(spans, flat_ner=True): # start, end, class, score
|
| 134 |
+
|
| 135 |
+
if flat_ner:
|
| 136 |
+
has_ov = has_overlapping
|
| 137 |
+
else:
|
| 138 |
+
has_ov = has_overlapping_nested
|
| 139 |
+
|
| 140 |
+
new_list = []
|
| 141 |
+
span_prob = sorted(spans, key=lambda x: -x[-1])
|
| 142 |
+
for i in range(len(spans)):
|
| 143 |
+
b = span_prob[i]
|
| 144 |
+
flag = False
|
| 145 |
+
for new in new_list:
|
| 146 |
+
if has_ov(b[:-1], new):
|
| 147 |
+
flag = True
|
| 148 |
+
break
|
| 149 |
+
if not flag:
|
| 150 |
+
new_list.append(b[:-1])
|
| 151 |
+
new_list = sorted(new_list, key=lambda x: x[0])
|
| 152 |
+
return new_list
|
backup/modules/layers.py
CHANGED
|
@@ -1,28 +1,28 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn.functional as F
|
| 3 |
-
from torch import nn
|
| 4 |
-
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class LstmSeq2SeqEncoder(nn.Module):
|
| 8 |
-
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
|
| 9 |
-
super(LstmSeq2SeqEncoder, self).__init__()
|
| 10 |
-
self.lstm = nn.LSTM(input_size=input_size,
|
| 11 |
-
hidden_size=hidden_size,
|
| 12 |
-
num_layers=num_layers,
|
| 13 |
-
dropout=dropout,
|
| 14 |
-
bidirectional=bidirectional,
|
| 15 |
-
batch_first=True)
|
| 16 |
-
|
| 17 |
-
def forward(self, x, mask, hidden=None):
|
| 18 |
-
# Packing the input sequence
|
| 19 |
-
lengths = mask.sum(dim=1).cpu()
|
| 20 |
-
packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
|
| 21 |
-
|
| 22 |
-
# Passing packed sequence through LSTM
|
| 23 |
-
packed_output, hidden = self.lstm(packed_x, hidden)
|
| 24 |
-
|
| 25 |
-
# Unpacking the output sequence
|
| 26 |
-
output, _ = pad_packed_sequence(packed_output, batch_first=True)
|
| 27 |
-
|
| 28 |
-
return output
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LstmSeq2SeqEncoder(nn.Module):
|
| 8 |
+
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
|
| 9 |
+
super(LstmSeq2SeqEncoder, self).__init__()
|
| 10 |
+
self.lstm = nn.LSTM(input_size=input_size,
|
| 11 |
+
hidden_size=hidden_size,
|
| 12 |
+
num_layers=num_layers,
|
| 13 |
+
dropout=dropout,
|
| 14 |
+
bidirectional=bidirectional,
|
| 15 |
+
batch_first=True)
|
| 16 |
+
|
| 17 |
+
def forward(self, x, mask, hidden=None):
|
| 18 |
+
# Packing the input sequence
|
| 19 |
+
lengths = mask.sum(dim=1).cpu()
|
| 20 |
+
packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
|
| 21 |
+
|
| 22 |
+
# Passing packed sequence through LSTM
|
| 23 |
+
packed_output, hidden = self.lstm(packed_x, hidden)
|
| 24 |
+
|
| 25 |
+
# Unpacking the output sequence
|
| 26 |
+
output, _ = pad_packed_sequence(packed_output, batch_first=True)
|
| 27 |
+
|
| 28 |
+
return output
|
backup/modules/run_evaluation.py
CHANGED
|
@@ -1,188 +1,188 @@
|
|
| 1 |
-
import glob
|
| 2 |
-
import json
|
| 3 |
-
import os
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
from tqdm import tqdm
|
| 8 |
-
import random
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def open_content(path):
|
| 12 |
-
paths = glob.glob(os.path.join(path, "*.json"))
|
| 13 |
-
train, dev, test, labels = None, None, None, None
|
| 14 |
-
for p in paths:
|
| 15 |
-
if "train" in p:
|
| 16 |
-
with open(p, "r") as f:
|
| 17 |
-
train = json.load(f)
|
| 18 |
-
elif "dev" in p:
|
| 19 |
-
with open(p, "r") as f:
|
| 20 |
-
dev = json.load(f)
|
| 21 |
-
elif "test" in p:
|
| 22 |
-
with open(p, "r") as f:
|
| 23 |
-
test = json.load(f)
|
| 24 |
-
elif "labels" in p:
|
| 25 |
-
with open(p, "r") as f:
|
| 26 |
-
labels = json.load(f)
|
| 27 |
-
return train, dev, test, labels
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def process(data):
|
| 31 |
-
words = data['sentence'].split()
|
| 32 |
-
entities = [] # List of entities (start, end, type)
|
| 33 |
-
|
| 34 |
-
for entity in data['entities']:
|
| 35 |
-
start_char, end_char = entity['pos']
|
| 36 |
-
|
| 37 |
-
# Initialize variables to keep track of word positions
|
| 38 |
-
start_word = None
|
| 39 |
-
end_word = None
|
| 40 |
-
|
| 41 |
-
# Iterate through words and find the word positions
|
| 42 |
-
char_count = 0
|
| 43 |
-
for i, word in enumerate(words):
|
| 44 |
-
word_length = len(word)
|
| 45 |
-
if char_count == start_char:
|
| 46 |
-
start_word = i
|
| 47 |
-
if char_count + word_length == end_char:
|
| 48 |
-
end_word = i
|
| 49 |
-
break
|
| 50 |
-
char_count += word_length + 1 # Add 1 for the space
|
| 51 |
-
|
| 52 |
-
# Append the word positions to the list
|
| 53 |
-
entities.append((start_word, end_word, entity['type']))
|
| 54 |
-
|
| 55 |
-
# Create a list of word positions for each entity
|
| 56 |
-
sample = {
|
| 57 |
-
"tokenized_text": words,
|
| 58 |
-
"ner": entities
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
return sample
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# create dataset
|
| 65 |
-
def create_dataset(path):
|
| 66 |
-
train, dev, test, labels = open_content(path)
|
| 67 |
-
train_dataset = []
|
| 68 |
-
dev_dataset = []
|
| 69 |
-
test_dataset = []
|
| 70 |
-
for data in train:
|
| 71 |
-
train_dataset.append(process(data))
|
| 72 |
-
for data in dev:
|
| 73 |
-
dev_dataset.append(process(data))
|
| 74 |
-
for data in test:
|
| 75 |
-
test_dataset.append(process(data))
|
| 76 |
-
return train_dataset, dev_dataset, test_dataset, labels
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
@torch.no_grad()
|
| 80 |
-
def get_for_one_path(path, model):
|
| 81 |
-
# load the dataset
|
| 82 |
-
_, _, test_dataset, entity_types = create_dataset(path)
|
| 83 |
-
|
| 84 |
-
data_name = path.split("/")[-1] # get the name of the dataset
|
| 85 |
-
|
| 86 |
-
# check if the dataset is flat_ner
|
| 87 |
-
flat_ner = True
|
| 88 |
-
if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]):
|
| 89 |
-
flat_ner = False
|
| 90 |
-
|
| 91 |
-
# evaluate the model
|
| 92 |
-
results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12,
|
| 93 |
-
entity_types=entity_types)
|
| 94 |
-
return data_name, results, f1
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def get_for_all_path(model, steps, log_dir, data_paths):
|
| 98 |
-
all_paths = glob.glob(f"{data_paths}/*")
|
| 99 |
-
|
| 100 |
-
all_paths = sorted(all_paths)
|
| 101 |
-
|
| 102 |
-
# move the model to the device
|
| 103 |
-
device = next(model.parameters()).device
|
| 104 |
-
model.to(device)
|
| 105 |
-
# set the model to eval mode
|
| 106 |
-
model.eval()
|
| 107 |
-
|
| 108 |
-
# log the results
|
| 109 |
-
save_path = os.path.join(log_dir, "results.txt")
|
| 110 |
-
|
| 111 |
-
with open(save_path, "a") as f:
|
| 112 |
-
f.write("##############################################\n")
|
| 113 |
-
# write step
|
| 114 |
-
f.write("step: " + str(steps) + "\n")
|
| 115 |
-
|
| 116 |
-
zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music",
|
| 117 |
-
"CrossNER_politics", "CrossNER_science"]
|
| 118 |
-
|
| 119 |
-
zero_shot_benc_results = {}
|
| 120 |
-
all_results = {} # without crossNER
|
| 121 |
-
|
| 122 |
-
for p in tqdm(all_paths):
|
| 123 |
-
if "sample_" not in p:
|
| 124 |
-
data_name, results, f1 = get_for_one_path(p, model)
|
| 125 |
-
# write to file
|
| 126 |
-
with open(save_path, "a") as f:
|
| 127 |
-
f.write(data_name + "\n")
|
| 128 |
-
f.write(str(results) + "\n")
|
| 129 |
-
|
| 130 |
-
if data_name in zero_shot_benc:
|
| 131 |
-
zero_shot_benc_results[data_name] = f1
|
| 132 |
-
else:
|
| 133 |
-
all_results[data_name] = f1
|
| 134 |
-
|
| 135 |
-
avg_all = sum(all_results.values()) / len(all_results)
|
| 136 |
-
avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results)
|
| 137 |
-
|
| 138 |
-
save_path_table = os.path.join(log_dir, "tables.txt")
|
| 139 |
-
|
| 140 |
-
# results for all datasets except crossNER
|
| 141 |
-
table_bench_all = ""
|
| 142 |
-
for k, v in all_results.items():
|
| 143 |
-
table_bench_all += f"{k:20}: {v:.1%}\n"
|
| 144 |
-
# (20 size aswell for average i.e. :20)
|
| 145 |
-
table_bench_all += f"{'Average':20}: {avg_all:.1%}"
|
| 146 |
-
|
| 147 |
-
# results for zero-shot benchmark
|
| 148 |
-
table_bench_zeroshot = ""
|
| 149 |
-
for k, v in zero_shot_benc_results.items():
|
| 150 |
-
table_bench_zeroshot += f"{k:20}: {v:.1%}\n"
|
| 151 |
-
table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}"
|
| 152 |
-
|
| 153 |
-
# write to file
|
| 154 |
-
with open(save_path_table, "a") as f:
|
| 155 |
-
f.write("##############################################\n")
|
| 156 |
-
f.write("step: " + str(steps) + "\n")
|
| 157 |
-
f.write("Table for all datasets except crossNER\n")
|
| 158 |
-
f.write(table_bench_all + "\n\n")
|
| 159 |
-
f.write("Table for zero-shot benchmark\n")
|
| 160 |
-
f.write(table_bench_zeroshot + "\n")
|
| 161 |
-
f.write("##############################################\n\n")
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def sample_train_data(data_paths, sample_size=10000):
|
| 165 |
-
all_paths = glob.glob(f"{data_paths}/*")
|
| 166 |
-
|
| 167 |
-
all_paths = sorted(all_paths)
|
| 168 |
-
|
| 169 |
-
# to exclude the zero-shot benchmark datasets
|
| 170 |
-
zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music",
|
| 171 |
-
"CrossNER_politics", "CrossNER_science", "ACE 2004"]
|
| 172 |
-
|
| 173 |
-
new_train = []
|
| 174 |
-
# take 10k samples from each dataset
|
| 175 |
-
for p in tqdm(all_paths):
|
| 176 |
-
if any([i in p for i in zero_shot_benc]):
|
| 177 |
-
continue
|
| 178 |
-
train, dev, test, labels = create_dataset(p)
|
| 179 |
-
|
| 180 |
-
# add label key to the train data
|
| 181 |
-
for i in range(len(train)):
|
| 182 |
-
train[i]["label"] = labels
|
| 183 |
-
|
| 184 |
-
random.shuffle(train)
|
| 185 |
-
train = train[:sample_size]
|
| 186 |
-
new_train.extend(train)
|
| 187 |
-
|
| 188 |
-
return new_train
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def open_content(path):
|
| 12 |
+
paths = glob.glob(os.path.join(path, "*.json"))
|
| 13 |
+
train, dev, test, labels = None, None, None, None
|
| 14 |
+
for p in paths:
|
| 15 |
+
if "train" in p:
|
| 16 |
+
with open(p, "r") as f:
|
| 17 |
+
train = json.load(f)
|
| 18 |
+
elif "dev" in p:
|
| 19 |
+
with open(p, "r") as f:
|
| 20 |
+
dev = json.load(f)
|
| 21 |
+
elif "test" in p:
|
| 22 |
+
with open(p, "r") as f:
|
| 23 |
+
test = json.load(f)
|
| 24 |
+
elif "labels" in p:
|
| 25 |
+
with open(p, "r") as f:
|
| 26 |
+
labels = json.load(f)
|
| 27 |
+
return train, dev, test, labels
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def process(data):
|
| 31 |
+
words = data['sentence'].split()
|
| 32 |
+
entities = [] # List of entities (start, end, type)
|
| 33 |
+
|
| 34 |
+
for entity in data['entities']:
|
| 35 |
+
start_char, end_char = entity['pos']
|
| 36 |
+
|
| 37 |
+
# Initialize variables to keep track of word positions
|
| 38 |
+
start_word = None
|
| 39 |
+
end_word = None
|
| 40 |
+
|
| 41 |
+
# Iterate through words and find the word positions
|
| 42 |
+
char_count = 0
|
| 43 |
+
for i, word in enumerate(words):
|
| 44 |
+
word_length = len(word)
|
| 45 |
+
if char_count == start_char:
|
| 46 |
+
start_word = i
|
| 47 |
+
if char_count + word_length == end_char:
|
| 48 |
+
end_word = i
|
| 49 |
+
break
|
| 50 |
+
char_count += word_length + 1 # Add 1 for the space
|
| 51 |
+
|
| 52 |
+
# Append the word positions to the list
|
| 53 |
+
entities.append((start_word, end_word, entity['type']))
|
| 54 |
+
|
| 55 |
+
# Create a list of word positions for each entity
|
| 56 |
+
sample = {
|
| 57 |
+
"tokenized_text": words,
|
| 58 |
+
"ner": entities
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
return sample
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# create dataset
|
| 65 |
+
def create_dataset(path):
|
| 66 |
+
train, dev, test, labels = open_content(path)
|
| 67 |
+
train_dataset = []
|
| 68 |
+
dev_dataset = []
|
| 69 |
+
test_dataset = []
|
| 70 |
+
for data in train:
|
| 71 |
+
train_dataset.append(process(data))
|
| 72 |
+
for data in dev:
|
| 73 |
+
dev_dataset.append(process(data))
|
| 74 |
+
for data in test:
|
| 75 |
+
test_dataset.append(process(data))
|
| 76 |
+
return train_dataset, dev_dataset, test_dataset, labels
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@torch.no_grad()
|
| 80 |
+
def get_for_one_path(path, model):
|
| 81 |
+
# load the dataset
|
| 82 |
+
_, _, test_dataset, entity_types = create_dataset(path)
|
| 83 |
+
|
| 84 |
+
data_name = path.split("/")[-1] # get the name of the dataset
|
| 85 |
+
|
| 86 |
+
# check if the dataset is flat_ner
|
| 87 |
+
flat_ner = True
|
| 88 |
+
if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]):
|
| 89 |
+
flat_ner = False
|
| 90 |
+
|
| 91 |
+
# evaluate the model
|
| 92 |
+
results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12,
|
| 93 |
+
entity_types=entity_types)
|
| 94 |
+
return data_name, results, f1
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_for_all_path(model, steps, log_dir, data_paths):
|
| 98 |
+
all_paths = glob.glob(f"{data_paths}/*")
|
| 99 |
+
|
| 100 |
+
all_paths = sorted(all_paths)
|
| 101 |
+
|
| 102 |
+
# move the model to the device
|
| 103 |
+
device = next(model.parameters()).device
|
| 104 |
+
model.to(device)
|
| 105 |
+
# set the model to eval mode
|
| 106 |
+
model.eval()
|
| 107 |
+
|
| 108 |
+
# log the results
|
| 109 |
+
save_path = os.path.join(log_dir, "results.txt")
|
| 110 |
+
|
| 111 |
+
with open(save_path, "a") as f:
|
| 112 |
+
f.write("##############################################\n")
|
| 113 |
+
# write step
|
| 114 |
+
f.write("step: " + str(steps) + "\n")
|
| 115 |
+
|
| 116 |
+
zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music",
|
| 117 |
+
"CrossNER_politics", "CrossNER_science"]
|
| 118 |
+
|
| 119 |
+
zero_shot_benc_results = {}
|
| 120 |
+
all_results = {} # without crossNER
|
| 121 |
+
|
| 122 |
+
for p in tqdm(all_paths):
|
| 123 |
+
if "sample_" not in p:
|
| 124 |
+
data_name, results, f1 = get_for_one_path(p, model)
|
| 125 |
+
# write to file
|
| 126 |
+
with open(save_path, "a") as f:
|
| 127 |
+
f.write(data_name + "\n")
|
| 128 |
+
f.write(str(results) + "\n")
|
| 129 |
+
|
| 130 |
+
if data_name in zero_shot_benc:
|
| 131 |
+
zero_shot_benc_results[data_name] = f1
|
| 132 |
+
else:
|
| 133 |
+
all_results[data_name] = f1
|
| 134 |
+
|
| 135 |
+
avg_all = sum(all_results.values()) / len(all_results)
|
| 136 |
+
avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results)
|
| 137 |
+
|
| 138 |
+
save_path_table = os.path.join(log_dir, "tables.txt")
|
| 139 |
+
|
| 140 |
+
# results for all datasets except crossNER
|
| 141 |
+
table_bench_all = ""
|
| 142 |
+
for k, v in all_results.items():
|
| 143 |
+
table_bench_all += f"{k:20}: {v:.1%}\n"
|
| 144 |
+
# (20 size aswell for average i.e. :20)
|
| 145 |
+
table_bench_all += f"{'Average':20}: {avg_all:.1%}"
|
| 146 |
+
|
| 147 |
+
# results for zero-shot benchmark
|
| 148 |
+
table_bench_zeroshot = ""
|
| 149 |
+
for k, v in zero_shot_benc_results.items():
|
| 150 |
+
table_bench_zeroshot += f"{k:20}: {v:.1%}\n"
|
| 151 |
+
table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}"
|
| 152 |
+
|
| 153 |
+
# write to file
|
| 154 |
+
with open(save_path_table, "a") as f:
|
| 155 |
+
f.write("##############################################\n")
|
| 156 |
+
f.write("step: " + str(steps) + "\n")
|
| 157 |
+
f.write("Table for all datasets except crossNER\n")
|
| 158 |
+
f.write(table_bench_all + "\n\n")
|
| 159 |
+
f.write("Table for zero-shot benchmark\n")
|
| 160 |
+
f.write(table_bench_zeroshot + "\n")
|
| 161 |
+
f.write("##############################################\n\n")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def sample_train_data(data_paths, sample_size=10000):
|
| 165 |
+
all_paths = glob.glob(f"{data_paths}/*")
|
| 166 |
+
|
| 167 |
+
all_paths = sorted(all_paths)
|
| 168 |
+
|
| 169 |
+
# to exclude the zero-shot benchmark datasets
|
| 170 |
+
zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music",
|
| 171 |
+
"CrossNER_politics", "CrossNER_science", "ACE 2004"]
|
| 172 |
+
|
| 173 |
+
new_train = []
|
| 174 |
+
# take 10k samples from each dataset
|
| 175 |
+
for p in tqdm(all_paths):
|
| 176 |
+
if any([i in p for i in zero_shot_benc]):
|
| 177 |
+
continue
|
| 178 |
+
train, dev, test, labels = create_dataset(p)
|
| 179 |
+
|
| 180 |
+
# add label key to the train data
|
| 181 |
+
for i in range(len(train)):
|
| 182 |
+
train[i]["label"] = labels
|
| 183 |
+
|
| 184 |
+
random.shuffle(train)
|
| 185 |
+
train = train[:sample_size]
|
| 186 |
+
new_train.extend(train)
|
| 187 |
+
|
| 188 |
+
return new_train
|
backup/modules/span_rep.py
CHANGED
|
@@ -1,369 +1,369 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn.functional as F
|
| 3 |
-
from torch import nn
|
| 4 |
-
|
| 5 |
-
def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
|
| 6 |
-
"""
|
| 7 |
-
Creates a projection layer with specified configurations.
|
| 8 |
-
"""
|
| 9 |
-
if out_dim is None:
|
| 10 |
-
out_dim = hidden_size
|
| 11 |
-
|
| 12 |
-
return nn.Sequential(
|
| 13 |
-
nn.Linear(hidden_size, out_dim * 4),
|
| 14 |
-
nn.ReLU(),
|
| 15 |
-
nn.Dropout(dropout),
|
| 16 |
-
nn.Linear(out_dim * 4, out_dim)
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class SpanQuery(nn.Module):
|
| 21 |
-
|
| 22 |
-
def __init__(self, hidden_size, max_width, trainable=True):
|
| 23 |
-
super().__init__()
|
| 24 |
-
|
| 25 |
-
self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
|
| 26 |
-
|
| 27 |
-
nn.init.uniform_(self.query_seg, a=-1, b=1)
|
| 28 |
-
|
| 29 |
-
if not trainable:
|
| 30 |
-
self.query_seg.requires_grad = False
|
| 31 |
-
|
| 32 |
-
self.project = nn.Sequential(
|
| 33 |
-
nn.Linear(hidden_size, hidden_size),
|
| 34 |
-
nn.ReLU()
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
def forward(self, h, *args):
|
| 38 |
-
# h of shape [B, L, D]
|
| 39 |
-
# query_seg of shape [D, max_width]
|
| 40 |
-
|
| 41 |
-
span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg)
|
| 42 |
-
|
| 43 |
-
return self.project(span_rep)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class SpanMLP(nn.Module):
|
| 47 |
-
|
| 48 |
-
def __init__(self, hidden_size, max_width):
|
| 49 |
-
super().__init__()
|
| 50 |
-
|
| 51 |
-
self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
|
| 52 |
-
|
| 53 |
-
def forward(self, h, *args):
|
| 54 |
-
# h of shape [B, L, D]
|
| 55 |
-
# query_seg of shape [D, max_width]
|
| 56 |
-
|
| 57 |
-
B, L, D = h.size()
|
| 58 |
-
|
| 59 |
-
span_rep = self.mlp(h)
|
| 60 |
-
|
| 61 |
-
span_rep = span_rep.view(B, L, -1, D)
|
| 62 |
-
|
| 63 |
-
return span_rep.relu()
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
class SpanCAT(nn.Module):
|
| 67 |
-
|
| 68 |
-
def __init__(self, hidden_size, max_width):
|
| 69 |
-
super().__init__()
|
| 70 |
-
|
| 71 |
-
self.max_width = max_width
|
| 72 |
-
|
| 73 |
-
self.query_seg = nn.Parameter(torch.randn(128, max_width))
|
| 74 |
-
|
| 75 |
-
self.project = nn.Sequential(
|
| 76 |
-
nn.Linear(hidden_size + 128, hidden_size),
|
| 77 |
-
nn.ReLU()
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
def forward(self, h, *args):
|
| 81 |
-
# h of shape [B, L, D]
|
| 82 |
-
# query_seg of shape [D, max_width]
|
| 83 |
-
|
| 84 |
-
B, L, D = h.size()
|
| 85 |
-
|
| 86 |
-
h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
|
| 87 |
-
|
| 88 |
-
q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
|
| 89 |
-
|
| 90 |
-
span_rep = torch.cat([h, q], dim=-1)
|
| 91 |
-
|
| 92 |
-
span_rep = self.project(span_rep)
|
| 93 |
-
|
| 94 |
-
return span_rep
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
class SpanConvBlock(nn.Module):
|
| 98 |
-
def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
|
| 99 |
-
super().__init__()
|
| 100 |
-
|
| 101 |
-
if span_mode == 'conv_conv':
|
| 102 |
-
self.conv = nn.Conv1d(hidden_size, hidden_size,
|
| 103 |
-
kernel_size=kernel_size)
|
| 104 |
-
|
| 105 |
-
# initialize the weights
|
| 106 |
-
nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
|
| 107 |
-
|
| 108 |
-
elif span_mode == 'conv_max':
|
| 109 |
-
self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
|
| 110 |
-
elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
|
| 111 |
-
self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
|
| 112 |
-
|
| 113 |
-
self.span_mode = span_mode
|
| 114 |
-
|
| 115 |
-
self.pad = kernel_size - 1
|
| 116 |
-
|
| 117 |
-
def forward(self, x):
|
| 118 |
-
|
| 119 |
-
x = torch.einsum('bld->bdl', x)
|
| 120 |
-
|
| 121 |
-
if self.pad > 0:
|
| 122 |
-
x = F.pad(x, (0, self.pad), "constant", 0)
|
| 123 |
-
|
| 124 |
-
x = self.conv(x)
|
| 125 |
-
|
| 126 |
-
if self.span_mode == "conv_sum":
|
| 127 |
-
x = x * (self.pad + 1)
|
| 128 |
-
|
| 129 |
-
return torch.einsum('bdl->bld', x)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
class SpanConv(nn.Module):
|
| 133 |
-
def __init__(self, hidden_size, max_width, span_mode):
|
| 134 |
-
super().__init__()
|
| 135 |
-
|
| 136 |
-
kernels = [i + 2 for i in range(max_width - 1)]
|
| 137 |
-
|
| 138 |
-
self.convs = nn.ModuleList()
|
| 139 |
-
|
| 140 |
-
for kernel in kernels:
|
| 141 |
-
self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
|
| 142 |
-
|
| 143 |
-
self.project = nn.Sequential(
|
| 144 |
-
nn.ReLU(),
|
| 145 |
-
nn.Linear(hidden_size, hidden_size)
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
def forward(self, x, *args):
|
| 149 |
-
|
| 150 |
-
span_reps = [x]
|
| 151 |
-
|
| 152 |
-
for conv in self.convs:
|
| 153 |
-
h = conv(x)
|
| 154 |
-
span_reps.append(h)
|
| 155 |
-
|
| 156 |
-
span_reps = torch.stack(span_reps, dim=-2)
|
| 157 |
-
|
| 158 |
-
return self.project(span_reps)
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
class SpanEndpointsBlock(nn.Module):
|
| 162 |
-
def __init__(self, kernel_size):
|
| 163 |
-
super().__init__()
|
| 164 |
-
|
| 165 |
-
self.kernel_size = kernel_size
|
| 166 |
-
|
| 167 |
-
def forward(self, x):
|
| 168 |
-
B, L, D = x.size()
|
| 169 |
-
|
| 170 |
-
span_idx = torch.LongTensor(
|
| 171 |
-
[[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
|
| 172 |
-
|
| 173 |
-
x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
|
| 174 |
-
|
| 175 |
-
# endrep
|
| 176 |
-
start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
|
| 177 |
-
|
| 178 |
-
start_end_rep = start_end_rep.view(B, L, 2, D)
|
| 179 |
-
|
| 180 |
-
return start_end_rep
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
class ConvShare(nn.Module):
|
| 184 |
-
def __init__(self, hidden_size, max_width):
|
| 185 |
-
super().__init__()
|
| 186 |
-
|
| 187 |
-
self.max_width = max_width
|
| 188 |
-
|
| 189 |
-
self.conv_weigth = nn.Parameter(
|
| 190 |
-
torch.randn(hidden_size, hidden_size, max_width))
|
| 191 |
-
|
| 192 |
-
nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu')
|
| 193 |
-
|
| 194 |
-
self.project = nn.Sequential(
|
| 195 |
-
nn.ReLU(),
|
| 196 |
-
nn.Linear(hidden_size, hidden_size)
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
def forward(self, x, *args):
|
| 200 |
-
span_reps = []
|
| 201 |
-
|
| 202 |
-
x = torch.einsum('bld->bdl', x)
|
| 203 |
-
|
| 204 |
-
for i in range(self.max_width):
|
| 205 |
-
pad = i
|
| 206 |
-
x_i = F.pad(x, (0, pad), "constant", 0)
|
| 207 |
-
conv_w = self.conv_weigth[:, :, :i + 1]
|
| 208 |
-
out_i = F.conv1d(x_i, conv_w)
|
| 209 |
-
span_reps.append(out_i.transpose(-1, -2))
|
| 210 |
-
|
| 211 |
-
out = torch.stack(span_reps, dim=-2)
|
| 212 |
-
|
| 213 |
-
return self.project(out)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
def extract_elements(sequence, indices):
|
| 217 |
-
B, L, D = sequence.shape
|
| 218 |
-
K = indices.shape[1]
|
| 219 |
-
|
| 220 |
-
# Expand indices to [B, K, D]
|
| 221 |
-
expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
|
| 222 |
-
|
| 223 |
-
# Gather the elements
|
| 224 |
-
extracted_elements = torch.gather(sequence, 1, expanded_indices)
|
| 225 |
-
|
| 226 |
-
return extracted_elements
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
class SpanMarker(nn.Module):
|
| 230 |
-
|
| 231 |
-
def __init__(self, hidden_size, max_width, dropout=0.4):
|
| 232 |
-
super().__init__()
|
| 233 |
-
|
| 234 |
-
self.max_width = max_width
|
| 235 |
-
|
| 236 |
-
self.project_start = nn.Sequential(
|
| 237 |
-
nn.Linear(hidden_size, hidden_size * 2, bias=True),
|
| 238 |
-
nn.ReLU(),
|
| 239 |
-
nn.Dropout(dropout),
|
| 240 |
-
nn.Linear(hidden_size * 2, hidden_size, bias=True),
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
self.project_end = nn.Sequential(
|
| 244 |
-
nn.Linear(hidden_size, hidden_size * 2, bias=True),
|
| 245 |
-
nn.ReLU(),
|
| 246 |
-
nn.Dropout(dropout),
|
| 247 |
-
nn.Linear(hidden_size * 2, hidden_size, bias=True),
|
| 248 |
-
)
|
| 249 |
-
|
| 250 |
-
self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
|
| 251 |
-
|
| 252 |
-
def forward(self, h, span_idx):
|
| 253 |
-
# h of shape [B, L, D]
|
| 254 |
-
# query_seg of shape [D, max_width]
|
| 255 |
-
|
| 256 |
-
B, L, D = h.size()
|
| 257 |
-
|
| 258 |
-
# project start and end
|
| 259 |
-
start_rep = self.project_start(h)
|
| 260 |
-
end_rep = self.project_end(h)
|
| 261 |
-
|
| 262 |
-
start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
|
| 263 |
-
end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
|
| 264 |
-
|
| 265 |
-
# concat start and end
|
| 266 |
-
cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
|
| 267 |
-
|
| 268 |
-
# project
|
| 269 |
-
cat = self.out_project(cat)
|
| 270 |
-
|
| 271 |
-
# reshape
|
| 272 |
-
return cat.view(B, L, self.max_width, D)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
class SpanMarkerV0(nn.Module):
|
| 276 |
-
"""
|
| 277 |
-
Marks and projects span endpoints using an MLP.
|
| 278 |
-
"""
|
| 279 |
-
|
| 280 |
-
def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
|
| 281 |
-
super().__init__()
|
| 282 |
-
self.max_width = max_width
|
| 283 |
-
self.project_start = create_projection_layer(hidden_size, dropout)
|
| 284 |
-
self.project_end = create_projection_layer(hidden_size, dropout)
|
| 285 |
-
|
| 286 |
-
self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
|
| 287 |
-
|
| 288 |
-
def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
|
| 289 |
-
B, L, D = h.size()
|
| 290 |
-
|
| 291 |
-
start_rep = self.project_start(h)
|
| 292 |
-
end_rep = self.project_end(h)
|
| 293 |
-
|
| 294 |
-
start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
|
| 295 |
-
end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
|
| 296 |
-
|
| 297 |
-
cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
|
| 298 |
-
|
| 299 |
-
return self.out_project(cat).view(B, L, self.max_width, D)
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
class ConvShareV2(nn.Module):
|
| 303 |
-
def __init__(self, hidden_size, max_width):
|
| 304 |
-
super().__init__()
|
| 305 |
-
|
| 306 |
-
self.max_width = max_width
|
| 307 |
-
|
| 308 |
-
self.conv_weigth = nn.Parameter(
|
| 309 |
-
torch.randn(hidden_size, hidden_size, max_width)
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
nn.init.xavier_normal_(self.conv_weigth)
|
| 313 |
-
|
| 314 |
-
def forward(self, x, *args):
|
| 315 |
-
span_reps = []
|
| 316 |
-
|
| 317 |
-
x = torch.einsum('bld->bdl', x)
|
| 318 |
-
|
| 319 |
-
for i in range(self.max_width):
|
| 320 |
-
pad = i
|
| 321 |
-
x_i = F.pad(x, (0, pad), "constant", 0)
|
| 322 |
-
conv_w = self.conv_weigth[:, :, :i + 1]
|
| 323 |
-
out_i = F.conv1d(x_i, conv_w)
|
| 324 |
-
span_reps.append(out_i.transpose(-1, -2))
|
| 325 |
-
|
| 326 |
-
out = torch.stack(span_reps, dim=-2)
|
| 327 |
-
|
| 328 |
-
return out
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
class SpanRepLayer(nn.Module):
|
| 332 |
-
"""
|
| 333 |
-
Various span representation approaches
|
| 334 |
-
"""
|
| 335 |
-
|
| 336 |
-
def __init__(self, hidden_size, max_width, span_mode, **kwargs):
|
| 337 |
-
super().__init__()
|
| 338 |
-
|
| 339 |
-
if span_mode == 'marker':
|
| 340 |
-
self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
|
| 341 |
-
elif span_mode == 'markerV0':
|
| 342 |
-
self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs)
|
| 343 |
-
elif span_mode == 'query':
|
| 344 |
-
self.span_rep_layer = SpanQuery(
|
| 345 |
-
hidden_size, max_width, trainable=True)
|
| 346 |
-
elif span_mode == 'mlp':
|
| 347 |
-
self.span_rep_layer = SpanMLP(hidden_size, max_width)
|
| 348 |
-
elif span_mode == 'cat':
|
| 349 |
-
self.span_rep_layer = SpanCAT(hidden_size, max_width)
|
| 350 |
-
elif span_mode == 'conv_conv':
|
| 351 |
-
self.span_rep_layer = SpanConv(
|
| 352 |
-
hidden_size, max_width, span_mode='conv_conv')
|
| 353 |
-
elif span_mode == 'conv_max':
|
| 354 |
-
self.span_rep_layer = SpanConv(
|
| 355 |
-
hidden_size, max_width, span_mode='conv_max')
|
| 356 |
-
elif span_mode == 'conv_mean':
|
| 357 |
-
self.span_rep_layer = SpanConv(
|
| 358 |
-
hidden_size, max_width, span_mode='conv_mean')
|
| 359 |
-
elif span_mode == 'conv_sum':
|
| 360 |
-
self.span_rep_layer = SpanConv(
|
| 361 |
-
hidden_size, max_width, span_mode='conv_sum')
|
| 362 |
-
elif span_mode == 'conv_share':
|
| 363 |
-
self.span_rep_layer = ConvShare(hidden_size, max_width)
|
| 364 |
-
else:
|
| 365 |
-
raise ValueError(f'Unknown span mode {span_mode}')
|
| 366 |
-
|
| 367 |
-
def forward(self, x, *args):
|
| 368 |
-
|
| 369 |
-
return self.span_rep_layer(x, *args)
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
|
| 6 |
+
"""
|
| 7 |
+
Creates a projection layer with specified configurations.
|
| 8 |
+
"""
|
| 9 |
+
if out_dim is None:
|
| 10 |
+
out_dim = hidden_size
|
| 11 |
+
|
| 12 |
+
return nn.Sequential(
|
| 13 |
+
nn.Linear(hidden_size, out_dim * 4),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.Dropout(dropout),
|
| 16 |
+
nn.Linear(out_dim * 4, out_dim)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SpanQuery(nn.Module):
|
| 21 |
+
|
| 22 |
+
def __init__(self, hidden_size, max_width, trainable=True):
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
|
| 26 |
+
|
| 27 |
+
nn.init.uniform_(self.query_seg, a=-1, b=1)
|
| 28 |
+
|
| 29 |
+
if not trainable:
|
| 30 |
+
self.query_seg.requires_grad = False
|
| 31 |
+
|
| 32 |
+
self.project = nn.Sequential(
|
| 33 |
+
nn.Linear(hidden_size, hidden_size),
|
| 34 |
+
nn.ReLU()
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def forward(self, h, *args):
|
| 38 |
+
# h of shape [B, L, D]
|
| 39 |
+
# query_seg of shape [D, max_width]
|
| 40 |
+
|
| 41 |
+
span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg)
|
| 42 |
+
|
| 43 |
+
return self.project(span_rep)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SpanMLP(nn.Module):
|
| 47 |
+
|
| 48 |
+
def __init__(self, hidden_size, max_width):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
|
| 52 |
+
|
| 53 |
+
def forward(self, h, *args):
|
| 54 |
+
# h of shape [B, L, D]
|
| 55 |
+
# query_seg of shape [D, max_width]
|
| 56 |
+
|
| 57 |
+
B, L, D = h.size()
|
| 58 |
+
|
| 59 |
+
span_rep = self.mlp(h)
|
| 60 |
+
|
| 61 |
+
span_rep = span_rep.view(B, L, -1, D)
|
| 62 |
+
|
| 63 |
+
return span_rep.relu()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class SpanCAT(nn.Module):
|
| 67 |
+
|
| 68 |
+
def __init__(self, hidden_size, max_width):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.max_width = max_width
|
| 72 |
+
|
| 73 |
+
self.query_seg = nn.Parameter(torch.randn(128, max_width))
|
| 74 |
+
|
| 75 |
+
self.project = nn.Sequential(
|
| 76 |
+
nn.Linear(hidden_size + 128, hidden_size),
|
| 77 |
+
nn.ReLU()
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, h, *args):
|
| 81 |
+
# h of shape [B, L, D]
|
| 82 |
+
# query_seg of shape [D, max_width]
|
| 83 |
+
|
| 84 |
+
B, L, D = h.size()
|
| 85 |
+
|
| 86 |
+
h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
|
| 87 |
+
|
| 88 |
+
q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
|
| 89 |
+
|
| 90 |
+
span_rep = torch.cat([h, q], dim=-1)
|
| 91 |
+
|
| 92 |
+
span_rep = self.project(span_rep)
|
| 93 |
+
|
| 94 |
+
return span_rep
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class SpanConvBlock(nn.Module):
|
| 98 |
+
def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
if span_mode == 'conv_conv':
|
| 102 |
+
self.conv = nn.Conv1d(hidden_size, hidden_size,
|
| 103 |
+
kernel_size=kernel_size)
|
| 104 |
+
|
| 105 |
+
# initialize the weights
|
| 106 |
+
nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
|
| 107 |
+
|
| 108 |
+
elif span_mode == 'conv_max':
|
| 109 |
+
self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
|
| 110 |
+
elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
|
| 111 |
+
self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
|
| 112 |
+
|
| 113 |
+
self.span_mode = span_mode
|
| 114 |
+
|
| 115 |
+
self.pad = kernel_size - 1
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
|
| 119 |
+
x = torch.einsum('bld->bdl', x)
|
| 120 |
+
|
| 121 |
+
if self.pad > 0:
|
| 122 |
+
x = F.pad(x, (0, self.pad), "constant", 0)
|
| 123 |
+
|
| 124 |
+
x = self.conv(x)
|
| 125 |
+
|
| 126 |
+
if self.span_mode == "conv_sum":
|
| 127 |
+
x = x * (self.pad + 1)
|
| 128 |
+
|
| 129 |
+
return torch.einsum('bdl->bld', x)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class SpanConv(nn.Module):
|
| 133 |
+
def __init__(self, hidden_size, max_width, span_mode):
|
| 134 |
+
super().__init__()
|
| 135 |
+
|
| 136 |
+
kernels = [i + 2 for i in range(max_width - 1)]
|
| 137 |
+
|
| 138 |
+
self.convs = nn.ModuleList()
|
| 139 |
+
|
| 140 |
+
for kernel in kernels:
|
| 141 |
+
self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
|
| 142 |
+
|
| 143 |
+
self.project = nn.Sequential(
|
| 144 |
+
nn.ReLU(),
|
| 145 |
+
nn.Linear(hidden_size, hidden_size)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def forward(self, x, *args):
|
| 149 |
+
|
| 150 |
+
span_reps = [x]
|
| 151 |
+
|
| 152 |
+
for conv in self.convs:
|
| 153 |
+
h = conv(x)
|
| 154 |
+
span_reps.append(h)
|
| 155 |
+
|
| 156 |
+
span_reps = torch.stack(span_reps, dim=-2)
|
| 157 |
+
|
| 158 |
+
return self.project(span_reps)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class SpanEndpointsBlock(nn.Module):
|
| 162 |
+
def __init__(self, kernel_size):
|
| 163 |
+
super().__init__()
|
| 164 |
+
|
| 165 |
+
self.kernel_size = kernel_size
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
B, L, D = x.size()
|
| 169 |
+
|
| 170 |
+
span_idx = torch.LongTensor(
|
| 171 |
+
[[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
|
| 172 |
+
|
| 173 |
+
x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
|
| 174 |
+
|
| 175 |
+
# endrep
|
| 176 |
+
start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
|
| 177 |
+
|
| 178 |
+
start_end_rep = start_end_rep.view(B, L, 2, D)
|
| 179 |
+
|
| 180 |
+
return start_end_rep
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class ConvShare(nn.Module):
|
| 184 |
+
def __init__(self, hidden_size, max_width):
|
| 185 |
+
super().__init__()
|
| 186 |
+
|
| 187 |
+
self.max_width = max_width
|
| 188 |
+
|
| 189 |
+
self.conv_weigth = nn.Parameter(
|
| 190 |
+
torch.randn(hidden_size, hidden_size, max_width))
|
| 191 |
+
|
| 192 |
+
nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu')
|
| 193 |
+
|
| 194 |
+
self.project = nn.Sequential(
|
| 195 |
+
nn.ReLU(),
|
| 196 |
+
nn.Linear(hidden_size, hidden_size)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def forward(self, x, *args):
|
| 200 |
+
span_reps = []
|
| 201 |
+
|
| 202 |
+
x = torch.einsum('bld->bdl', x)
|
| 203 |
+
|
| 204 |
+
for i in range(self.max_width):
|
| 205 |
+
pad = i
|
| 206 |
+
x_i = F.pad(x, (0, pad), "constant", 0)
|
| 207 |
+
conv_w = self.conv_weigth[:, :, :i + 1]
|
| 208 |
+
out_i = F.conv1d(x_i, conv_w)
|
| 209 |
+
span_reps.append(out_i.transpose(-1, -2))
|
| 210 |
+
|
| 211 |
+
out = torch.stack(span_reps, dim=-2)
|
| 212 |
+
|
| 213 |
+
return self.project(out)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def extract_elements(sequence, indices):
|
| 217 |
+
B, L, D = sequence.shape
|
| 218 |
+
K = indices.shape[1]
|
| 219 |
+
|
| 220 |
+
# Expand indices to [B, K, D]
|
| 221 |
+
expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
|
| 222 |
+
|
| 223 |
+
# Gather the elements
|
| 224 |
+
extracted_elements = torch.gather(sequence, 1, expanded_indices)
|
| 225 |
+
|
| 226 |
+
return extracted_elements
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class SpanMarker(nn.Module):
|
| 230 |
+
|
| 231 |
+
def __init__(self, hidden_size, max_width, dropout=0.4):
|
| 232 |
+
super().__init__()
|
| 233 |
+
|
| 234 |
+
self.max_width = max_width
|
| 235 |
+
|
| 236 |
+
self.project_start = nn.Sequential(
|
| 237 |
+
nn.Linear(hidden_size, hidden_size * 2, bias=True),
|
| 238 |
+
nn.ReLU(),
|
| 239 |
+
nn.Dropout(dropout),
|
| 240 |
+
nn.Linear(hidden_size * 2, hidden_size, bias=True),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.project_end = nn.Sequential(
|
| 244 |
+
nn.Linear(hidden_size, hidden_size * 2, bias=True),
|
| 245 |
+
nn.ReLU(),
|
| 246 |
+
nn.Dropout(dropout),
|
| 247 |
+
nn.Linear(hidden_size * 2, hidden_size, bias=True),
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
|
| 251 |
+
|
| 252 |
+
def forward(self, h, span_idx):
|
| 253 |
+
# h of shape [B, L, D]
|
| 254 |
+
# query_seg of shape [D, max_width]
|
| 255 |
+
|
| 256 |
+
B, L, D = h.size()
|
| 257 |
+
|
| 258 |
+
# project start and end
|
| 259 |
+
start_rep = self.project_start(h)
|
| 260 |
+
end_rep = self.project_end(h)
|
| 261 |
+
|
| 262 |
+
start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
|
| 263 |
+
end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
|
| 264 |
+
|
| 265 |
+
# concat start and end
|
| 266 |
+
cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
|
| 267 |
+
|
| 268 |
+
# project
|
| 269 |
+
cat = self.out_project(cat)
|
| 270 |
+
|
| 271 |
+
# reshape
|
| 272 |
+
return cat.view(B, L, self.max_width, D)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class SpanMarkerV0(nn.Module):
|
| 276 |
+
"""
|
| 277 |
+
Marks and projects span endpoints using an MLP.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
|
| 281 |
+
super().__init__()
|
| 282 |
+
self.max_width = max_width
|
| 283 |
+
self.project_start = create_projection_layer(hidden_size, dropout)
|
| 284 |
+
self.project_end = create_projection_layer(hidden_size, dropout)
|
| 285 |
+
|
| 286 |
+
self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
|
| 287 |
+
|
| 288 |
+
def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
|
| 289 |
+
B, L, D = h.size()
|
| 290 |
+
|
| 291 |
+
start_rep = self.project_start(h)
|
| 292 |
+
end_rep = self.project_end(h)
|
| 293 |
+
|
| 294 |
+
start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
|
| 295 |
+
end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
|
| 296 |
+
|
| 297 |
+
cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
|
| 298 |
+
|
| 299 |
+
return self.out_project(cat).view(B, L, self.max_width, D)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class ConvShareV2(nn.Module):
|
| 303 |
+
def __init__(self, hidden_size, max_width):
|
| 304 |
+
super().__init__()
|
| 305 |
+
|
| 306 |
+
self.max_width = max_width
|
| 307 |
+
|
| 308 |
+
self.conv_weigth = nn.Parameter(
|
| 309 |
+
torch.randn(hidden_size, hidden_size, max_width)
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
nn.init.xavier_normal_(self.conv_weigth)
|
| 313 |
+
|
| 314 |
+
def forward(self, x, *args):
|
| 315 |
+
span_reps = []
|
| 316 |
+
|
| 317 |
+
x = torch.einsum('bld->bdl', x)
|
| 318 |
+
|
| 319 |
+
for i in range(self.max_width):
|
| 320 |
+
pad = i
|
| 321 |
+
x_i = F.pad(x, (0, pad), "constant", 0)
|
| 322 |
+
conv_w = self.conv_weigth[:, :, :i + 1]
|
| 323 |
+
out_i = F.conv1d(x_i, conv_w)
|
| 324 |
+
span_reps.append(out_i.transpose(-1, -2))
|
| 325 |
+
|
| 326 |
+
out = torch.stack(span_reps, dim=-2)
|
| 327 |
+
|
| 328 |
+
return out
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class SpanRepLayer(nn.Module):
|
| 332 |
+
"""
|
| 333 |
+
Various span representation approaches
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
def __init__(self, hidden_size, max_width, span_mode, **kwargs):
|
| 337 |
+
super().__init__()
|
| 338 |
+
|
| 339 |
+
if span_mode == 'marker':
|
| 340 |
+
self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
|
| 341 |
+
elif span_mode == 'markerV0':
|
| 342 |
+
self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs)
|
| 343 |
+
elif span_mode == 'query':
|
| 344 |
+
self.span_rep_layer = SpanQuery(
|
| 345 |
+
hidden_size, max_width, trainable=True)
|
| 346 |
+
elif span_mode == 'mlp':
|
| 347 |
+
self.span_rep_layer = SpanMLP(hidden_size, max_width)
|
| 348 |
+
elif span_mode == 'cat':
|
| 349 |
+
self.span_rep_layer = SpanCAT(hidden_size, max_width)
|
| 350 |
+
elif span_mode == 'conv_conv':
|
| 351 |
+
self.span_rep_layer = SpanConv(
|
| 352 |
+
hidden_size, max_width, span_mode='conv_conv')
|
| 353 |
+
elif span_mode == 'conv_max':
|
| 354 |
+
self.span_rep_layer = SpanConv(
|
| 355 |
+
hidden_size, max_width, span_mode='conv_max')
|
| 356 |
+
elif span_mode == 'conv_mean':
|
| 357 |
+
self.span_rep_layer = SpanConv(
|
| 358 |
+
hidden_size, max_width, span_mode='conv_mean')
|
| 359 |
+
elif span_mode == 'conv_sum':
|
| 360 |
+
self.span_rep_layer = SpanConv(
|
| 361 |
+
hidden_size, max_width, span_mode='conv_sum')
|
| 362 |
+
elif span_mode == 'conv_share':
|
| 363 |
+
self.span_rep_layer = ConvShare(hidden_size, max_width)
|
| 364 |
+
else:
|
| 365 |
+
raise ValueError(f'Unknown span mode {span_mode}')
|
| 366 |
+
|
| 367 |
+
def forward(self, x, *args):
|
| 368 |
+
|
| 369 |
+
return self.span_rep_layer(x, *args)
|
backup/modules/token_rep.py
CHANGED
|
@@ -1,54 +1,54 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from flair.data import Sentence
|
| 5 |
-
from flair.embeddings import TransformerWordEmbeddings
|
| 6 |
-
from torch import nn
|
| 7 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# flair.cache_root = '/gpfswork/rech/pds/upa43yu/.cache'
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class TokenRepLayer(nn.Module):
|
| 14 |
-
def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
|
| 15 |
-
hidden_size: int = 768,
|
| 16 |
-
add_tokens=["[SEP]", "[ENT]"]
|
| 17 |
-
):
|
| 18 |
-
super().__init__()
|
| 19 |
-
|
| 20 |
-
self.bert_layer = TransformerWordEmbeddings(
|
| 21 |
-
model_name,
|
| 22 |
-
fine_tune=fine_tune,
|
| 23 |
-
subtoken_pooling=subtoken_pooling,
|
| 24 |
-
allow_long_sentences=True
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
# add tokens to vocabulary
|
| 28 |
-
self.bert_layer.tokenizer.add_tokens(add_tokens)
|
| 29 |
-
|
| 30 |
-
# resize token embeddings
|
| 31 |
-
self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))
|
| 32 |
-
|
| 33 |
-
bert_hidden_size = self.bert_layer.embedding_length
|
| 34 |
-
|
| 35 |
-
if hidden_size != bert_hidden_size:
|
| 36 |
-
self.projection = nn.Linear(bert_hidden_size, hidden_size)
|
| 37 |
-
|
| 38 |
-
def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
|
| 39 |
-
token_embeddings = self.compute_word_embedding(tokens)
|
| 40 |
-
|
| 41 |
-
if hasattr(self, "projection"):
|
| 42 |
-
token_embeddings = self.projection(token_embeddings)
|
| 43 |
-
|
| 44 |
-
B = len(lengths)
|
| 45 |
-
max_length = lengths.max()
|
| 46 |
-
mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to(
|
| 47 |
-
token_embeddings.device).long()
|
| 48 |
-
return {"embeddings": token_embeddings, "mask": mask}
|
| 49 |
-
|
| 50 |
-
def compute_word_embedding(self, tokens):
|
| 51 |
-
sentences = [Sentence(i) for i in tokens]
|
| 52 |
-
self.bert_layer.embed(sentences)
|
| 53 |
-
token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True)
|
| 54 |
-
return token_embeddings
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from flair.data import Sentence
|
| 5 |
+
from flair.embeddings import TransformerWordEmbeddings
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# flair.cache_root = '/gpfswork/rech/pds/upa43yu/.cache'
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TokenRepLayer(nn.Module):
|
| 14 |
+
def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
|
| 15 |
+
hidden_size: int = 768,
|
| 16 |
+
add_tokens=["[SEP]", "[ENT]"]
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
self.bert_layer = TransformerWordEmbeddings(
|
| 21 |
+
model_name,
|
| 22 |
+
fine_tune=fine_tune,
|
| 23 |
+
subtoken_pooling=subtoken_pooling,
|
| 24 |
+
allow_long_sentences=True
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# add tokens to vocabulary
|
| 28 |
+
self.bert_layer.tokenizer.add_tokens(add_tokens)
|
| 29 |
+
|
| 30 |
+
# resize token embeddings
|
| 31 |
+
self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))
|
| 32 |
+
|
| 33 |
+
bert_hidden_size = self.bert_layer.embedding_length
|
| 34 |
+
|
| 35 |
+
if hidden_size != bert_hidden_size:
|
| 36 |
+
self.projection = nn.Linear(bert_hidden_size, hidden_size)
|
| 37 |
+
|
| 38 |
+
def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
|
| 39 |
+
token_embeddings = self.compute_word_embedding(tokens)
|
| 40 |
+
|
| 41 |
+
if hasattr(self, "projection"):
|
| 42 |
+
token_embeddings = self.projection(token_embeddings)
|
| 43 |
+
|
| 44 |
+
B = len(lengths)
|
| 45 |
+
max_length = lengths.max()
|
| 46 |
+
mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to(
|
| 47 |
+
token_embeddings.device).long()
|
| 48 |
+
return {"embeddings": token_embeddings, "mask": mask}
|
| 49 |
+
|
| 50 |
+
def compute_word_embedding(self, tokens):
|
| 51 |
+
sentences = [Sentence(i) for i in tokens]
|
| 52 |
+
self.bert_layer.embed(sentences)
|
| 53 |
+
token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True)
|
| 54 |
+
return token_embeddings
|
backup/requirements.txt
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
torch
|
| 2 |
-
transformers
|
| 3 |
-
huggingface_hub
|
| 4 |
-
flair
|
| 5 |
-
seqeval
|
| 6 |
tqdm
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
huggingface_hub
|
| 4 |
+
flair
|
| 5 |
+
seqeval
|
| 6 |
tqdm
|
backup/save_load.py
CHANGED
|
@@ -1,20 +1,20 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from .model import GLiNER
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def save_model(current_model, path):
|
| 6 |
-
config = current_model.config
|
| 7 |
-
dict_save = {"model_weights": current_model.state_dict(), "config": config}
|
| 8 |
-
torch.save(dict_save, path)
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def load_model(path, model_name=None, device=None):
|
| 12 |
-
dict_load = torch.load(path, map_location=torch.device('cpu'))
|
| 13 |
-
config = dict_load["config"]
|
| 14 |
-
|
| 15 |
-
if model_name is not None:
|
| 16 |
-
config.model_name = model_name
|
| 17 |
-
|
| 18 |
-
loaded_model = GLiNER(config)
|
| 19 |
-
loaded_model.load_state_dict(dict_load["model_weights"])
|
| 20 |
-
return loaded_model.to(device) if device is not None else loaded_model
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .model import GLiNER
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def save_model(current_model, path):
|
| 6 |
+
config = current_model.config
|
| 7 |
+
dict_save = {"model_weights": current_model.state_dict(), "config": config}
|
| 8 |
+
torch.save(dict_save, path)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_model(path, model_name=None, device=None):
|
| 12 |
+
dict_load = torch.load(path, map_location=torch.device('cpu'))
|
| 13 |
+
config = dict_load["config"]
|
| 14 |
+
|
| 15 |
+
if model_name is not None:
|
| 16 |
+
config.model_name = model_name
|
| 17 |
+
|
| 18 |
+
loaded_model = GLiNER(config)
|
| 19 |
+
loaded_model.load_state_dict(dict_load["model_weights"])
|
| 20 |
+
return loaded_model.to(device) if device is not None else loaded_model
|
backup/train.py
CHANGED
|
@@ -1,132 +1,132 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import yaml
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
from transformers import get_cosine_schedule_with_warmup
|
| 8 |
-
|
| 9 |
-
# from model_nested import NerFilteredSemiCRF
|
| 10 |
-
from .model import GLiNER
|
| 11 |
-
from .modules.run_evaluation import get_for_all_path, sample_train_data
|
| 12 |
-
from .save_load import save_model, load_model
|
| 13 |
-
|
| 14 |
-
import json
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
# train function
|
| 18 |
-
def train(model, optimizer, train_data, num_steps=1000, eval_every=100, log_dir="logs", warmup_ratio=0.1,
|
| 19 |
-
train_batch_size=8, device='cuda'):
|
| 20 |
-
model.train()
|
| 21 |
-
|
| 22 |
-
# initialize data loaders
|
| 23 |
-
train_loader = model.create_dataloader(train_data, batch_size=train_batch_size, shuffle=True)
|
| 24 |
-
|
| 25 |
-
pbar = tqdm(range(num_steps))
|
| 26 |
-
|
| 27 |
-
if warmup_ratio < 1:
|
| 28 |
-
num_warmup_steps = int(num_steps * warmup_ratio)
|
| 29 |
-
else:
|
| 30 |
-
num_warmup_steps = int(warmup_ratio)
|
| 31 |
-
|
| 32 |
-
scheduler = get_cosine_schedule_with_warmup(
|
| 33 |
-
optimizer,
|
| 34 |
-
num_warmup_steps=num_warmup_steps,
|
| 35 |
-
num_training_steps=num_steps
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
iter_train_loader = iter(train_loader)
|
| 39 |
-
|
| 40 |
-
for step in pbar:
|
| 41 |
-
try:
|
| 42 |
-
x = next(iter_train_loader)
|
| 43 |
-
except StopIteration:
|
| 44 |
-
iter_train_loader = iter(train_loader)
|
| 45 |
-
x = next(iter_train_loader)
|
| 46 |
-
|
| 47 |
-
for k, v in x.items():
|
| 48 |
-
if isinstance(v, torch.Tensor):
|
| 49 |
-
x[k] = v.to(device)
|
| 50 |
-
|
| 51 |
-
try:
|
| 52 |
-
loss = model(x) # Forward pass
|
| 53 |
-
except:
|
| 54 |
-
continue
|
| 55 |
-
|
| 56 |
-
# check if loss is nan
|
| 57 |
-
if torch.isnan(loss):
|
| 58 |
-
continue
|
| 59 |
-
|
| 60 |
-
loss.backward() # Compute gradients
|
| 61 |
-
optimizer.step() # Update parameters
|
| 62 |
-
scheduler.step() # Update learning rate schedule
|
| 63 |
-
optimizer.zero_grad() # Reset gradients
|
| 64 |
-
|
| 65 |
-
description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
|
| 66 |
-
|
| 67 |
-
if (step + 1) % eval_every == 0:
|
| 68 |
-
current_path = os.path.join(log_dir, f'model_{step + 1}')
|
| 69 |
-
save_model(model, current_path)
|
| 70 |
-
#val_data_dir = "/gpfswork/rech/ohy/upa43yu/NER_datasets" # can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
|
| 71 |
-
#get_for_all_path(model, step, log_dir, val_data_dir) # you can remove this comment if you want to evaluate the model
|
| 72 |
-
|
| 73 |
-
model.train()
|
| 74 |
-
|
| 75 |
-
pbar.set_description(description)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def create_parser():
|
| 79 |
-
parser = argparse.ArgumentParser(description="Span-based NER")
|
| 80 |
-
parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
|
| 81 |
-
parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
|
| 82 |
-
return parser
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def load_config_as_namespace(config_file):
|
| 86 |
-
with open(config_file, 'r') as f:
|
| 87 |
-
config_dict = yaml.safe_load(f)
|
| 88 |
-
return argparse.Namespace(**config_dict)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
if __name__ == "__main__":
|
| 92 |
-
# parse args
|
| 93 |
-
parser = create_parser()
|
| 94 |
-
args = parser.parse_args()
|
| 95 |
-
|
| 96 |
-
# load config
|
| 97 |
-
config = load_config_as_namespace(args.config)
|
| 98 |
-
|
| 99 |
-
config.log_dir = args.log_dir
|
| 100 |
-
|
| 101 |
-
try:
|
| 102 |
-
with open(config.train_data, 'r') as f:
|
| 103 |
-
data = json.load(f)
|
| 104 |
-
except:
|
| 105 |
-
data = sample_train_data(config.train_data, 10000)
|
| 106 |
-
|
| 107 |
-
if config.prev_path != "none":
|
| 108 |
-
model = load_model(config.prev_path)
|
| 109 |
-
model.config = config
|
| 110 |
-
else:
|
| 111 |
-
model = GLiNER(config)
|
| 112 |
-
|
| 113 |
-
if torch.cuda.is_available():
|
| 114 |
-
model = model.cuda()
|
| 115 |
-
|
| 116 |
-
lr_encoder = float(config.lr_encoder)
|
| 117 |
-
lr_others = float(config.lr_others)
|
| 118 |
-
|
| 119 |
-
optimizer = torch.optim.AdamW([
|
| 120 |
-
# encoder
|
| 121 |
-
{'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
|
| 122 |
-
{'params': model.rnn.parameters(), 'lr': lr_others},
|
| 123 |
-
# projection layers
|
| 124 |
-
{'params': model.span_rep_layer.parameters(), 'lr': lr_others},
|
| 125 |
-
{'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
|
| 126 |
-
])
|
| 127 |
-
|
| 128 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 129 |
-
|
| 130 |
-
train(model, optimizer, data, num_steps=config.num_steps, eval_every=config.eval_every,
|
| 131 |
-
log_dir=config.log_dir, warmup_ratio=config.warmup_ratio, train_batch_size=config.train_batch_size,
|
| 132 |
-
device=device)
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import yaml
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 8 |
+
|
| 9 |
+
# from model_nested import NerFilteredSemiCRF
|
| 10 |
+
from .model import GLiNER
|
| 11 |
+
from .modules.run_evaluation import get_for_all_path, sample_train_data
|
| 12 |
+
from .save_load import save_model, load_model
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# train function
|
| 18 |
+
def train(model, optimizer, train_data, num_steps=1000, eval_every=100, log_dir="logs", warmup_ratio=0.1,
|
| 19 |
+
train_batch_size=8, device='cuda'):
|
| 20 |
+
model.train()
|
| 21 |
+
|
| 22 |
+
# initialize data loaders
|
| 23 |
+
train_loader = model.create_dataloader(train_data, batch_size=train_batch_size, shuffle=True)
|
| 24 |
+
|
| 25 |
+
pbar = tqdm(range(num_steps))
|
| 26 |
+
|
| 27 |
+
if warmup_ratio < 1:
|
| 28 |
+
num_warmup_steps = int(num_steps * warmup_ratio)
|
| 29 |
+
else:
|
| 30 |
+
num_warmup_steps = int(warmup_ratio)
|
| 31 |
+
|
| 32 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 33 |
+
optimizer,
|
| 34 |
+
num_warmup_steps=num_warmup_steps,
|
| 35 |
+
num_training_steps=num_steps
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
iter_train_loader = iter(train_loader)
|
| 39 |
+
|
| 40 |
+
for step in pbar:
|
| 41 |
+
try:
|
| 42 |
+
x = next(iter_train_loader)
|
| 43 |
+
except StopIteration:
|
| 44 |
+
iter_train_loader = iter(train_loader)
|
| 45 |
+
x = next(iter_train_loader)
|
| 46 |
+
|
| 47 |
+
for k, v in x.items():
|
| 48 |
+
if isinstance(v, torch.Tensor):
|
| 49 |
+
x[k] = v.to(device)
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
loss = model(x) # Forward pass
|
| 53 |
+
except:
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
# check if loss is nan
|
| 57 |
+
if torch.isnan(loss):
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
loss.backward() # Compute gradients
|
| 61 |
+
optimizer.step() # Update parameters
|
| 62 |
+
scheduler.step() # Update learning rate schedule
|
| 63 |
+
optimizer.zero_grad() # Reset gradients
|
| 64 |
+
|
| 65 |
+
description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
|
| 66 |
+
|
| 67 |
+
if (step + 1) % eval_every == 0:
|
| 68 |
+
current_path = os.path.join(log_dir, f'model_{step + 1}')
|
| 69 |
+
save_model(model, current_path)
|
| 70 |
+
#val_data_dir = "/gpfswork/rech/ohy/upa43yu/NER_datasets" # can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
|
| 71 |
+
#get_for_all_path(model, step, log_dir, val_data_dir) # you can remove this comment if you want to evaluate the model
|
| 72 |
+
|
| 73 |
+
model.train()
|
| 74 |
+
|
| 75 |
+
pbar.set_description(description)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def create_parser():
|
| 79 |
+
parser = argparse.ArgumentParser(description="Span-based NER")
|
| 80 |
+
parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
|
| 81 |
+
parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
|
| 82 |
+
return parser
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_config_as_namespace(config_file):
|
| 86 |
+
with open(config_file, 'r') as f:
|
| 87 |
+
config_dict = yaml.safe_load(f)
|
| 88 |
+
return argparse.Namespace(**config_dict)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
# parse args
|
| 93 |
+
parser = create_parser()
|
| 94 |
+
args = parser.parse_args()
|
| 95 |
+
|
| 96 |
+
# load config
|
| 97 |
+
config = load_config_as_namespace(args.config)
|
| 98 |
+
|
| 99 |
+
config.log_dir = args.log_dir
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
with open(config.train_data, 'r') as f:
|
| 103 |
+
data = json.load(f)
|
| 104 |
+
except:
|
| 105 |
+
data = sample_train_data(config.train_data, 10000)
|
| 106 |
+
|
| 107 |
+
if config.prev_path != "none":
|
| 108 |
+
model = load_model(config.prev_path)
|
| 109 |
+
model.config = config
|
| 110 |
+
else:
|
| 111 |
+
model = GLiNER(config)
|
| 112 |
+
|
| 113 |
+
if torch.cuda.is_available():
|
| 114 |
+
model = model.cuda()
|
| 115 |
+
|
| 116 |
+
lr_encoder = float(config.lr_encoder)
|
| 117 |
+
lr_others = float(config.lr_others)
|
| 118 |
+
|
| 119 |
+
optimizer = torch.optim.AdamW([
|
| 120 |
+
# encoder
|
| 121 |
+
{'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
|
| 122 |
+
{'params': model.rnn.parameters(), 'lr': lr_others},
|
| 123 |
+
# projection layers
|
| 124 |
+
{'params': model.span_rep_layer.parameters(), 'lr': lr_others},
|
| 125 |
+
{'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
|
| 126 |
+
])
|
| 127 |
+
|
| 128 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 129 |
+
|
| 130 |
+
train(model, optimizer, data, num_steps=config.num_steps, eval_every=config.eval_every,
|
| 131 |
+
log_dir=config.log_dir, warmup_ratio=config.warmup_ratio, train_batch_size=config.train_batch_size,
|
| 132 |
+
device=device)
|
core/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (1.59 kB). View file
|
|
|
core/__pycache__/gradio_ocr.cpython-310.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
core/__pycache__/ner_engine.cpython-310.pyc
ADDED
|
Binary file (2.57 kB). View file
|
|
|
core/__pycache__/ocr_engine.cpython-310.pyc
ADDED
|
Binary file (4.04 kB). View file
|
|
|
core/__pycache__/vlm_engine.cpython-310.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
core/base.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
|
| 4 |
+
class BaseEngine(ABC):
|
| 5 |
+
@abstractmethod
|
| 6 |
+
def process(self, image_path: str) -> Any:
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
class BaseOCR(BaseEngine):
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def extract_text(self, image_path: str) -> str:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
class BaseVLM(BaseEngine):
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def extract_structured_data(self, image_path: str, prompt: str) -> Dict[str, Any]:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
class BaseNER(BaseEngine):
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def extract_entities(self, text: str, labels: List[str]) -> Dict[str, List[str]]:
|
| 22 |
+
pass
|
core/gradio_ocr.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from gradio_client import Client, handle_file
|
| 4 |
+
from .base import BaseOCR
|
| 5 |
+
|
| 6 |
+
class GradioOCREngine(BaseOCR):
|
| 7 |
+
def __init__(self, space_name="WebAshlarWA/glm-ocr-v1"):
|
| 8 |
+
self.space_name = space_name
|
| 9 |
+
self.client = None
|
| 10 |
+
self._initialize_client()
|
| 11 |
+
|
| 12 |
+
def _initialize_client(self):
|
| 13 |
+
try:
|
| 14 |
+
self.client = Client(self.space_name)
|
| 15 |
+
logging.info(f"Gradio Client initialized for Space: {self.space_name}")
|
| 16 |
+
except Exception as e:
|
| 17 |
+
logging.error(f"Failed to initialize Gradio Client for {self.space_name}: {e}")
|
| 18 |
+
|
| 19 |
+
def extract_text(self, image_path: str) -> str:
|
| 20 |
+
if not self.client:
|
| 21 |
+
logging.error("Gradio Client not initialized.")
|
| 22 |
+
return ""
|
| 23 |
+
|
| 24 |
+
logging.info(f"Gradio OCR: Starting extraction for {os.path.basename(image_path)}")
|
| 25 |
+
try:
|
| 26 |
+
# According to the user snippet, the input is 'image' and output is a string?
|
| 27 |
+
# Or structured data. The snippet used /proses_intelijen
|
| 28 |
+
result = self.client.predict(
|
| 29 |
+
image=handle_file(image_path),
|
| 30 |
+
api_name="/proses_intelijen"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if isinstance(result, list) and len(result) > 0:
|
| 34 |
+
# Gradio spaces often return lists of [text, score] or similar
|
| 35 |
+
return str(result[0])
|
| 36 |
+
elif isinstance(result, str):
|
| 37 |
+
return result
|
| 38 |
+
elif isinstance(result, dict):
|
| 39 |
+
# If it's structured, we might need to stringify or handle it elsewhere
|
| 40 |
+
# For OCR we expect a string
|
| 41 |
+
return result.get('text', str(result))
|
| 42 |
+
|
| 43 |
+
logging.info(f"Gradio OCR: Successfully extracted text.")
|
| 44 |
+
return str(result)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logging.error(f"Gradio OCR extraction failed: {e}")
|
| 47 |
+
return ""
|
| 48 |
+
|
| 49 |
+
def process(self, image_path: str) -> str:
|
| 50 |
+
return self.extract_text(image_path)
|
core/ner_engine.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
from .base import BaseNER
|
| 4 |
+
|
| 5 |
+
class NEREngine(BaseNER):
|
| 6 |
+
def __init__(self, model_name="urchade/gliner_mediumv2.1"):
|
| 7 |
+
self.model_name = model_name
|
| 8 |
+
self.model = None
|
| 9 |
+
self._initialize_model()
|
| 10 |
+
|
| 11 |
+
def _initialize_model(self):
|
| 12 |
+
logging.info(f"Initializing NER model: {self.model_name}")
|
| 13 |
+
try:
|
| 14 |
+
from backup.model import GLiNER
|
| 15 |
+
self.model = GLiNER.from_pretrained(self.model_name)
|
| 16 |
+
logging.info(f"NER model '{self.model_name}' loaded successfully.")
|
| 17 |
+
except Exception as e:
|
| 18 |
+
logging.error(f"Failed to load NER model: {e}. NER extraction will be unavailable.")
|
| 19 |
+
|
| 20 |
+
def extract_entities(self, text: str, labels: List[str] = None) -> Dict[str, List[str]]:
|
| 21 |
+
if not text:
|
| 22 |
+
logging.warning("NER: Received empty text for extraction.")
|
| 23 |
+
return {}
|
| 24 |
+
|
| 25 |
+
if not self.model:
|
| 26 |
+
logging.error("NER: Model not initialized. Skipping extraction.")
|
| 27 |
+
return {}
|
| 28 |
+
|
| 29 |
+
if labels is None:
|
| 30 |
+
labels = ["Name", "Designation", "Company", "Contact", "Address", "Email", "Link"]
|
| 31 |
+
|
| 32 |
+
logging.info(f"NER: Extracting entities for {len(text)} characters of text.")
|
| 33 |
+
try:
|
| 34 |
+
entities = self.model.predict_entities(text, labels, threshold=0.3)
|
| 35 |
+
structured_data = {label: [] for label in labels}
|
| 36 |
+
for ent in entities:
|
| 37 |
+
label = ent["label"]
|
| 38 |
+
if label in structured_data:
|
| 39 |
+
structured_data[label].append(ent["text"])
|
| 40 |
+
|
| 41 |
+
non_empty_tags = sum(1 for v in structured_data.values() if v)
|
| 42 |
+
logging.info(f"NER: Extracted entities for {non_empty_tags} labels.")
|
| 43 |
+
return structured_data
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logging.error(f"NER: Extraction pipeline crashed: {e}")
|
| 46 |
+
return {}
|
| 47 |
+
|
| 48 |
+
def process(self, text: str) -> Dict[str, List[str]]:
|
| 49 |
+
return self.extract_entities(text)
|
core/ocr_engine.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image, ImageEnhance
|
| 6 |
+
from .base import BaseOCR
|
| 7 |
+
from .gradio_ocr import GradioOCREngine
|
| 8 |
+
|
| 9 |
+
class OCREngine(BaseOCR):
|
| 10 |
+
def __init__(self, engine_type='paddle'):
|
| 11 |
+
self.engine_type = engine_type
|
| 12 |
+
self.ocr = None
|
| 13 |
+
self.gradio_fallback = None
|
| 14 |
+
self._initialize_engine()
|
| 15 |
+
|
| 16 |
+
def _initialize_engine(self):
|
| 17 |
+
logging.info(f"Initializing OCR engine: {self.engine_type}")
|
| 18 |
+
|
| 19 |
+
# Pre-emptive Gradio initialization as it's the most reliable fallback
|
| 20 |
+
try:
|
| 21 |
+
self.gradio_fallback = GradioOCREngine()
|
| 22 |
+
except Exception as e:
|
| 23 |
+
logging.error(f"Failed to pre-initialize Gradio fallback: {e}")
|
| 24 |
+
|
| 25 |
+
if self.engine_type == 'paddle':
|
| 26 |
+
try:
|
| 27 |
+
from paddleocr import PaddleOCR
|
| 28 |
+
self.ocr = PaddleOCR(use_angle_cls=False, lang='en', show_log=False)
|
| 29 |
+
logging.info("PaddleOCR engine initialized successfully.")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logging.warning(f"Failed to initialize PaddleOCR: {e}. Switching to EasyOCR fallback.")
|
| 32 |
+
self.engine_type = 'easyocr'
|
| 33 |
+
|
| 34 |
+
if self.engine_type == 'easyocr':
|
| 35 |
+
try:
|
| 36 |
+
import easyocr
|
| 37 |
+
self.ocr = easyocr.Reader(['en'])
|
| 38 |
+
logging.info("EasyOCR engine initialized successfully.")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logging.error(f"Failed to initialize EasyOCR: {e}. OCR will be partially unavailable.")
|
| 41 |
+
self.ocr = None
|
| 42 |
+
|
| 43 |
+
def preprocess_image(self, image_path, scale=2):
|
| 44 |
+
try:
|
| 45 |
+
image = cv2.imread(image_path)
|
| 46 |
+
if image is None:
|
| 47 |
+
logging.error(f"Image not found or unreadable: {image_path}")
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
# Upscale
|
| 51 |
+
height, width = image.shape[:2]
|
| 52 |
+
image = cv2.resize(image, (width * scale, height * scale), interpolation=cv2.INTER_CUBIC)
|
| 53 |
+
|
| 54 |
+
# Denoise
|
| 55 |
+
image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
|
| 56 |
+
|
| 57 |
+
# Sharpen
|
| 58 |
+
kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
|
| 59 |
+
image = cv2.filter2D(image, -1, kernel)
|
| 60 |
+
|
| 61 |
+
# Enhance Contrast
|
| 62 |
+
pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
| 63 |
+
enhancer = ImageEnhance.Contrast(pil_img)
|
| 64 |
+
enhanced_image = enhancer.enhance(1.5)
|
| 65 |
+
|
| 66 |
+
logging.debug(f"Preprocessing completed for {image_path}")
|
| 67 |
+
return cv2.cvtColor(np.array(enhanced_image), cv2.COLOR_RGB2BGR)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logging.error(f"Error during image preprocessing for {image_path}: {e}")
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
def extract_text(self, image_path: str) -> str:
|
| 73 |
+
logging.info(f"Starting text extraction for: {os.path.basename(image_path)}")
|
| 74 |
+
|
| 75 |
+
# Tiered Extraction Strategy:
|
| 76 |
+
# 1. Primary Engine (Paddle/EasyOCR)
|
| 77 |
+
# 2. Gradio Remote Fallback (Very reliable)
|
| 78 |
+
|
| 79 |
+
extracted_text = ""
|
| 80 |
+
|
| 81 |
+
# 1. Local OCR
|
| 82 |
+
if self.engine_type == 'paddle' and self.ocr:
|
| 83 |
+
try:
|
| 84 |
+
processed_img = self.preprocess_image(image_path)
|
| 85 |
+
if processed_img is not None:
|
| 86 |
+
results = self.ocr.ocr(processed_img)
|
| 87 |
+
if results and results[0]:
|
| 88 |
+
extracted_text = " ".join([line[1][0] for line in results[0]])
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logging.error(f"PaddleOCR crashed: {e}")
|
| 91 |
+
|
| 92 |
+
elif self.engine_type == 'easyocr' and self.ocr:
|
| 93 |
+
try:
|
| 94 |
+
processed_img = self.preprocess_image(image_path)
|
| 95 |
+
if processed_img is not None:
|
| 96 |
+
results = self.ocr.readtext(processed_img)
|
| 97 |
+
extracted_text = " ".join([res[1] for res in results])
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logging.error(f"EasyOCR crashed: {e}")
|
| 100 |
+
|
| 101 |
+
# 2. Gradio Fallback if Local failed
|
| 102 |
+
if not extracted_text and self.gradio_fallback:
|
| 103 |
+
logging.info("Local OCR failed or returned empty. Trying Gradio OCR fallback...")
|
| 104 |
+
extracted_text = self.gradio_fallback.extract_text(image_path)
|
| 105 |
+
|
| 106 |
+
if extracted_text:
|
| 107 |
+
logging.info(f"OCR extracted {len(extracted_text)} characters using {'Gradio' if not extracted_text else self.engine_type}.")
|
| 108 |
+
else:
|
| 109 |
+
logging.error("All OCR methods failed to extract text.")
|
| 110 |
+
|
| 111 |
+
return extracted_text
|
| 112 |
+
|
| 113 |
+
def process(self, image_path: str) -> str:
|
| 114 |
+
return self.extract_text(image_path)
|
core/vlm_engine.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import base64
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import requests
|
| 6 |
+
import cv2
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
from .base import BaseVLM
|
| 9 |
+
|
| 10 |
+
class GroqVLMEngine(BaseVLM):
|
| 11 |
+
def __init__(self, model="meta-llama/llama-4-scout-17b-16e-instruct"):
|
| 12 |
+
self.api_key = os.getenv("GROQ_API_KEY")
|
| 13 |
+
self.url = "https://api.groq.com/openai/v1/chat/completions"
|
| 14 |
+
self.model = model
|
| 15 |
+
if not self.api_key:
|
| 16 |
+
logging.warning("GROQ_API_KEY missing from environment. VLM extraction will be skipped.")
|
| 17 |
+
|
| 18 |
+
def image_to_base64(self, image_path: str) -> str:
|
| 19 |
+
try:
|
| 20 |
+
img = cv2.imread(image_path)
|
| 21 |
+
if img is None:
|
| 22 |
+
logging.error(f"VLM: Image not found at {image_path}")
|
| 23 |
+
return ""
|
| 24 |
+
_, buffer = cv2.imencode(".jpg", img)
|
| 25 |
+
return base64.b64encode(buffer).decode("utf-8")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
logging.error(f"VLM: Error converting image to base64: {e}")
|
| 28 |
+
return ""
|
| 29 |
+
|
| 30 |
+
def extract_structured_data(self, image_path: str, prompt: str) -> Dict[str, Any]:
|
| 31 |
+
if not self.api_key:
|
| 32 |
+
return {}
|
| 33 |
+
|
| 34 |
+
logging.info(f"VLM: Starting extraction for {os.path.basename(image_path)} using {self.model}")
|
| 35 |
+
base64_image = self.image_to_base64(image_path)
|
| 36 |
+
if not base64_image:
|
| 37 |
+
return {}
|
| 38 |
+
|
| 39 |
+
headers = {
|
| 40 |
+
"Content-Type": "application/json",
|
| 41 |
+
"Authorization": f"Bearer {self.api_key}"
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
payload = {
|
| 45 |
+
"model": self.model,
|
| 46 |
+
"messages": [
|
| 47 |
+
{
|
| 48 |
+
"role": "system",
|
| 49 |
+
"content": "You are a strict information extraction engine for business cards. Return only valid JSON. Do not include any other text."
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"role": "user",
|
| 53 |
+
"content": [
|
| 54 |
+
{"type": "text", "text": prompt},
|
| 55 |
+
{
|
| 56 |
+
"type": "image_url",
|
| 57 |
+
"image_url": {
|
| 58 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
]
|
| 62 |
+
}
|
| 63 |
+
],
|
| 64 |
+
"response_format": {"type": "json_object"},
|
| 65 |
+
"temperature": 0.1
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
resp = requests.post(self.url, headers=headers, json=payload, timeout=60)
|
| 70 |
+
if resp.status_code != 200:
|
| 71 |
+
logging.error(f"VLM API Error: {resp.status_code} - {resp.text}")
|
| 72 |
+
return {}
|
| 73 |
+
|
| 74 |
+
content = resp.json()["choices"][0]["message"]["content"]
|
| 75 |
+
data = json.loads(content)
|
| 76 |
+
logging.info(f"VLM: Successfully extracted structured data from {os.path.basename(image_path)}")
|
| 77 |
+
return data
|
| 78 |
+
except requests.exceptions.Timeout:
|
| 79 |
+
logging.error("VLM: Request timed out.")
|
| 80 |
+
return {}
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logging.error(f"VLM: Unexpected error: {e}")
|
| 83 |
+
return {}
|
| 84 |
+
|
| 85 |
+
def process(self, image_path: str) -> Dict[str, Any]:
|
| 86 |
+
prompt = """
|
| 87 |
+
Extract structured text from this business card and return ONLY valid JSON.
|
| 88 |
+
Fields: Name, Designation, Company, Contact, Address, Email, Link.
|
| 89 |
+
Every value must be a JSON array. If not found, use [].
|
| 90 |
+
"""
|
| 91 |
+
return self.extract_structured_data(image_path, prompt)
|
requirements.txt
CHANGED
|
@@ -1,16 +1,18 @@
|
|
| 1 |
-
Flask
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
Flask
|
| 2 |
+
python-dotenv
|
| 3 |
+
Pillow
|
| 4 |
+
opencv-python
|
| 5 |
+
numpy
|
| 6 |
+
paddle-bfloat
|
| 7 |
+
paddlepaddle>=2.6.0
|
| 8 |
+
paddleocr>=2.7.0
|
| 9 |
+
gradio_client
|
| 10 |
+
easyocr
|
| 11 |
+
langchain
|
| 12 |
+
langchain-community
|
| 13 |
+
torch
|
| 14 |
+
transformers
|
| 15 |
+
flair
|
| 16 |
+
tqdm
|
| 17 |
+
gunicorn
|
| 18 |
+
requests
|
static/uploads/IN_Standard-Visiting-Cards_Overview.png
ADDED
|
templates/index.html
CHANGED
|
@@ -1,284 +1,236 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html lang="en">
|
| 3 |
-
|
| 4 |
-
<head>
|
| 5 |
-
<meta charset="UTF-8" />
|
| 6 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
-
<title>AI Data Extractor</title>
|
| 8 |
-
<link href="https://
|
| 9 |
-
<
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
background-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
padding:
|
| 65 |
-
border-radius:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
border:
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
}
|
| 131 |
-
|
| 132 |
-
.
|
| 133 |
-
background
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
.
|
| 143 |
-
|
| 144 |
-
padding:
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
</
|
| 160 |
-
|
| 161 |
-
<
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
<
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
<
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
};
|
| 224 |
-
|
| 225 |
-
// Flash message auto-hide
|
| 226 |
-
setTimeout(
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
processingMessage.remove(); // Remove the processing message
|
| 238 |
-
}
|
| 239 |
-
|
| 240 |
-
// Re-enable tab buttons and remove disabled class
|
| 241 |
-
const buttons = document.querySelectorAll('.tab button');
|
| 242 |
-
buttons.forEach(button => {
|
| 243 |
-
button.removeAttribute('disabled');
|
| 244 |
-
button.classList.remove('disabled'); // Remove disabled class
|
| 245 |
-
});
|
| 246 |
-
}, 3000); // Adjust timing based on your upload duration
|
| 247 |
-
|
| 248 |
-
// Function to open links in the same tab
|
| 249 |
-
function openLink(url, element) {
|
| 250 |
-
window.location.href = url; // Redirects to the specified URL in the same tab
|
| 251 |
-
|
| 252 |
-
// Remove "active" class from all buttons
|
| 253 |
-
const buttons = document.querySelectorAll('.tab button');
|
| 254 |
-
buttons.forEach(button => button.classList.remove('active'));
|
| 255 |
-
|
| 256 |
-
// Add "active" class to the clicked button
|
| 257 |
-
element.classList.add('active');
|
| 258 |
-
}
|
| 259 |
-
//Refreshing the cookie
|
| 260 |
-
function setCookie(name, value, days) {
|
| 261 |
-
let expires = "";
|
| 262 |
-
if (days) {
|
| 263 |
-
const date = new Date();
|
| 264 |
-
date.setTime(date.getTime() + (days * 24 * 60 * 60 * 1000));
|
| 265 |
-
expires = "; expires=" + date.toUTCString();
|
| 266 |
-
}
|
| 267 |
-
document.cookie = name + "=" + (value || "") + expires + "; path=/";
|
| 268 |
-
}
|
| 269 |
-
|
| 270 |
-
function deleteCookie(name) {
|
| 271 |
-
document.cookie = name + '=; Max-Age=0; path=/;'; // Delete the cookie
|
| 272 |
-
}
|
| 273 |
-
|
| 274 |
-
// Set the cookie (you can comment this out after testing)
|
| 275 |
-
setCookie('myCookie', 'myValue', 1); // Sets a cookie for demonstration
|
| 276 |
-
|
| 277 |
-
// Automatically delete the cookie when the page is loaded or refreshed
|
| 278 |
-
window.onload = function() {
|
| 279 |
-
deleteCookie('myCookie'); // Replace 'myCookie' with your cookie name
|
| 280 |
-
}
|
| 281 |
-
</script>
|
| 282 |
-
</body>
|
| 283 |
-
|
| 284 |
-
</html>
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>AI Data Extractor - Visiting Card</title>
|
| 8 |
+
<link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600&display=swap" rel="stylesheet">
|
| 9 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet" />
|
| 10 |
+
<style>
|
| 11 |
+
:root {
|
| 12 |
+
--primary: #ee4410;
|
| 13 |
+
--secondary: #ff9f0a;
|
| 14 |
+
--bg-dark: #0a0a0c;
|
| 15 |
+
--card-bg: rgba(30, 30, 35, 0.7);
|
| 16 |
+
--text-glow: 0 0 10px rgba(238, 68, 16, 0.5);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
body {
|
| 20 |
+
background-color: var(--bg-dark);
|
| 21 |
+
font-family: 'Outfit', sans-serif;
|
| 22 |
+
color: #f5f5f7;
|
| 23 |
+
overflow-x: hidden;
|
| 24 |
+
background: radial-gradient(circle at 50% 50%, #1a1a1f 0%, #0a0a0c 100%);
|
| 25 |
+
min-height: 100vh;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.glass-card {
|
| 29 |
+
background: var(--card-bg);
|
| 30 |
+
backdrop-filter: blur(12px);
|
| 31 |
+
border: 1px solid rgba(255, 255, 255, 0.1);
|
| 32 |
+
border-radius: 20px;
|
| 33 |
+
box-shadow: 0 15px 35px rgba(0, 0, 0, 0.4);
|
| 34 |
+
padding: 2.5rem;
|
| 35 |
+
margin-top: 2rem;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.premium-title {
|
| 39 |
+
background: linear-gradient(135deg, #fff 0%, #aaa 100%);
|
| 40 |
+
-webkit-background-clip: text;
|
| 41 |
+
background-clip: text;
|
| 42 |
+
-webkit-text-fill-color: transparent;
|
| 43 |
+
font-weight: 600;
|
| 44 |
+
letter-spacing: -1px;
|
| 45 |
+
text-shadow: var(--text-glow);
|
| 46 |
+
margin-bottom: 2rem;
|
| 47 |
+
text-align: center;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.top-bar {
|
| 51 |
+
background: rgba(20, 20, 25, 0.8);
|
| 52 |
+
backdrop-filter: blur(10px);
|
| 53 |
+
padding: 1rem 2rem;
|
| 54 |
+
border-bottom: 1px solid rgba(255, 255, 255, 0.05);
|
| 55 |
+
position: sticky;
|
| 56 |
+
top: 0;
|
| 57 |
+
z-index: 1000;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
.tab-btn {
|
| 61 |
+
background: transparent;
|
| 62 |
+
border: 1px solid rgba(255, 255, 255, 0.1);
|
| 63 |
+
color: #8e8e93;
|
| 64 |
+
padding: 0.6rem 1.2rem;
|
| 65 |
+
border-radius: 30px;
|
| 66 |
+
margin-right: 10px;
|
| 67 |
+
transition: all 0.3s ease;
|
| 68 |
+
text-decoration: none;
|
| 69 |
+
font-size: 0.9rem;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.tab-btn:hover, .tab-btn.active {
|
| 73 |
+
border-color: var(--primary);
|
| 74 |
+
color: #fff;
|
| 75 |
+
background: rgba(238, 68, 16, 0.1);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
.tab-btn.active {
|
| 79 |
+
box-shadow: 0 0 15px rgba(238, 68, 16, 0.2);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.upload-area {
|
| 83 |
+
border: 2px dashed rgba(255, 255, 255, 0.1);
|
| 84 |
+
border-radius: 15px;
|
| 85 |
+
padding: 3rem;
|
| 86 |
+
text-align: center;
|
| 87 |
+
transition: all 0.3s ease;
|
| 88 |
+
cursor: pointer;
|
| 89 |
+
margin-bottom: 2rem;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
.upload-area:hover {
|
| 93 |
+
border-color: var(--primary);
|
| 94 |
+
background: rgba(238, 68, 16, 0.05);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.btn-premium {
|
| 98 |
+
background: linear-gradient(135deg, var(--primary) 0%, var(--secondary) 100%);
|
| 99 |
+
border: none;
|
| 100 |
+
color: white;
|
| 101 |
+
padding: 0.8rem 2rem;
|
| 102 |
+
border-radius: 30px;
|
| 103 |
+
font-weight: 600;
|
| 104 |
+
box-shadow: 0 5px 15px rgba(238, 68, 16, 0.3);
|
| 105 |
+
transition: all 0.3s ease;
|
| 106 |
+
width: 100%;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
.btn-premium:hover:not(:disabled) {
|
| 110 |
+
transform: scale(1.02);
|
| 111 |
+
box-shadow: 0 8px 25px rgba(238, 68, 16, 0.5);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
.btn-premium:disabled {
|
| 115 |
+
opacity: 0.6;
|
| 116 |
+
cursor: not-allowed;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
.loader {
|
| 120 |
+
display: none;
|
| 121 |
+
width: 40px;
|
| 122 |
+
height: 40px;
|
| 123 |
+
border: 3px solid rgba(255,255,255,0.1);
|
| 124 |
+
border-top: 3px solid var(--primary);
|
| 125 |
+
border-radius: 50%;
|
| 126 |
+
animation: spin 1s linear infinite;
|
| 127 |
+
margin: 2rem auto;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
|
| 131 |
+
|
| 132 |
+
.alert-premium {
|
| 133 |
+
background: rgba(40, 40, 45, 0.9);
|
| 134 |
+
border: 1px solid var(--primary);
|
| 135 |
+
color: #fff;
|
| 136 |
+
border-radius: 12px;
|
| 137 |
+
padding: 1rem;
|
| 138 |
+
margin-top: 1.5rem;
|
| 139 |
+
text-align: center;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
.file-list {
|
| 143 |
+
margin-top: 1.5rem;
|
| 144 |
+
padding: 1rem;
|
| 145 |
+
background: rgba(0, 0, 0, 0.2);
|
| 146 |
+
border-radius: 10px;
|
| 147 |
+
font-size: 0.9rem;
|
| 148 |
+
}
|
| 149 |
+
</style>
|
| 150 |
+
</head>
|
| 151 |
+
|
| 152 |
+
<body>
|
| 153 |
+
<div class="top-bar d-flex justify-content-between align-items-center">
|
| 154 |
+
<h4 class="mb-0 text-white" style="font-weight: 600; letter-spacing: -0.5px;">AI EXtractor</h4>
|
| 155 |
+
<div class="d-none d-md-flex">
|
| 156 |
+
<a href="#" class="tab-btn active">Visiting Card</a>
|
| 157 |
+
<a href="https://webashalarforml-resumeextractor2.hf.space/" class="tab-btn">Resume Detail</a>
|
| 158 |
+
</div>
|
| 159 |
+
</div>
|
| 160 |
+
|
| 161 |
+
<div class="container py-5">
|
| 162 |
+
<div class="row justify-content-center">
|
| 163 |
+
<div class="col-lg-6">
|
| 164 |
+
<div class="glass-card">
|
| 165 |
+
<h1 class="premium-title">Card Scanner <small style="display: block; font-size: 0.4em; letter-spacing: 2px; color: var(--primary); margin-top: 5px;">POWERED BY GROQ VLM</small></h1>
|
| 166 |
+
|
| 167 |
+
<form id="uploadForm" action="{{ url_for('upload_file') }}" method="POST" enctype="multipart/form-data">
|
| 168 |
+
<div class="upload-area" onclick="document.getElementById('fileInput').click()">
|
| 169 |
+
<svg xmlns="http://www.w3.org/2000/svg" width="48" height="48" fill="currentColor" class="bi bi-cloud-upload text-muted mb-3" viewBox="0 0 16 16">
|
| 170 |
+
<path fill-rule="evenodd" d="M4.406 3.342A5.53 5.53 0 0 1 8 2c2.69 0 4.923 2 5.166 4.579C14.758 6.804 16 8.137 16 9.773 16 11.569 14.502 13 12.687 13H3.781C1.708 13 0 11.366 0 9.318c0-1.763 1.266-3.223 2.942-3.593.143-.863.698-1.723 1.464-2.383zm.653.757c-.757.653-1.153 1.44-1.153 2.056v.448l-.445.049C2.064 6.805 1 7.952 1 9.318 1 10.785 2.23 12 3.781 12h8.906C13.98 12 15 10.988 15 9.773c0-1.216-1.02-2.228-2.313-2.228h-.5v-.5C12.188 4.825 10.328 3 8 3a4.53 4.53 0 0 0-2.941 1.1z"/>
|
| 171 |
+
<path fill-rule="evenodd" d="M7.646 5.146a.5.5 0 0 1 .708 0l2 2a.5.5 0 0 1-.708.708L8.5 6.707V10.5a.5.5 0 0 1-1 0V6.707L6.354 7.854a.5.5 0 1 1-.708-.708l2-2z"/>
|
| 172 |
+
</svg>
|
| 173 |
+
<p class="mb-0 text-muted">Click or Drag & Drop Business Cards</p>
|
| 174 |
+
<input type="file" name="files" id="fileInput" multiple style="display: none;" required onchange="updateFileList(this)" />
|
| 175 |
+
</div>
|
| 176 |
+
|
| 177 |
+
<div id="fileList" class="file-list" style="display: none;"></div>
|
| 178 |
+
|
| 179 |
+
<button type="submit" id="submitBtn" class="btn-premium mt-3">Start Extraction</button>
|
| 180 |
+
</form>
|
| 181 |
+
|
| 182 |
+
<div class="loader" id="loader"></div>
|
| 183 |
+
<p id="loadingMsg" class="text-center text-muted small mt-2" style="display: none;">Analyzing images with AI engine...</p>
|
| 184 |
+
|
| 185 |
+
{% if session.get('uploaded_files') %}
|
| 186 |
+
<div class="mt-4 pt-3 border-top border-secondary">
|
| 187 |
+
<div class="d-flex justify-content-between align-items-center">
|
| 188 |
+
<span class="small text-muted">Ready to process: {{ session.get('uploaded_files')|length }} files</span>
|
| 189 |
+
<a href="{{ url_for('reset_upload') }}" class="text-danger small text-decoration-none">Clear All</a>
|
| 190 |
+
</div>
|
| 191 |
+
</div>
|
| 192 |
+
{% endif %}
|
| 193 |
+
|
| 194 |
+
{% with messages = get_flashed_messages() %}
|
| 195 |
+
{% if messages %}
|
| 196 |
+
<div class="alert-premium" id="flashMessage">
|
| 197 |
+
{{ messages[0] }}
|
| 198 |
+
</div>
|
| 199 |
+
{% endif %}
|
| 200 |
+
{% endwith %}
|
| 201 |
+
</div>
|
| 202 |
+
</div>
|
| 203 |
+
</div>
|
| 204 |
+
</div>
|
| 205 |
+
|
| 206 |
+
<script>
|
| 207 |
+
function updateFileList(input) {
|
| 208 |
+
const list = document.getElementById('fileList');
|
| 209 |
+
if (input.files.length > 0) {
|
| 210 |
+
list.style.display = 'block';
|
| 211 |
+
list.innerHTML = '<strong>Selected:</strong><br>' +
|
| 212 |
+
Array.from(input.files).map(f => f.name).join('<br>');
|
| 213 |
+
} else {
|
| 214 |
+
list.style.display = 'none';
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
document.getElementById('uploadForm').onsubmit = function() {
|
| 219 |
+
document.getElementById('loader').style.display = 'block';
|
| 220 |
+
document.getElementById('loadingMsg').style.display = 'block';
|
| 221 |
+
document.getElementById('submitBtn').disabled = true;
|
| 222 |
+
document.getElementById('submitBtn').innerText = 'Processing...';
|
| 223 |
+
};
|
| 224 |
+
|
| 225 |
+
// Flash message auto-hide
|
| 226 |
+
setTimeout(() => {
|
| 227 |
+
const flash = document.getElementById('flashMessage');
|
| 228 |
+
if (flash) {
|
| 229 |
+
flash.style.transition = 'opacity 1s ease';
|
| 230 |
+
flash.style.opacity = '0';
|
| 231 |
+
setTimeout(() => flash.remove(), 1000);
|
| 232 |
+
}
|
| 233 |
+
}, 4000);
|
| 234 |
+
</script>
|
| 235 |
+
</body>
|
| 236 |
+
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
templates/result.html
CHANGED
|
@@ -1,248 +1,326 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html lang="en">
|
| 3 |
-
|
| 4 |
-
<head>
|
| 5 |
-
<meta charset="UTF-8" />
|
| 6 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
-
<title>Processed Results</title>
|
| 8 |
-
<link href="https://
|
| 9 |
-
<
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
padding
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
border-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
border:
|
| 77 |
-
padding:
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
-
.
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
</
|
| 247 |
-
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>Processed Results</title>
|
| 8 |
+
<link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600&display=swap" rel="stylesheet">
|
| 9 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet" />
|
| 10 |
+
<style>
|
| 11 |
+
:root {
|
| 12 |
+
--primary: #ee4410;
|
| 13 |
+
--secondary: #ff9f0a;
|
| 14 |
+
--bg-dark: #0a0a0c;
|
| 15 |
+
--card-bg: rgba(30, 30, 35, 0.7);
|
| 16 |
+
--text-glow: 0 0 10px rgba(238, 68, 16, 0.5);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
body {
|
| 20 |
+
background-color: var(--bg-dark);
|
| 21 |
+
font-family: 'Outfit', sans-serif;
|
| 22 |
+
color: #f5f5f7;
|
| 23 |
+
overflow-x: hidden;
|
| 24 |
+
background: radial-gradient(circle at 50% 50%, #1a1a1f 0%, #0a0a0c 100%);
|
| 25 |
+
min-height: 100vh;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.glass-card {
|
| 29 |
+
background: var(--card-bg);
|
| 30 |
+
backdrop-filter: blur(12px);
|
| 31 |
+
border: 1px solid rgba(255, 255, 255, 0.1);
|
| 32 |
+
border-radius: 20px;
|
| 33 |
+
box-shadow: 0 15px 35px rgba(0, 0, 0, 0.4);
|
| 34 |
+
transition: all 0.4s cubic-bezier(0.175, 0.885, 0.32, 1.275);
|
| 35 |
+
padding: 2rem;
|
| 36 |
+
margin-bottom: 2rem;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
.glass-card:hover {
|
| 40 |
+
transform: translateY(-5px);
|
| 41 |
+
border-color: rgba(238, 68, 16, 0.3);
|
| 42 |
+
box-shadow: 0 20px 45px rgba(238, 68, 16, 0.1);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
.premium-title {
|
| 46 |
+
background: linear-gradient(135deg, #fff 0%, #aaa 100%);
|
| 47 |
+
-webkit-background-clip: text;
|
| 48 |
+
background-clip: text;
|
| 49 |
+
-webkit-text-fill-color: transparent;
|
| 50 |
+
font-weight: 600;
|
| 51 |
+
letter-spacing: -1px;
|
| 52 |
+
text-shadow: var(--text-glow);
|
| 53 |
+
margin-bottom: 1.5rem;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
.section-title {
|
| 57 |
+
color: var(--primary);
|
| 58 |
+
font-size: 1.1rem;
|
| 59 |
+
text-transform: uppercase;
|
| 60 |
+
letter-spacing: 2px;
|
| 61 |
+
margin-bottom: 1.5rem;
|
| 62 |
+
display: flex;
|
| 63 |
+
align-items: center;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.section-title::after {
|
| 67 |
+
content: '';
|
| 68 |
+
flex: 1;
|
| 69 |
+
height: 1px;
|
| 70 |
+
background: linear-gradient(90deg, var(--primary), transparent);
|
| 71 |
+
margin-left: 1rem;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
.data-item {
|
| 75 |
+
margin-bottom: 1.5rem;
|
| 76 |
+
border-left: 3px solid var(--primary);
|
| 77 |
+
padding-left: 1rem;
|
| 78 |
+
animation: fadeIn 0.6s ease-out forwards;
|
| 79 |
+
opacity: 0;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
@keyframes fadeIn {
|
| 83 |
+
from { opacity: 0; transform: translateX(-10px); }
|
| 84 |
+
to { opacity: 1; transform: translateX(0); }
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
.data-label {
|
| 88 |
+
color: #8e8e93;
|
| 89 |
+
font-size: 0.85rem;
|
| 90 |
+
margin-bottom: 0.2rem;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.data-value {
|
| 94 |
+
font-size: 1.1rem;
|
| 95 |
+
color: #fff;
|
| 96 |
+
list-style: none;
|
| 97 |
+
padding: 0;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.data-value li {
|
| 101 |
+
margin-bottom: 0.4rem;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
.result-img-container {
|
| 105 |
+
border-radius: 15px;
|
| 106 |
+
overflow: hidden;
|
| 107 |
+
border: 1px solid rgba(255, 255, 255, 0.1);
|
| 108 |
+
position: relative;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
.result-img-container img {
|
| 112 |
+
width: 100%;
|
| 113 |
+
height: auto;
|
| 114 |
+
display: block;
|
| 115 |
+
transition: transform 0.5s ease;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
.result-img-container:hover img {
|
| 119 |
+
transform: scale(1.05);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.btn-premium {
|
| 123 |
+
background: linear-gradient(135deg, var(--primary) 0%, var(--secondary) 100%);
|
| 124 |
+
border: none;
|
| 125 |
+
color: white;
|
| 126 |
+
padding: 0.8rem 2rem;
|
| 127 |
+
border-radius: 30px;
|
| 128 |
+
font-weight: 600;
|
| 129 |
+
box-shadow: 0 5px 15px rgba(238, 68, 16, 0.3);
|
| 130 |
+
text-decoration: none;
|
| 131 |
+
transition: all 0.3s ease;
|
| 132 |
+
display: inline-block;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.btn-premium:hover {
|
| 136 |
+
transform: scale(1.05);
|
| 137 |
+
box-shadow: 0 8px 25px rgba(238, 68, 16, 0.5);
|
| 138 |
+
color: #fff;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
.alert-premium {
|
| 142 |
+
background: rgba(40, 40, 45, 0.9);
|
| 143 |
+
border: 1px solid var(--primary);
|
| 144 |
+
color: #fff;
|
| 145 |
+
border-radius: 12px;
|
| 146 |
+
padding: 1rem;
|
| 147 |
+
animation: slideDown 0.5s cubic-bezier(0.19, 1, 0.22, 1);
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
@keyframes slideDown {
|
| 151 |
+
from { transform: translateY(-50px); opacity: 0; }
|
| 152 |
+
to { transform: translateY(0); opacity: 1; }
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
.debug-panel {
|
| 156 |
+
margin-top: 4rem;
|
| 157 |
+
padding: 2rem;
|
| 158 |
+
background: rgba(0, 0, 0, 0.3);
|
| 159 |
+
border-top: 1px solid rgba(255, 255, 255, 0.05);
|
| 160 |
+
font-family: monospace;
|
| 161 |
+
font-size: 0.8rem;
|
| 162 |
+
color: #666;
|
| 163 |
+
}
|
| 164 |
+
</style>
|
| 165 |
+
</style>
|
| 166 |
+
</head>
|
| 167 |
+
|
| 168 |
+
<body>
|
| 169 |
+
<div class="container py-5">
|
| 170 |
+
{% with messages = get_flashed_messages() %}
|
| 171 |
+
{% if messages %}
|
| 172 |
+
<div class="alert-premium mb-4" id="flash-message">
|
| 173 |
+
{{ messages[0] }}
|
| 174 |
+
</div>
|
| 175 |
+
{% endif %}
|
| 176 |
+
{% endwith %}
|
| 177 |
+
|
| 178 |
+
<div class="d-flex align-items-center justify-content-between mb-5">
|
| 179 |
+
<h1 class="premium-title mb-0">Extraction Analysis <small style="font-size: 0.4em; color: var(--primary); letter-spacing: 2px; font-weight: 400; display: block; text-align: left;">v2.1 GOLD</small></h1>
|
| 180 |
+
<a href="{{ url_for('reset_upload') }}" class="btn-premium" style="width: auto;">Process New Image</a>
|
| 181 |
+
</div>
|
| 182 |
+
|
| 183 |
+
{% if data %}
|
| 184 |
+
<div class="row g-4">
|
| 185 |
+
<!-- Source Image Column -->
|
| 186 |
+
<div class="col-lg-5">
|
| 187 |
+
<div class="glass-card">
|
| 188 |
+
<h3 class="section-title">Source Image</h3>
|
| 189 |
+
{% if Img %}
|
| 190 |
+
{% for filename, result_path in Img.items() %}
|
| 191 |
+
<div class="result-img-container mb-3">
|
| 192 |
+
<img src="{{ result_path }}" alt="Analyzed Document" />
|
| 193 |
+
</div>
|
| 194 |
+
<p class="text-muted small">File: {{ filename | basename }}</p>
|
| 195 |
+
{% endfor %}
|
| 196 |
+
{% else %}
|
| 197 |
+
<div class="p-4 text-center text-muted">No image path available</div>
|
| 198 |
+
{% endif %}
|
| 199 |
+
</div>
|
| 200 |
+
</div>
|
| 201 |
+
|
| 202 |
+
<!-- Extracted Details Column -->
|
| 203 |
+
<div class="col-lg-7">
|
| 204 |
+
<div class="glass-card">
|
| 205 |
+
<h3 class="section-title">Verified Details</h3>
|
| 206 |
+
|
| 207 |
+
<div class="row">
|
| 208 |
+
<!-- Left Data Sub-col -->
|
| 209 |
+
<div class="col-md-6">
|
| 210 |
+
{% if data.name %}
|
| 211 |
+
<div class="data-item" style="animation-delay: 0.1s">
|
| 212 |
+
<div class="data-label">Full Name</div>
|
| 213 |
+
<ul class="data-value">
|
| 214 |
+
{% for val in data.name %}<li>{{ val }}</li>{% endfor %}
|
| 215 |
+
</ul>
|
| 216 |
+
</div>
|
| 217 |
+
{% endif %}
|
| 218 |
+
|
| 219 |
+
{% if data.Designation %}
|
| 220 |
+
<div class="data-item" style="animation-delay: 0.2s">
|
| 221 |
+
<div class="data-label">Designation</div>
|
| 222 |
+
<ul class="data-value">
|
| 223 |
+
{% for val in data.Designation %}<li>{{ val }}</li>{% endfor %}
|
| 224 |
+
</ul>
|
| 225 |
+
</div>
|
| 226 |
+
{% endif %}
|
| 227 |
+
|
| 228 |
+
{% if data.Company %}
|
| 229 |
+
<div class="data-item" style="animation-delay: 0.3s">
|
| 230 |
+
<div class="data-label">Organization</div>
|
| 231 |
+
<ul class="data-value">
|
| 232 |
+
{% for val in data.Company %}<li>{{ val }}</li>{% endfor %}
|
| 233 |
+
</ul>
|
| 234 |
+
</div>
|
| 235 |
+
{% endif %}
|
| 236 |
+
</div>
|
| 237 |
+
|
| 238 |
+
<!-- Right Data Sub-col -->
|
| 239 |
+
<div class="col-md-6">
|
| 240 |
+
{% if data.contact_number %}
|
| 241 |
+
<div class="data-item" style="animation-delay: 0.4s">
|
| 242 |
+
<div class="data-label">Phone Numbers</div>
|
| 243 |
+
<ul class="data-value">
|
| 244 |
+
{% for val in data.contact_number %}<li>{{ val }}</li>{% endfor %}
|
| 245 |
+
</ul>
|
| 246 |
+
</div>
|
| 247 |
+
{% endif %}
|
| 248 |
+
|
| 249 |
+
{% if data.email %}
|
| 250 |
+
<div class="data-item" style="animation-delay: 0.5s">
|
| 251 |
+
<div class="data-label">Email Addresses</div>
|
| 252 |
+
<ul class="data-value">
|
| 253 |
+
{% for val in data.email %}<li>{{ val }}</li>{% endfor %}
|
| 254 |
+
</ul>
|
| 255 |
+
</div>
|
| 256 |
+
{% endif %}
|
| 257 |
+
|
| 258 |
+
{% if data.Location %}
|
| 259 |
+
<div class="data-item" style="animation-delay: 0.6s">
|
| 260 |
+
<div class="data-label">Address</div>
|
| 261 |
+
<ul class="data-value">
|
| 262 |
+
{% for val in data.Location %}<li>{{ val }}</li>{% endfor %}
|
| 263 |
+
</ul>
|
| 264 |
+
</div>
|
| 265 |
+
{% endif %}
|
| 266 |
+
|
| 267 |
+
{% if data.Link %}
|
| 268 |
+
<div class="data-item" style="animation-delay: 0.7s">
|
| 269 |
+
<div class="data-label">Social/Web Links</div>
|
| 270 |
+
<ul class="data-value">
|
| 271 |
+
{% for val in data.Link %}<li>{{ val }}</li>{% endfor %}
|
| 272 |
+
</ul>
|
| 273 |
+
</div>
|
| 274 |
+
{% endif %}
|
| 275 |
+
</div>
|
| 276 |
+
</div>
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
{% if data.status_message %}
|
| 280 |
+
<div class="mt-4 pt-3 border-top border-secondary text-end">
|
| 281 |
+
<span class="badge rounded-pill bg-dark text-muted" style="font-size: 0.65rem; border: 1px solid rgba(255,255,255,0.05)">{{ data.status_message }}</span>
|
| 282 |
+
</div>
|
| 283 |
+
{% endif %}
|
| 284 |
+
</div>
|
| 285 |
+
</div>
|
| 286 |
+
</div>
|
| 287 |
+
|
| 288 |
+
<!-- Debug / Raw Output Panel -->
|
| 289 |
+
<div class="debug-panel rounded">
|
| 290 |
+
<h5 class="mb-3" style="color: #444; font-size: 0.9rem">SYSTEM_LOG :: RAW_EXTRACTION_BUFFER</h5>
|
| 291 |
+
{% for filename, raw in data.extracted_text.items() %}
|
| 292 |
+
<div class="mb-4">
|
| 293 |
+
<div class="mb-1 text-primary">>> {{ filename | basename }}</div>
|
| 294 |
+
<div class="p-3 bg-black rounded" style="color: #0f0; opacity: 0.8; font-size: 0.85rem">
|
| 295 |
+
{{ raw }}
|
| 296 |
+
</div>
|
| 297 |
+
</div>
|
| 298 |
+
{% endfor %}
|
| 299 |
+
</div>
|
| 300 |
+
|
| 301 |
+
{% else %}
|
| 302 |
+
<div class="text-center glass-card py-5">
|
| 303 |
+
<h2 class="premium-title">Waiting for Data...</h2>
|
| 304 |
+
<p class="text-muted">No analysis results found in session.</p>
|
| 305 |
+
<a href="{{ url_for('index') }}" class="btn-premium mt-3">Back to Upload</a>
|
| 306 |
+
</div>
|
| 307 |
+
{% endif %}
|
| 308 |
+
</div>
|
| 309 |
+
|
| 310 |
+
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.bundle.min.js"></script>
|
| 311 |
+
<script>
|
| 312 |
+
// Auto-remove flash messages
|
| 313 |
+
setTimeout(function () {
|
| 314 |
+
let flashMessage = document.getElementById("flash-message");
|
| 315 |
+
if (flashMessage) {
|
| 316 |
+
flashMessage.style.transition = "all 0.8s ease";
|
| 317 |
+
flashMessage.style.opacity = 0;
|
| 318 |
+
flashMessage.style.transform = "translateY(-20px)";
|
| 319 |
+
setTimeout(() => flashMessage.remove(), 800);
|
| 320 |
+
}
|
| 321 |
+
}, 4000);
|
| 322 |
+
</script>
|
| 323 |
+
</body>
|
| 324 |
+
</html>
|
| 325 |
+
|
| 326 |
+
</html>
|
utility/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
utility/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (27 kB). View file
|
|
|
utility/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
utility/utils.py
CHANGED
|
@@ -1,700 +1,132 @@
|
|
| 1 |
-
# libraries
|
| 2 |
import os
|
| 3 |
-
from huggingface_hub import InferenceClient
|
| 4 |
-
from dotenv import load_dotenv
|
| 5 |
import json
|
| 6 |
import re
|
| 7 |
-
#import easyocr
|
| 8 |
-
from PIL import Image, ImageEnhance, ImageDraw
|
| 9 |
-
import cv2
|
| 10 |
-
import numpy as np
|
| 11 |
-
from paddleocr import PaddleOCR
|
| 12 |
import logging
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
return upscaled_image
|
| 65 |
-
|
| 66 |
-
# Function to denoise the image (reduce noise)
|
| 67 |
-
def reduce_noise(image):
|
| 68 |
-
return cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
|
| 69 |
-
|
| 70 |
-
# Function to sharpen the image
|
| 71 |
-
def sharpen_image(image):
|
| 72 |
-
kernel = np.array([[0, -1, 0],
|
| 73 |
-
[-1, 5, -1],
|
| 74 |
-
[0, -1, 0]])
|
| 75 |
-
sharpened_image = cv2.filter2D(image, -1, kernel)
|
| 76 |
-
return sharpened_image
|
| 77 |
-
|
| 78 |
-
# Function to increase contrast and enhance details without changing color
|
| 79 |
-
def enhance_image(image):
|
| 80 |
-
pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
| 81 |
-
enhancer = ImageEnhance.Contrast(pil_img)
|
| 82 |
-
enhanced_image = enhancer.enhance(1.5)
|
| 83 |
-
enhanced_image_bgr = cv2.cvtColor(np.array(enhanced_image), cv2.COLOR_RGB2BGR)
|
| 84 |
-
return enhanced_image_bgr
|
| 85 |
-
|
| 86 |
-
# Complete function to process image
|
| 87 |
-
def process_image(image_path, scale=2):
|
| 88 |
-
# Load the image
|
| 89 |
-
image = load_image(image_path)
|
| 90 |
-
|
| 91 |
-
# Upscale the image
|
| 92 |
-
upscaled_image = upscale_image(image, scale)
|
| 93 |
-
|
| 94 |
-
# Reduce noise
|
| 95 |
-
denoised_image = reduce_noise(upscaled_image)
|
| 96 |
-
|
| 97 |
-
# Sharpen the image
|
| 98 |
-
sharpened_image = sharpen_image(denoised_image)
|
| 99 |
-
|
| 100 |
-
# Enhance the image contrast and details without changing color
|
| 101 |
-
final_image = enhance_image(sharpened_image)
|
| 102 |
-
|
| 103 |
-
return final_image
|
| 104 |
-
|
| 105 |
-
# Function for OCR with PaddleOCR, returning both text and bounding boxes
|
| 106 |
-
def ocr_with_paddle(img):
|
| 107 |
-
final_text = ''
|
| 108 |
-
boxes = []
|
| 109 |
-
|
| 110 |
-
# Initialize PaddleOCR
|
| 111 |
-
# In /app/utility/utils.py
|
| 112 |
-
ocr = PaddleOCR(
|
| 113 |
-
use_angle_cls=True,
|
| 114 |
-
lang='en',
|
| 115 |
-
enable_mkldnn=False, # <--- Add this line to disable the failing optimization
|
| 116 |
-
use_gpu=False # Ensure this is False if you are on a CPU-only container
|
| 117 |
-
)
|
| 118 |
-
# ocr = PaddleOCR(
|
| 119 |
-
# lang='en',
|
| 120 |
-
# use_angle_cls=True,
|
| 121 |
-
# det_model_dir=os.path.join(os.environ['PADDLEOCR_HOME'], 'whl/det'),
|
| 122 |
-
# rec_model_dir=os.path.join(os.environ['PADDLEOCR_HOME'], 'whl/rec/en/en_PP-OCRv4_rec_infer'),
|
| 123 |
-
# cls_model_dir=os.path.join(os.environ['PADDLEOCR_HOME'], 'whl/cls/ch_ppocr_mobile_v2.0_cls_infer')
|
| 124 |
-
# )
|
| 125 |
-
# ocr = PaddleOCR(
|
| 126 |
-
# use_angle_cls=True,
|
| 127 |
-
# lang='en',
|
| 128 |
-
# det_model_dir='/app/paddleocr_models/whl/det/ch_ppocr_mobile_v2.0_det_infer',
|
| 129 |
-
# rec_model_dir='/app/paddleocr_models/whl/rec/ch_ppocr_mobile_v2.0_rec_infer',
|
| 130 |
-
# cls_model_dir='/app/paddleocr_models/whl/cls/ch_ppocr_mobile_v2.0_cls_infer'
|
| 131 |
-
# )
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
# Check if img is a file path or an image array
|
| 135 |
-
if isinstance(img, str):
|
| 136 |
-
img = cv2.imread(img)
|
| 137 |
-
|
| 138 |
-
# Perform OCR
|
| 139 |
-
result = ocr.ocr(img)
|
| 140 |
-
|
| 141 |
-
# Iterate through the OCR result
|
| 142 |
-
for line in result[0]:
|
| 143 |
-
# Check how many values are returned (2 or 3) and unpack accordingly
|
| 144 |
-
if len(line) == 3:
|
| 145 |
-
box, (text, confidence), _ = line # When 3 values are returned
|
| 146 |
-
elif len(line) == 2:
|
| 147 |
-
box, (text, confidence) = line # When only 2 values are returned
|
| 148 |
-
|
| 149 |
-
# Store the recognized text and bounding boxes
|
| 150 |
-
final_text += ' ' + text # Extract the text from the tuple
|
| 151 |
-
boxes.append(box)
|
| 152 |
-
|
| 153 |
-
# Draw the bounding box
|
| 154 |
-
points = [(int(point[0]), int(point[1])) for point in box]
|
| 155 |
-
cv2.polylines(img, [np.array(points)], isClosed=True, color=(0, 255, 0), thickness=2)
|
| 156 |
-
|
| 157 |
-
# Store the image with bounding boxes in a variable
|
| 158 |
-
img_with_boxes = img
|
| 159 |
-
|
| 160 |
-
return final_text, img_with_boxes
|
| 161 |
-
|
| 162 |
-
def extract_text_from_images(image_paths):
|
| 163 |
-
all_extracted_texts = {}
|
| 164 |
-
all_extracted_imgs = {}
|
| 165 |
-
for image_path in image_paths:
|
| 166 |
-
try:
|
| 167 |
-
# Enhance the image before OCR
|
| 168 |
-
enhanced_image = process_image(image_path, scale=2)
|
| 169 |
-
|
| 170 |
-
# Perform OCR on the enhanced image and get boxes
|
| 171 |
-
result, img_with_boxes = ocr_with_paddle(enhanced_image)
|
| 172 |
-
|
| 173 |
-
# Draw bounding boxes on the processed image
|
| 174 |
-
img_result = Image.fromarray(enhanced_image)
|
| 175 |
-
#img_with_boxes = draw_boxes(img_result, boxes)
|
| 176 |
-
|
| 177 |
-
# genrating unique id to save the images
|
| 178 |
-
# Get the current date and time
|
| 179 |
-
current_time = datetime.now()
|
| 180 |
-
|
| 181 |
-
# Format it as a string to create a unique ID
|
| 182 |
-
unique_id = current_time.strftime("%Y%m%d%H%M%S%f")
|
| 183 |
-
|
| 184 |
-
#print(unique_id)
|
| 185 |
-
|
| 186 |
-
# Save the image with boxes
|
| 187 |
-
result_image_path = os.path.join(RESULT_FOLDER, f'result_{unique_id}_{os.path.basename(image_path)}')
|
| 188 |
-
#img_with_boxes.save(result_image_path)
|
| 189 |
-
cv2.imwrite(result_image_path, img_with_boxes)
|
| 190 |
-
|
| 191 |
-
# Store the text and image result paths
|
| 192 |
-
all_extracted_texts[image_path] = result
|
| 193 |
-
all_extracted_imgs[image_path] = result_image_path
|
| 194 |
-
except ValueError as ve:
|
| 195 |
-
print(f"Error processing image {image_path}: {ve}")
|
| 196 |
-
continue # Continue to the next image if there's an error
|
| 197 |
-
|
| 198 |
-
# Convert to JSON-compatible structure
|
| 199 |
-
all_extracted_imgs_json = {str(k): str(v) for k, v in all_extracted_imgs.items()}
|
| 200 |
-
return all_extracted_texts, all_extracted_imgs_json
|
| 201 |
-
|
| 202 |
-
# Function to call the Gemma model and process the output as Json
|
| 203 |
-
# def Data_Extractor(data, client=client):
|
| 204 |
-
# text = f'''Act as a Text extractor for the following text given in text: {data}
|
| 205 |
-
# extract text in the following output JSON string:
|
| 206 |
-
# {{
|
| 207 |
-
# "Name": ["Identify and Extract All the person's name from the text."],
|
| 208 |
-
# "Designation": ["Extract All the designation or job title mentioned in the text."],
|
| 209 |
-
# "Company": ["Extract All the company or organization name if mentioned."],
|
| 210 |
-
# "Contact": ["Extract All phone number, including country codes if present."],
|
| 211 |
-
# "Address": ["Extract All the full postal address or location mentioned in the text."],
|
| 212 |
-
# "Email": ["Identify and Extract All valid email addresses mentioned in the text else 'Not found'."],
|
| 213 |
-
# "Link": ["Identify and Extract any website URLs or social media links present in the text."]
|
| 214 |
-
# }}
|
| 215 |
-
# Output:
|
| 216 |
-
# '''
|
| 217 |
-
|
| 218 |
-
# # Call the API for inference
|
| 219 |
-
# response = client.text_generation(text, max_new_tokens=1000)#, temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.2)
|
| 220 |
-
|
| 221 |
-
# print("parse in text ---:",response)
|
| 222 |
-
|
| 223 |
-
# # Convert the response text to JSON
|
| 224 |
-
# try:
|
| 225 |
-
# json_data = json.loads(response)
|
| 226 |
-
# print("Json_data-------------->",json_data)
|
| 227 |
-
# return json_data
|
| 228 |
-
# except json.JSONDecodeError as e:
|
| 229 |
-
# return {"error": f"Error decoding JSON: {e}"}
|
| 230 |
-
def Data_Extractor(data):
|
| 231 |
-
url = "https://api.groq.com/openai/v1/chat/completions"
|
| 232 |
-
|
| 233 |
-
headers = {
|
| 234 |
-
"Content-Type": "application/json",
|
| 235 |
-
"Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
|
| 236 |
}
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
""
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
"
|
| 274 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
}
|
| 276 |
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
return {"error": "Invalid JSON from model", "raw": content}
|
| 302 |
-
|
| 303 |
-
# For have text compatible to the llm
|
| 304 |
-
def json_to_llm_str(textJson):
|
| 305 |
-
str=''
|
| 306 |
-
for file,item in textJson.items():
|
| 307 |
-
str+=item + ' '
|
| 308 |
-
return str
|
| 309 |
-
|
| 310 |
-
# Define the RE for extracting the contact details like number, mail , portfolio, website etc
|
| 311 |
-
def extract_contact_details(text):
|
| 312 |
-
# Regex patterns
|
| 313 |
-
# Phone numbers with at least 5 digits in any segment
|
| 314 |
-
combined_phone_regex = re.compile(r'''
|
| 315 |
-
(?:
|
| 316 |
-
#(?:(?:\+91[-.\s]?)?\d{5}[-.\s]?\d{5})|(?:\+?\d{1,3})?[-.\s()]?\d{5,}[-.\s()]?\d{5,}[-.\s()]?\d{1,9} | /^[\.-)( ]*([0-9]{3})[\.-)( ]*([0-9]{3})[\.-)( ]*([0-9]{4})$/ |
|
| 317 |
-
\+1\s\(\d{3}\)\s\d{3}-\d{4} | # USA/Canada Intl +1 (XXX) XXX-XXXX
|
| 318 |
-
\(\d{3}\)\s\d{3}-\d{4} | # USA/Canada STD (XXX) XXX-XXXX
|
| 319 |
-
\(\d{3}\)\s\d{3}\s\d{4} | # USA/Canada (XXX) XXX XXXX
|
| 320 |
-
\(\d{3}\)\s\d{3}\s\d{3} | # USA/Canada (XXX) XXX XXX
|
| 321 |
-
\+1\d{10} | # +1 XXXXXXXXXX
|
| 322 |
-
\d{10} | # XXXXXXXXXX
|
| 323 |
-
\+44\s\d{4}\s\d{6} | # UK Intl +44 XXXX XXXXXX
|
| 324 |
-
\+44\s\d{3}\s\d{3}\s\d{4} | # UK Intl +44 XXX XXX XXXX
|
| 325 |
-
0\d{4}\s\d{6} | # UK STD 0XXXX XXXXXX
|
| 326 |
-
0\d{3}\s\d{3}\s\d{4} | # UK STD 0XXX XXX XXXX
|
| 327 |
-
\+44\d{10} | # +44 XXXXXXXXXX
|
| 328 |
-
0\d{10} | # 0XXXXXXXXXX
|
| 329 |
-
\+61\s\d\s\d{4}\s\d{4} | # Australia Intl +61 X XXXX XXXX
|
| 330 |
-
0\d\s\d{4}\s\d{4} | # Australia STD 0X XXXX XXXX
|
| 331 |
-
\+61\d{9} | # +61 XXXXXXXXX
|
| 332 |
-
0\d{9} | # 0XXXXXXXXX
|
| 333 |
-
\+91\s\d{5}-\d{5} | # India Intl +91 XXXXX-XXXXX
|
| 334 |
-
\+91\s\d{4}-\d{6} | # India Intl +91 XXXX-XXXXXX
|
| 335 |
-
\+91\s\d{10} | # India Intl +91 XXXXXXXXXX
|
| 336 |
-
\+91\s\d{3}\s\d{3}\s\d{4} | # India Intl +91 XXX XXX XXXX
|
| 337 |
-
\+91\s\d{3}-\d{3}-\d{4} | # India Intl +91 XXX-XXX-XXXX
|
| 338 |
-
\+91\s\d{2}\s\d{4}\s\d{4} | # India Intl +91 XX XXXX XXXX
|
| 339 |
-
\+91\s\d{2}-\d{4}-\d{4} | # India Intl +91 XX-XXXX-XXXX
|
| 340 |
-
\+91\s\d{5}\s\d{5} | # India Intl +91 XXXXX XXXXX
|
| 341 |
-
\d{5}\s\d{5} | # India XXXXX XXXXX
|
| 342 |
-
\d{5}-\d{5} | # India XXXXX-XXXXX
|
| 343 |
-
0\d{2}-\d{7} | # India STD 0XX-XXXXXXX
|
| 344 |
-
\+91\d{10} | # +91 XXXXXXXXXX
|
| 345 |
-
\d{10} | # XXXXXXXXXX # Here is the regex to handle all possible combination of the contact
|
| 346 |
-
\d{6}-\d{4} | # XXXXXX-XXXX
|
| 347 |
-
\d{4}-\d{6} | # XXXX-XXXXXX
|
| 348 |
-
\d{3}\s\d{3}\s\d{4} | # XXX XXX XXXX
|
| 349 |
-
\d{3}-\d{3}-\d{4} | # XXX-XXX-XXXX
|
| 350 |
-
\d{4}\s\d{3}\s\d{3} | # XXXX XXX XXX
|
| 351 |
-
\d{4}-\d{3}-\d{3} | # XXXX-XXX-XXX #-----
|
| 352 |
-
\+49\s\d{4}\s\d{8} | # Germany Intl +49 XXXX XXXXXXXX
|
| 353 |
-
\+49\s\d{3}\s\d{7} | # Germany Intl +49 XXX XXXXXXX
|
| 354 |
-
0\d{3}\s\d{8} | # Germany STD 0XXX XXXXXXXX
|
| 355 |
-
\+49\d{12} | # +49 XXXXXXXXXXXX
|
| 356 |
-
\+49\d{10} | # +49 XXXXXXXXXX
|
| 357 |
-
0\d{11} | # 0XXXXXXXXXXX
|
| 358 |
-
\+86\s\d{3}\s\d{4}\s\d{4} | # China Intl +86 XXX XXXX XXXX
|
| 359 |
-
0\d{3}\s\d{4}\s\d{4} | # China STD 0XXX XXXX XXXX
|
| 360 |
-
\+86\d{11} | # +86 XXXXXXXXXXX
|
| 361 |
-
\+81\s\d\s\d{4}\s\d{4} | # Japan Intl +81 X XXXX XXXX
|
| 362 |
-
\+81\s\d{2}\s\d{4}\s\d{4} | # Japan Intl +81 XX XXXX XXXX
|
| 363 |
-
0\d\s\d{4}\s\d{4} | # Japan STD 0X XXXX XXXX
|
| 364 |
-
\+81\d{10} | # +81 XXXXXXXXXX
|
| 365 |
-
\+81\d{9} | # +81 XXXXXXXXX
|
| 366 |
-
0\d{9} | # 0XXXXXXXXX
|
| 367 |
-
\+55\s\d{2}\s\d{5}-\d{4} | # Brazil Intl +55 XX XXXXX-XXXX
|
| 368 |
-
\+55\s\d{2}\s\d{4}-\d{4} | # Brazil Intl +55 XX XXXX-XXXX
|
| 369 |
-
0\d{2}\s\d{4}\s\d{4} | # Brazil STD 0XX XXXX XXXX
|
| 370 |
-
\+55\d{11} | # +55 XXXXXXXXXXX
|
| 371 |
-
\+55\d{10} | # +55 XXXXXXXXXX
|
| 372 |
-
0\d{10} | # 0XXXXXXXXXX
|
| 373 |
-
\+33\s\d\s\d{2}\s\d{2}\s\d{2}\s\d{2} | # France Intl +33 X XX XX XX XX
|
| 374 |
-
0\d\s\d{2}\s\d{2}\s\d{2}\s\d{2} | # France STD 0X XX XX XX XX
|
| 375 |
-
\+33\d{9} | # +33 XXXXXXXXX
|
| 376 |
-
0\d{9} | # 0XXXXXXXXX
|
| 377 |
-
\+7\s\d{3}\s\d{3}-\d{2}-\d{2} | # Russia Intl +7 XXX XXX-XX-XX
|
| 378 |
-
8\s\d{3}\s\d{3}-\d{2}-\d{2} | # Russia STD 8 XXX XXX-XX-XX
|
| 379 |
-
\+7\d{10} | # +7 XXXXXXXXXX
|
| 380 |
-
8\d{10} | # 8 XXXXXXXXXX
|
| 381 |
-
\+27\s\d{2}\s\d{3}\s\d{4} | # South Africa Intl +27 XX XXX XXXX
|
| 382 |
-
0\d{2}\s\d{3}\s\d{4} | # South Africa STD 0XX XXX XXXX
|
| 383 |
-
\+27\d{9} | # +27 XXXXXXXXX
|
| 384 |
-
0\d{9} | # 0XXXXXXXXX
|
| 385 |
-
\+52\s\d{3}\s\d{3}\s\d{4} | # Mexico Intl +52 XXX XXX XXXX
|
| 386 |
-
\+52\s\d{2}\s\d{4}\s\d{4} | # Mexico Intl +52 XX XXXX XXXX
|
| 387 |
-
01\s\d{3}\s\d{4} | # Mexico STD 01 XXX XXXX
|
| 388 |
-
\+52\d{10} | # +52 XXXXXXXXXX
|
| 389 |
-
01\d{7} | # 01 XXXXXXX
|
| 390 |
-
\+234\s\d{3}\s\d{3}\s\d{4} | # Nigeria Intl +234 XXX XXX XXXX
|
| 391 |
-
0\d{3}\s\d{3}\s\d{4} | # Nigeria STD 0XXX XXX XXXX
|
| 392 |
-
\+234\d{10} | # +234 XXXXXXXXXX
|
| 393 |
-
0\d{10} | # 0XXXXXXXXXX
|
| 394 |
-
\+971\s\d\s\d{3}\s\d{4} | # UAE Intl +971 X XXX XXXX
|
| 395 |
-
0\d\s\d{3}\s\d{4} | # UAE STD 0X XXX XXXX
|
| 396 |
-
\+971\d{8} | # +971 XXXXXXXX
|
| 397 |
-
0\d{8} | # 0XXXXXXXX
|
| 398 |
-
\+54\s9\s\d{3}\s\d{3}\s\d{4} | # Argentina Intl +54 9 XXX XXX XXXX
|
| 399 |
-
\+54\s\d{1}\s\d{4}\s\d{4} | # Argentina Intl +54 X XXXX XXXX
|
| 400 |
-
0\d{3}\s\d{4} | # Argentina STD 0XXX XXXX
|
| 401 |
-
\+54\d{10} | # +54 9 XXXXXXXXXX
|
| 402 |
-
\+54\d{9} | # +54 XXXXXXXXX
|
| 403 |
-
0\d{7} | # 0XXXXXXX
|
| 404 |
-
\+966\s\d\s\d{3}\s\d{4} | # Saudi Intl +966 X XXX XXXX
|
| 405 |
-
0\d\s\d{3}\s\d{4} | # Saudi STD 0X XXX XXXX
|
| 406 |
-
\+966\d{8} | # +966 XXXXXXXX
|
| 407 |
-
0\d{8} | # 0XXXXXXXX
|
| 408 |
-
\+1\d{10} | # +1 XXXXXXXXXX
|
| 409 |
-
\+1\s\d{3}\s\d{3}\s\d{4} | # +1 XXX XXX XXXX
|
| 410 |
-
\d{5}\s\d{5} | # XXXXX XXXXX
|
| 411 |
-
\d{10} | # XXXXXXXXXX
|
| 412 |
-
\+44\d{10} | # +44 XXXXXXXXXX
|
| 413 |
-
0\d{10} | # 0XXXXXXXXXX
|
| 414 |
-
\+61\d{9} | # +61 XXXXXXXXX
|
| 415 |
-
0\d{9} | # 0XXXXXXXXX
|
| 416 |
-
\+91\d{10} | # +91 XXXXXXXXXX
|
| 417 |
-
\+49\d{12} | # +49 XXXXXXXXXXXX
|
| 418 |
-
\+49\d{10} | # +49 XXXXXXXXXX
|
| 419 |
-
0\d{11} | # 0XXXXXXXXXXX
|
| 420 |
-
\+86\d{11} | # +86 XXXXXXXXXXX
|
| 421 |
-
\+81\d{10} | # +81 XXXXXXXXXX
|
| 422 |
-
\+81\d{9} | # +81 XXXXXXXXX
|
| 423 |
-
0\d{9} | # 0XXXXXXXXX
|
| 424 |
-
\+55\d{11} | # +55 XXXXXXXXXXX
|
| 425 |
-
\+55\d{10} | # +55 XXXXXXXXXX
|
| 426 |
-
0\d{10} | # 0XXXXXXXXXX
|
| 427 |
-
\+33\d{9} | # +33 XXXXXXXXX
|
| 428 |
-
0\d{9} | # 0XXXXXXXXX
|
| 429 |
-
\+7\d{10} | # +7 XXXXXXXXXX
|
| 430 |
-
8\d{10} | # 8 XXXXXXXXXX
|
| 431 |
-
\+27\d{9} | # +27 XXXXXXXXX
|
| 432 |
-
0\d{9} | # 0XXXXXXXXX (South Africa STD)
|
| 433 |
-
\+52\d{10} | # +52 XXXXXXXXXX
|
| 434 |
-
01\d{7} | # 01 XXXXXXX
|
| 435 |
-
\+234\d{10} | # +234 XXXXXXXXXX
|
| 436 |
-
0\d{10} | # 0XXXXXXXXXX
|
| 437 |
-
\+971\d{8} | # +971 XXXXXXXX
|
| 438 |
-
0\d{8} | # 0XXXXXXXX
|
| 439 |
-
\+54\s9\s\d{10} | # +54 9 XXXXXXXXXX
|
| 440 |
-
\+54\d{9} | # +54 XXXXXXXXX
|
| 441 |
-
0\d{7} | # 0XXXXXXX
|
| 442 |
-
\+966\d{8} | # +966 XXXXXXXX
|
| 443 |
-
0\d{8} # 0XXXXXXXX
|
| 444 |
-
\+\d{3}-\d{3}-\d{4}
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
''',re.VERBOSE)
|
| 448 |
-
|
| 449 |
-
# Email regex
|
| 450 |
email_regex = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b')
|
| 451 |
-
|
| 452 |
-
# URL and links regex, updated to avoid conflicts with email domains
|
| 453 |
-
link_regex = re.compile(r'\b(?:https?:\/\/)?(?:www\.)[a-zA-Z0-9-]+\.(?:com|co\.in|co|io|org|net|edu|gov|mil|int|uk|us|in|de|au|app|tech|xyz|info|biz|fr|dev)\b')
|
| 454 |
-
|
| 455 |
-
# Find all matches in the text
|
| 456 |
-
phone_numbers = [num for num in combined_phone_regex.findall(text) if len(num) >= 5]
|
| 457 |
-
|
| 458 |
-
emails = email_regex.findall(text)
|
| 459 |
-
|
| 460 |
-
links_RE = [link for link in link_regex.findall(text) if len(link)>=11]
|
| 461 |
-
|
| 462 |
-
# Remove profile links that might conflict with emails
|
| 463 |
-
links_RE = [link for link in links_RE if not any(email in link for email in emails)]
|
| 464 |
|
| 465 |
return {
|
| 466 |
-
"
|
| 467 |
-
"
|
| 468 |
-
"links_RE": links_RE
|
| 469 |
-
}
|
| 470 |
-
|
| 471 |
-
# preprocessing the data
|
| 472 |
-
def process_extracted_text(extracted_text):
|
| 473 |
-
# Load JSON data
|
| 474 |
-
data = json.dumps(extracted_text, indent=4)
|
| 475 |
-
data = json.loads(data)
|
| 476 |
-
|
| 477 |
-
# Create a single dictionary to hold combined results
|
| 478 |
-
combined_results = {
|
| 479 |
-
"phone_numbers": [],
|
| 480 |
-
"emails": [],
|
| 481 |
-
"links_RE": []
|
| 482 |
}
|
| 483 |
-
|
| 484 |
-
# Process each text entry
|
| 485 |
-
for filename, text in data.items():
|
| 486 |
-
contact_details = extract_contact_details(text)
|
| 487 |
-
# Extend combined results with the details from this file
|
| 488 |
-
combined_results["phone_numbers"].extend(contact_details["phone_numbers"])
|
| 489 |
-
combined_results["emails"].extend(contact_details["emails"])
|
| 490 |
-
combined_results["links_RE"].extend(contact_details["links_RE"])
|
| 491 |
-
|
| 492 |
-
# Convert the combined results to JSON
|
| 493 |
-
#combined_results_json = json.dumps(combined_results, indent=4)
|
| 494 |
-
combined_results_json = combined_results
|
| 495 |
-
|
| 496 |
-
# Print the final JSON results
|
| 497 |
-
print("Combined contact details in JSON format:")
|
| 498 |
-
print(combined_results_json)
|
| 499 |
-
|
| 500 |
-
return combined_results_json
|
| 501 |
-
|
| 502 |
-
# Function to remove duplicates (case-insensitive) from each list in the dictionary
|
| 503 |
-
def remove_duplicates_case_insensitive(data_dict):
|
| 504 |
-
for key, value_list in data_dict.items():
|
| 505 |
-
seen = set()
|
| 506 |
-
unique_list = []
|
| 507 |
-
|
| 508 |
-
for item in value_list:
|
| 509 |
-
if item.lower() not in seen:
|
| 510 |
-
unique_list.append(item) # Add original item (preserving its case)
|
| 511 |
-
seen.add(item.lower()) # Track lowercase version
|
| 512 |
-
|
| 513 |
-
# Update the dictionary with unique values
|
| 514 |
-
data_dict[key] = unique_list
|
| 515 |
-
return data_dict
|
| 516 |
-
|
| 517 |
-
# # Process the model output for parsed result
|
| 518 |
-
# def process_resume_data(LLMdata,cont_data,extracted_text):
|
| 519 |
-
|
| 520 |
-
# # # Removing duplicate emails
|
| 521 |
-
# # unique_emails = []
|
| 522 |
-
# # for email in cont_data['emails']:
|
| 523 |
-
# # if not any(email.lower() == existing_email.lower() for existing_email in LLMdata['Email']):
|
| 524 |
-
# # unique_emails.append(email)
|
| 525 |
-
|
| 526 |
-
# # # Removing duplicate links (case insensitive)
|
| 527 |
-
# # unique_links = []
|
| 528 |
-
# # for link in cont_data['links_RE']:
|
| 529 |
-
# # if not any(link.lower() == existing_link.lower() for existing_link in LLMdata['Link']):
|
| 530 |
-
# # unique_links.append(link)
|
| 531 |
-
|
| 532 |
-
# # # Removing duplicate phone numbers
|
| 533 |
-
# # normalized_contact = [num[-10:] for num in LLMdata['Contact']]
|
| 534 |
-
# # unique_numbers = []
|
| 535 |
-
# # for num in cont_data['phone_numbers']:
|
| 536 |
-
# # if num[-10:] not in normalized_contact:
|
| 537 |
-
# # unique_numbers.append(num)
|
| 538 |
-
|
| 539 |
-
# # # Add unique emails, links, and phone numbers to the original LLMdata
|
| 540 |
-
# # LLMdata['Email'] += unique_emails
|
| 541 |
-
# # LLMdata['Link'] += unique_links
|
| 542 |
-
# # LLMdata['Contact'] += unique_numbers
|
| 543 |
-
# # Ensure keys exist (CRITICAL FIX)
|
| 544 |
-
# LLMdata['Email'] = LLMdata.get('Email', []) or []
|
| 545 |
-
# LLMdata['Link'] = LLMdata.get('Link', []) or []
|
| 546 |
-
# LLMdata['Contact'] = LLMdata.get('Contact', []) or []
|
| 547 |
-
|
| 548 |
-
# # Removing duplicate emails
|
| 549 |
-
# unique_emails = []
|
| 550 |
-
# for email in cont_data.get('emails', []):
|
| 551 |
-
# if not any(email.lower() == str(existing_email).lower() for existing_email in LLMdata['Email']):
|
| 552 |
-
# unique_emails.append(email)
|
| 553 |
-
|
| 554 |
-
# # Removing duplicate links
|
| 555 |
-
# unique_links = []
|
| 556 |
-
# for link in cont_data.get('links_RE', []):
|
| 557 |
-
# if not any(link.lower() == str(existing_link).lower() for existing_link in LLMdata['Link']):
|
| 558 |
-
# unique_links.append(link)
|
| 559 |
-
|
| 560 |
-
# # Normalize existing contacts safely
|
| 561 |
-
# normalized_contact = [
|
| 562 |
-
# str(num)[-10:] for num in LLMdata['Contact'] if num
|
| 563 |
-
# ]
|
| 564 |
-
|
| 565 |
-
# # Removing duplicate phone numbers
|
| 566 |
-
# unique_numbers = []
|
| 567 |
-
# for num in cont_data.get('phone_numbers', []):
|
| 568 |
-
# if str(num)[-10:] not in normalized_contact:
|
| 569 |
-
# unique_numbers.append(num)
|
| 570 |
-
|
| 571 |
-
# # Merge safely
|
| 572 |
-
# LLMdata['Email'].extend(unique_emails)
|
| 573 |
-
# LLMdata['Link'].extend(unique_links)
|
| 574 |
-
# LLMdata['Contact'].extend(unique_numbers)
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
# # Apply the function to the data
|
| 578 |
-
# LLMdata=remove_duplicates_case_insensitive(LLMdata)
|
| 579 |
-
|
| 580 |
-
# # Initialize the processed data dictionary
|
| 581 |
-
# processed_data = {
|
| 582 |
-
# "name": [],
|
| 583 |
-
# "contact_number": [],
|
| 584 |
-
# "Designation":[],
|
| 585 |
-
# "email": [],
|
| 586 |
-
# "Location": [],
|
| 587 |
-
# "Link": [],
|
| 588 |
-
# "Company":[],
|
| 589 |
-
# "extracted_text": extracted_text
|
| 590 |
-
# }
|
| 591 |
-
# #LLM
|
| 592 |
-
|
| 593 |
-
# processed_data['name'].extend(LLMdata.get('Name', None))
|
| 594 |
-
# #processed_data['contact_number'].extend(LLMdata.get('Contact', []))
|
| 595 |
-
# processed_data['Designation'].extend(LLMdata.get('Designation', []))
|
| 596 |
-
# #processed_data['email'].extend(LLMdata.get("Email", []))
|
| 597 |
-
# processed_data['Location'].extend(LLMdata.get('Address', []))
|
| 598 |
-
# #processed_data['Link'].extend(LLMdata.get('Link', []))
|
| 599 |
-
# processed_data['Company'].extend(LLMdata.get('Company', []))
|
| 600 |
-
|
| 601 |
-
# #Contact
|
| 602 |
-
# #processed_data['email'].extend(cont_data.get("emails", []))
|
| 603 |
-
# #processed_data['contact_number'].extend(cont_data.get("phone_numbers", []))
|
| 604 |
-
# #processed_data['Link'].extend(cont_data.get("links_RE", []))
|
| 605 |
-
|
| 606 |
-
# #New_merge_data
|
| 607 |
-
# processed_data['email'].extend(LLMdata['Email'])
|
| 608 |
-
# processed_data['contact_number'].extend(LLMdata['Contact'])
|
| 609 |
-
# processed_data['Link'].extend(LLMdata['Link'])
|
| 610 |
-
|
| 611 |
-
# #to remove not found fields
|
| 612 |
-
# # List of keys to check for 'Not found'
|
| 613 |
-
# keys_to_check = ["name", "contact_number", "Designation", "email", "Location", "Link", "Company"]
|
| 614 |
-
|
| 615 |
-
# # Replace 'Not found' with an empty list for each key
|
| 616 |
-
# for key in keys_to_check:
|
| 617 |
-
# if processed_data[key] == ['Not found'] or processed_data[key] == ['not found']:
|
| 618 |
-
# processed_data[key] = []
|
| 619 |
-
|
| 620 |
-
# return processed_data
|
| 621 |
-
def process_resume_data(LLMdata, cont_data, extracted_text):
|
| 622 |
-
|
| 623 |
-
# -------------------------------
|
| 624 |
-
# ✅ STEP 1: Normalize LLM Schema
|
| 625 |
-
# -------------------------------
|
| 626 |
-
expected_keys = ["Name", "Designation", "Company", "Contact", "Address", "Email", "Link"]
|
| 627 |
-
|
| 628 |
-
for key in expected_keys:
|
| 629 |
-
if key not in LLMdata or LLMdata[key] is None:
|
| 630 |
-
LLMdata[key] = []
|
| 631 |
-
elif not isinstance(LLMdata[key], list):
|
| 632 |
-
LLMdata[key] = [LLMdata[key]]
|
| 633 |
-
|
| 634 |
-
# -------------------------------
|
| 635 |
-
# ✅ STEP 2: Normalize cont_data
|
| 636 |
-
# -------------------------------
|
| 637 |
-
cont_data = cont_data or {}
|
| 638 |
-
cont_data.setdefault("emails", [])
|
| 639 |
-
cont_data.setdefault("phone_numbers", [])
|
| 640 |
-
cont_data.setdefault("links_RE", [])
|
| 641 |
-
|
| 642 |
-
# -------------------------------
|
| 643 |
-
# ✅ STEP 3: Normalize existing contacts
|
| 644 |
-
# -------------------------------
|
| 645 |
-
normalized_llm_numbers = {
|
| 646 |
-
str(num)[-10:] for num in LLMdata["Contact"] if num
|
| 647 |
-
}
|
| 648 |
-
|
| 649 |
-
# -------------------------------
|
| 650 |
-
# ✅ STEP 4: Merge Emails
|
| 651 |
-
# -------------------------------
|
| 652 |
-
for email in cont_data["emails"]:
|
| 653 |
-
if not any(email.lower() == str(e).lower() for e in LLMdata["Email"]):
|
| 654 |
-
LLMdata["Email"].append(email)
|
| 655 |
-
|
| 656 |
-
# -------------------------------
|
| 657 |
-
# ✅ STEP 5: Merge Links
|
| 658 |
-
# -------------------------------
|
| 659 |
-
for link in cont_data["links_RE"]:
|
| 660 |
-
if not any(link.lower() == str(l).lower() for l in LLMdata["Link"]):
|
| 661 |
-
LLMdata["Link"].append(link)
|
| 662 |
-
|
| 663 |
-
# -------------------------------
|
| 664 |
-
# ✅ STEP 6: Merge Phone Numbers
|
| 665 |
-
# -------------------------------
|
| 666 |
-
for num in cont_data["phone_numbers"]:
|
| 667 |
-
norm = str(num)[-10:]
|
| 668 |
-
if norm not in normalized_llm_numbers:
|
| 669 |
-
LLMdata["Contact"].append(num)
|
| 670 |
-
normalized_llm_numbers.add(norm)
|
| 671 |
-
|
| 672 |
-
# -------------------------------
|
| 673 |
-
# ✅ STEP 7: Remove duplicates (case-insensitive)
|
| 674 |
-
# -------------------------------
|
| 675 |
-
LLMdata = remove_duplicates_case_insensitive(LLMdata)
|
| 676 |
-
|
| 677 |
-
# -------------------------------
|
| 678 |
-
# ✅ STEP 8: Build final structure
|
| 679 |
-
# -------------------------------
|
| 680 |
-
processed_data = {
|
| 681 |
-
"name": LLMdata["Name"],
|
| 682 |
-
"contact_number": LLMdata["Contact"],
|
| 683 |
-
"Designation": LLMdata["Designation"],
|
| 684 |
-
"email": LLMdata["Email"],
|
| 685 |
-
"Location": LLMdata["Address"],
|
| 686 |
-
"Link": LLMdata["Link"],
|
| 687 |
-
"Company": LLMdata["Company"],
|
| 688 |
-
"extracted_text": extracted_text
|
| 689 |
-
}
|
| 690 |
-
|
| 691 |
-
# -------------------------------
|
| 692 |
-
# ✅ STEP 9: Clean "Not found"
|
| 693 |
-
# -------------------------------
|
| 694 |
-
for key in ["name", "contact_number", "Designation", "email", "Location", "Link", "Company"]:
|
| 695 |
-
processed_data[key] = [
|
| 696 |
-
v for v in processed_data[key]
|
| 697 |
-
if str(v).lower() != "not found"
|
| 698 |
-
]
|
| 699 |
-
|
| 700 |
-
return processed_data
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import logging
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
# Ensure langchain is available for paddlex/paddleocr
|
| 7 |
+
try:
|
| 8 |
+
import langchain
|
| 9 |
+
import langchain_community
|
| 10 |
+
except ImportError:
|
| 11 |
+
logging.warning("LangChain modules not found. PaddleOCR might fail.")
|
| 12 |
+
|
| 13 |
+
from core.ocr_engine import OCREngine
|
| 14 |
+
from core.vlm_engine import GroqVLMEngine
|
| 15 |
+
from core.ner_engine import NEREngine
|
| 16 |
+
|
| 17 |
+
# Global instances (Lazy load)
|
| 18 |
+
_ocr = None
|
| 19 |
+
_vlm = None
|
| 20 |
+
_ner = None
|
| 21 |
+
|
| 22 |
+
def get_ocr():
|
| 23 |
+
global _ocr
|
| 24 |
+
if not _ocr:
|
| 25 |
+
_ocr = OCREngine()
|
| 26 |
+
return _ocr
|
| 27 |
+
|
| 28 |
+
def get_vlm():
|
| 29 |
+
global _vlm
|
| 30 |
+
if not _vlm:
|
| 31 |
+
_vlm = GroqVLMEngine()
|
| 32 |
+
return _vlm
|
| 33 |
+
|
| 34 |
+
def get_ner():
|
| 35 |
+
global _ner
|
| 36 |
+
if not _ner:
|
| 37 |
+
_ner = NEREngine()
|
| 38 |
+
return _ner
|
| 39 |
+
|
| 40 |
+
def process_image_pipeline(image_paths: List[str]) -> Dict[str, Any]:
|
| 41 |
+
logging.info(f"Pipeline: Starting processing for {len(image_paths)} images.")
|
| 42 |
+
vlm = get_vlm()
|
| 43 |
+
ocr = get_ocr()
|
| 44 |
+
ner = get_ner()
|
| 45 |
+
|
| 46 |
+
final_results = {
|
| 47 |
+
"name": [],
|
| 48 |
+
"contact_number": [],
|
| 49 |
+
"Designation": [],
|
| 50 |
+
"email": [],
|
| 51 |
+
"Location": [],
|
| 52 |
+
"Link": [],
|
| 53 |
+
"Company": [],
|
| 54 |
+
"extracted_text": {},
|
| 55 |
+
"status_message": "Primary: Groq VLM"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
}
|
| 57 |
|
| 58 |
+
all_raw_text = {}
|
| 59 |
+
|
| 60 |
+
for path in image_paths:
|
| 61 |
+
img_name = os.path.basename(path)
|
| 62 |
+
# 1. Primary: VLM
|
| 63 |
+
logging.info(f"Pipeline: Attempting VLM extraction for {img_name}")
|
| 64 |
+
vlm_data = vlm.process(path)
|
| 65 |
+
if vlm_data:
|
| 66 |
+
merge_structured_data(final_results, vlm_data)
|
| 67 |
+
all_raw_text[path] = json.dumps(vlm_data)
|
| 68 |
+
logging.info(f"Pipeline: VLM success for {img_name}")
|
| 69 |
+
else:
|
| 70 |
+
# 2. Fallback: OCR + NER
|
| 71 |
+
logging.warning(f"Pipeline: VLM failed or skipped for {img_name}. Falling back to OCR+NER.")
|
| 72 |
+
raw_text = ocr.extract_text(path)
|
| 73 |
+
all_raw_text[path] = raw_text
|
| 74 |
+
if raw_text:
|
| 75 |
+
logging.info(f"Pipeline: OCR success for {img_name}, attempting NER.")
|
| 76 |
+
ner_data = ner.extract_entities(raw_text)
|
| 77 |
+
if ner_data:
|
| 78 |
+
merge_structured_data(final_results, ner_data)
|
| 79 |
+
logging.info(f"Pipeline: NER success for {img_name}")
|
| 80 |
+
else:
|
| 81 |
+
logging.warning(f"Pipeline: NER failed to extract entities for {img_name}")
|
| 82 |
+
final_results["status_message"] = "Fallback: OCR+NER"
|
| 83 |
+
else:
|
| 84 |
+
logging.error(f"Pipeline: Both VLM and OCR failed for {img_name}")
|
| 85 |
+
|
| 86 |
+
final_results["extracted_text"] = all_raw_text
|
| 87 |
+
cleaned = cleanup_results(final_results)
|
| 88 |
+
logging.info(f"Pipeline: Completed. Extracted data for {sum(1 for v in cleaned.values() if isinstance(v, list) and v)} fields.")
|
| 89 |
+
return cleaned
|
| 90 |
+
|
| 91 |
+
def merge_structured_data(main_data: Dict, new_data: Dict):
|
| 92 |
+
mapping = {
|
| 93 |
+
"Name": "name",
|
| 94 |
+
"Contact": "contact_number",
|
| 95 |
+
"Designation": "Designation",
|
| 96 |
+
"Email": "email",
|
| 97 |
+
"Address": "Location",
|
| 98 |
+
"Link": "Link",
|
| 99 |
+
"Company": "Company"
|
| 100 |
}
|
| 101 |
|
| 102 |
+
for key, val in new_data.items():
|
| 103 |
+
canonical_key = mapping.get(key.capitalize(), key.lower())
|
| 104 |
+
if canonical_key in main_data:
|
| 105 |
+
if isinstance(val, list):
|
| 106 |
+
main_data[canonical_key].extend(val)
|
| 107 |
+
elif val:
|
| 108 |
+
main_data[canonical_key].append(val)
|
| 109 |
+
|
| 110 |
+
def cleanup_results(results: Dict) -> Dict:
|
| 111 |
+
for key, val in results.items():
|
| 112 |
+
if isinstance(val, list):
|
| 113 |
+
# Remove duplicates, empty strings, and 'not found'
|
| 114 |
+
seen = set()
|
| 115 |
+
unique = []
|
| 116 |
+
for item in val:
|
| 117 |
+
item_str = str(item).strip()
|
| 118 |
+
if item_str.lower() not in seen and item_str.lower() not in {"", "not found", "none", "null", "[]"}:
|
| 119 |
+
unique.append(item_str)
|
| 120 |
+
seen.add(item_str.lower())
|
| 121 |
+
results[key] = unique
|
| 122 |
+
return results
|
| 123 |
+
|
| 124 |
+
def extract_contact_details(text: str) -> Dict[str, List[str]]:
|
| 125 |
+
# Regex fallback for extra safety
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
email_regex = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b')
|
| 127 |
+
phone_regex = re.compile(r'(\+?\d{1,3}[-.\s()]?)?\(?\d{3,5}\)?[-.\s()]?\d{3,5}[-.\s()]?\d{3,5}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
return {
|
| 130 |
+
"emails": email_regex.findall(text),
|
| 131 |
+
"phone_numbers": phone_regex.findall(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|