WebashalarForML commited on
Commit
fad436e
·
verified ·
1 Parent(s): 5194ace

Upload 42 files

Browse files
Files changed (40) hide show
  1. app.py +69 -137
  2. backup/__pycache__/model.cpython-310.pyc +0 -0
  3. backup/__pycache__/save_load.cpython-310.pyc +0 -0
  4. backup/__pycache__/train.cpython-310.pyc +0 -0
  5. backup/backup.py +58 -58
  6. backup/model.py +412 -412
  7. backup/modules/__pycache__/base.cpython-310.pyc +0 -0
  8. backup/modules/__pycache__/evaluator.cpython-310.pyc +0 -0
  9. backup/modules/__pycache__/layers.cpython-310.pyc +0 -0
  10. backup/modules/__pycache__/run_evaluation.cpython-310.pyc +0 -0
  11. backup/modules/__pycache__/span_rep.cpython-310.pyc +0 -0
  12. backup/modules/__pycache__/token_rep.cpython-310.pyc +0 -0
  13. backup/modules/base.py +150 -150
  14. backup/modules/data_proc.py +73 -73
  15. backup/modules/evaluator.py +152 -152
  16. backup/modules/layers.py +28 -28
  17. backup/modules/run_evaluation.py +188 -188
  18. backup/modules/span_rep.py +369 -369
  19. backup/modules/token_rep.py +54 -54
  20. backup/requirements.txt +5 -5
  21. backup/save_load.py +20 -20
  22. backup/train.py +132 -132
  23. core/__pycache__/base.cpython-310.pyc +0 -0
  24. core/__pycache__/gradio_ocr.cpython-310.pyc +0 -0
  25. core/__pycache__/ner_engine.cpython-310.pyc +0 -0
  26. core/__pycache__/ocr_engine.cpython-310.pyc +0 -0
  27. core/__pycache__/vlm_engine.cpython-310.pyc +0 -0
  28. core/base.py +22 -0
  29. core/gradio_ocr.py +50 -0
  30. core/ner_engine.py +49 -0
  31. core/ocr_engine.py +114 -0
  32. core/vlm_engine.py +91 -0
  33. requirements.txt +18 -16
  34. static/uploads/IN_Standard-Visiting-Cards_Overview.png +0 -0
  35. templates/index.html +236 -284
  36. templates/result.html +326 -248
  37. utility/__pycache__/utils.cpython-310.pyc +0 -0
  38. utility/__pycache__/utils.cpython-312.pyc +0 -0
  39. utility/__pycache__/utils.cpython-313.pyc +0 -0
  40. utility/utils.py +120 -688
app.py CHANGED
@@ -1,186 +1,118 @@
1
- # libraries
2
- from flask import Flask, render_template, request, redirect, url_for, flash, session, send_from_directory
3
  import os
4
  import logging
5
- from utility.utils import extract_text_from_images, Data_Extractor, json_to_llm_str, process_extracted_text, process_resume_data
6
- from backup.backup import NER_Model
7
- from paddleocr import PaddleOCR
8
-
9
- # Configure logging
10
- logging.basicConfig(
11
- level=logging.INFO,
12
- handlers=[
13
- logging.StreamHandler() # Remove FileHandler and log only to the console
14
- ]
15
- )
16
-
17
- # Flask App
18
- app = Flask(__name__)
19
- app.secret_key = 'your_secret_key'
20
- app.config['UPLOAD_FOLDER'] = 'uploads/'
21
- app.config['RESULT_FOLDER'] = 'results/'
22
 
23
- UPLOAD_FOLDER = 'static/uploads/'
24
- RESULT_FOLDER = 'static/results/'
25
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
26
- os.makedirs(RESULT_FOLDER, exist_ok=True)
27
 
28
- if not os.path.exists(app.config['UPLOAD_FOLDER']):
29
- os.makedirs(app.config['UPLOAD_FOLDER'])
30
 
31
- if not os.path.exists(app.config['RESULT_FOLDER']):
32
- os.makedirs(app.config['RESULT_FOLDER'])
 
 
33
 
34
- # Set the PaddleOCR home directory to a writable location
35
- os.environ['PADDLEOCR_HOME'] = '/tmp/.paddleocr'
36
 
37
- # Check if PaddleOCR home directory is writable
38
- if not os.path.exists('/tmp/.paddleocr'):
39
- os.makedirs('/tmp/.paddleocr', exist_ok=True)
40
- logging.info("Created PaddleOCR home directory.")
41
- else:
42
- logging.info("PaddleOCR home directory exists.")
43
 
44
  @app.route('/')
45
  def index():
46
  uploaded_files = session.get('uploaded_files', [])
47
- logging.info(f"Accessed index page, uploaded files: {uploaded_files}")
48
  return render_template('index.html', uploaded_files=uploaded_files)
49
 
50
  @app.route('/upload', methods=['POST'])
51
  def upload_file():
 
52
  if 'files' not in request.files:
 
53
  flash('No file part')
54
- logging.warning("No file part found in the request")
55
  return redirect(request.url)
56
 
57
  files = request.files.getlist('files')
58
  if not files or all(file.filename == '' for file in files):
 
59
  flash('No selected files')
60
- logging.warning("No files selected for upload")
61
  return redirect(request.url)
62
 
63
- uploaded_files = session.get('uploaded_files', [])
64
  for file in files:
65
  if file:
66
  filename = file.filename
67
  file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
68
  file.save(file_path)
69
- print(f"file path --->{file_path}")
70
  uploaded_files.append(filename)
71
- logging.info(f"Uploaded file: {filename} at {file_path}")
72
 
73
  session['uploaded_files'] = uploaded_files
74
- flash('Files successfully uploaded')
75
- logging.info(f"Files successfully uploaded: {uploaded_files}")
76
  return process_file()
77
 
78
- @app.route('/remove_file',methods=['POST'])
79
- def remove_file():
80
- uploaded_files = session.get('uploaded_files', [])
81
- if uploaded_file:
82
- for filename in uploaded_files:
83
- file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
84
- if os.path.exists(file_path):
85
- os.remove(file_path)
86
- logging.info(f"Removed file: {filename}")
87
- else:
88
- logging.warning(f"File not found for removal: {file_path}") # More specific log
89
-
90
- session.pop('uploaded_files', None)
91
- flash('Files successfully removed')
92
- logging.info("All uploaded files removed")
93
- else:
94
- flash('No file to remove.')
95
- logging.warning("File not found for removal")
96
- return redirect(url_for('index'))
97
-
98
- @app.route('/reset_upload')
99
- def reset_upload():
100
- """Reset the uploaded file and the processed data."""
101
- uploaded_files = session.get('uploaded_files', [])
102
- if uploaded_file:
103
- for filename in uploaded_files:
104
- file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
105
- if os.path.exists(file_path):
106
- os.remove(file_path)
107
- logging.info(f"Removed file: {filename}")
108
- else:
109
- logging.warning(f"File not found for removal: {file_path}") # More specific log
110
-
111
- session.pop('uploaded_files', None)
112
- flash('Files successfully removed')
113
- logging.info("All uploaded files removed")
114
- else:
115
- flash('No file to remove.')
116
- logging.warning("File not found for removal")
117
- return redirect(url_for('index'))
118
-
119
- @app.route('/process', methods=['GET','POST'])
120
  def process_file():
 
121
  uploaded_files = session.get('uploaded_files', [])
122
  if not uploaded_files:
 
123
  flash('No files selected for processing')
124
- logging.warning("No files selected for processing")
125
  return redirect(url_for('index'))
126
 
127
- file_paths = [os.path.join(app.config['UPLOAD_FOLDER'], filename) for filename in uploaded_files]
128
- logging.info(f"Processing files: {file_paths}")
129
-
130
- extracted_text = {}
131
- processed_Img = {}
132
-
133
  try:
134
- extracted_text, processed_Img = extract_text_from_images(file_paths)
135
- logging.info(f"Extracted text: {extracted_text}")
136
- logging.info(f"Processed images: {processed_Img}")
137
-
138
- llmText = json_to_llm_str(extracted_text)
139
- logging.info(f"LLM text: {llmText}")
140
 
141
- LLMdata = Data_Extractor(llmText)
142
- print("llm data--------->",llmText)
143
- logging.info(f"LLM data: {LLMdata}")
 
 
144
 
 
 
 
145
  except Exception as e:
146
- logging.error(f"Error during LLM processing: {e}")
147
- logging.info("Running backup model...")
148
-
149
- LLMdata = {}
150
- extracted_text, processed_Img = extract_text_from_images(file_paths)
151
- logging.info(f"Extracted text(Backup): {extracted_text}")
152
- logging.info(f"Processed images(Backup): {processed_Img}")
153
- if extracted_text:
154
- text = json_to_llm_str(extracted_text)
155
- LLMdata = NER_Model(text)
156
- logging.info(f"NER model data: {LLMdata}")
157
- else:
158
- logging.warning("No extracted text available for backup model")
159
-
160
- cont_data = process_extracted_text(extracted_text)
161
- logging.info(f"Contextual data: {cont_data}")
162
-
163
- processed_data = process_resume_data(LLMdata, cont_data, extracted_text)
164
- logging.info(f"Processed data: {processed_data}")
165
-
166
- session['processed_data'] = processed_data
167
- session['processed_Img'] = processed_Img
168
- flash('Data processed and analyzed successfully')
169
- logging.info("Data processed and analyzed successfully")
170
- return redirect(url_for('result'))
171
-
172
  @app.route('/result')
173
  def result():
174
- processed_data = session.get('processed_data', {})
175
- processed_Img = session.get('processed_Img', {})
176
- logging.info(f"Displaying results: Data - {processed_data}, Images - {processed_Img}")
177
- return render_template('result.html', data=processed_data, Img=processed_Img)
 
178
 
179
- @app.route('/uploads/<filename>')
180
- def uploaded_file(filename):
181
- logging.info(f"Serving file: {filename}")
182
- return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
 
 
 
 
 
 
 
 
 
183
 
184
  if __name__ == '__main__':
185
- logging.info("Starting Flask app")
186
- app.run(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import logging
3
+ from flask import Flask, render_template, request, redirect, url_for, flash, session, send_from_directory
4
+ from utility.utils import process_image_pipeline
5
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # Load environment variables
8
+ load_dotenv()
 
 
9
 
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler()])
12
 
13
+ app = Flask(__name__)
14
+ app.secret_key = os.getenv('SECRET_KEY', 'default_secret_key')
15
+ app.config['UPLOAD_FOLDER'] = 'static/uploads/'
16
+ app.config['RESULT_FOLDER'] = 'static/results/'
17
 
18
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
19
+ os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
20
 
21
+ @app.template_filter('basename')
22
+ def basename_filter(path):
23
+ return os.path.basename(path)
 
 
 
24
 
25
  @app.route('/')
26
  def index():
27
  uploaded_files = session.get('uploaded_files', [])
 
28
  return render_template('index.html', uploaded_files=uploaded_files)
29
 
30
  @app.route('/upload', methods=['POST'])
31
  def upload_file():
32
+ logging.info("Request: /upload received")
33
  if 'files' not in request.files:
34
+ logging.warning("Upload: No file part in request")
35
  flash('No file part')
 
36
  return redirect(request.url)
37
 
38
  files = request.files.getlist('files')
39
  if not files or all(file.filename == '' for file in files):
40
+ logging.warning("Upload: No files selected")
41
  flash('No selected files')
 
42
  return redirect(request.url)
43
 
44
+ uploaded_files = []
45
  for file in files:
46
  if file:
47
  filename = file.filename
48
  file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
49
  file.save(file_path)
 
50
  uploaded_files.append(filename)
51
+ logging.info(f"Upload: Successfully saved {filename}")
52
 
53
  session['uploaded_files'] = uploaded_files
 
 
54
  return process_file()
55
 
56
+ @app.route('/process', methods=['GET', 'POST'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def process_file():
58
+ logging.info("Request: /process started")
59
  uploaded_files = session.get('uploaded_files', [])
60
  if not uploaded_files:
61
+ logging.warning("Process: No files in session")
62
  flash('No files selected for processing')
 
63
  return redirect(url_for('index'))
64
 
65
+ file_paths = [os.path.join(app.config['UPLOAD_FOLDER'], f) for f in uploaded_files]
66
+
 
 
 
 
67
  try:
68
+ logging.info(f"Process: Sending {len(file_paths)} files to pipeline")
69
+ processed_data = process_image_pipeline(file_paths)
 
 
 
 
70
 
71
+ # Format images for result.html
72
+ processed_Img = {f: os.path.join(app.config['UPLOAD_FOLDER'], f) for f in uploaded_files}
73
+
74
+ session['processed_data'] = processed_data
75
+ session['processed_Img'] = processed_Img
76
 
77
+ logging.info("Process: Pipeline completed successfully")
78
+ flash('Data processed successfully')
79
+ return redirect(url_for('result'))
80
  except Exception as e:
81
+ logging.exception(f"Process: Critical failure: {e}")
82
+ flash(f'Processing error: {str(e)}')
83
+ return redirect(url_for('index'))
84
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  @app.route('/result')
86
  def result():
87
+ data = session.get('processed_data', {})
88
+ Img = session.get('processed_Img', {})
89
+ if not data:
90
+ return redirect(url_for('index'))
91
+ return render_template('result.html', data=data, Img=Img)
92
 
93
+ @app.route('/reset_upload')
94
+ def reset_upload():
95
+ uploaded_files = session.get('uploaded_files', [])
96
+ for f in uploaded_files:
97
+ path = os.path.join(app.config['UPLOAD_FOLDER'], f)
98
+ if os.path.exists(path):
99
+ os.remove(path)
100
+
101
+ session.pop('uploaded_files', None)
102
+ session.pop('processed_data', None)
103
+ session.pop('processed_Img', None)
104
+ flash('System reset successful.')
105
+ return redirect(url_for('index'))
106
 
107
  if __name__ == '__main__':
108
+ from utility.utils import get_ocr, get_ner
109
+ logging.info("Core: Pre-initializing engines (this may take a minute)...")
110
+ # Trigger lazy load at startup to avoid reloader issues and request timeouts
111
+ try:
112
+ get_ocr()
113
+ get_ner()
114
+ logging.info("Core: Engines pre-initialized successfully.")
115
+ except Exception as e:
116
+ logging.error(f"Core: Failed to pre-initialize engines: {e}")
117
+
118
+ app.run(debug=True, use_reloader=False, port=int(os.getenv('PORT', 5000)))
backup/__pycache__/model.cpython-310.pyc ADDED
Binary file (9.97 kB). View file
 
backup/__pycache__/save_load.cpython-310.pyc ADDED
Binary file (765 Bytes). View file
 
backup/__pycache__/train.cpython-310.pyc ADDED
Binary file (2.95 kB). View file
 
backup/backup.py CHANGED
@@ -1,59 +1,59 @@
1
- from .model import GLiNER
2
-
3
- # Initialize GLiNER with the base model
4
- model = GLiNER.from_pretrained("urchade/gliner_mediumv2.1")
5
-
6
- # Sample text for entity prediction
7
- text = """
8
- lenskart m: (0)9428002330 Lenskart Store,Surat m: (0)9723817060) e:lenskartsurat@gmail.com Store Address UG-4.Ascon City.Opp.Maheshwari Bhavan,Citylight,Surat-395007"""
9
-
10
- def NER_Model(text):
11
-
12
- labels = ["Person", "Mail", "Number", "Address", "Organization","Designation","Link"]
13
-
14
- # Perform entity prediction
15
- entities = model.predict_entities(text, labels, threshold=0.5)
16
-
17
- # Initialize the processed data dictionary
18
- processed_data = {
19
- "Name": [],
20
- "Contact": [],
21
- "Designation": [],
22
- "Address": [],
23
- "Link": [],
24
- "Company": [],
25
- "Email": [],
26
- "extracted_text": "",
27
- }
28
-
29
- for entity in entities:
30
-
31
- print(entity["text"], "=>", entity["label"])
32
-
33
- #loading the data into json
34
- if entity["label"]==labels[0]:
35
- processed_data['Name'].extend([entity["text"]])
36
-
37
- if entity["label"]==labels[1]:
38
- processed_data['Email'].extend([entity["text"]])
39
-
40
- if entity["label"]==labels[2]:
41
- processed_data['Contact'].extend([entity["text"]])
42
-
43
- if entity["label"]==labels[3]:
44
- processed_data['Address'].extend([entity["text"]])
45
-
46
- if entity["label"]==labels[4]:
47
- processed_data['Company'].extend([entity["text"]])
48
-
49
- if entity["label"]==labels[5]:
50
- processed_data['Designation'].extend([entity["text"]])
51
-
52
- if entity["label"]==labels[6]:
53
- processed_data['Link'].extend([entity["text"]])
54
-
55
-
56
- processed_data['Address']=[', '.join(processed_data['Address'])]
57
- processed_data['extracted_text']=[text]
58
-
59
  return processed_data
 
1
+ from .model import GLiNER
2
+
3
+ # Initialize GLiNER with the base model
4
+ model = GLiNER.from_pretrained("urchade/gliner_mediumv2.1")
5
+
6
+ # Sample text for entity prediction
7
+ text = """
8
+ lenskart m: (0)9428002330 Lenskart Store,Surat m: (0)9723817060) e:lenskartsurat@gmail.com Store Address UG-4.Ascon City.Opp.Maheshwari Bhavan,Citylight,Surat-395007"""
9
+
10
+ def NER_Model(text):
11
+
12
+ labels = ["Person", "Mail", "Number", "Address", "Organization","Designation","Link"]
13
+
14
+ # Perform entity prediction
15
+ entities = model.predict_entities(text, labels, threshold=0.3)
16
+
17
+ # Initialize the processed data dictionary
18
+ processed_data = {
19
+ "Name": [],
20
+ "Contact": [],
21
+ "Designation": [],
22
+ "Address": [],
23
+ "Link": [],
24
+ "Company": [],
25
+ "Email": [],
26
+ "extracted_text": "",
27
+ }
28
+
29
+ for entity in entities:
30
+
31
+ print(entity["text"], "=>", entity["label"])
32
+
33
+ #loading the data into json
34
+ if entity["label"]==labels[0]:
35
+ processed_data['Name'].extend([entity["text"]])
36
+
37
+ if entity["label"]==labels[1]:
38
+ processed_data['Email'].extend([entity["text"]])
39
+
40
+ if entity["label"]==labels[2]:
41
+ processed_data['Contact'].extend([entity["text"]])
42
+
43
+ if entity["label"]==labels[3]:
44
+ processed_data['Address'].extend([entity["text"]])
45
+
46
+ if entity["label"]==labels[4]:
47
+ processed_data['Company'].extend([entity["text"]])
48
+
49
+ if entity["label"]==labels[5]:
50
+ processed_data['Designation'].extend([entity["text"]])
51
+
52
+ if entity["label"]==labels[6]:
53
+ processed_data['Link'].extend([entity["text"]])
54
+
55
+
56
+ processed_data['Address']=[', '.join(processed_data['Address'])]
57
+ processed_data['extracted_text']=[text]
58
+
59
  return processed_data
backup/model.py CHANGED
@@ -1,412 +1,412 @@
1
- import argparse
2
- import json
3
- from pathlib import Path
4
- import re
5
- from typing import Dict, Optional, Union
6
- import torch
7
- import torch.nn.functional as F
8
- from .modules.layers import LstmSeq2SeqEncoder
9
- from .modules.base import InstructBase
10
- from .modules.evaluator import Evaluator, greedy_search
11
- from .modules.span_rep import SpanRepLayer
12
- from .modules.token_rep import TokenRepLayer
13
- from torch import nn
14
- from torch.nn.utils.rnn import pad_sequence
15
- from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
16
- from huggingface_hub.utils import HfHubHTTPError
17
-
18
-
19
-
20
- class GLiNER(InstructBase, PyTorchModelHubMixin):
21
- def __init__(self, config):
22
- super().__init__(config)
23
-
24
- self.config = config
25
-
26
- # [ENT] token
27
- self.entity_token = "<<ENT>>"
28
- self.sep_token = "<<SEP>>"
29
-
30
- # usually a pretrained bidirectional transformer, returns first subtoken representation
31
- self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
32
- subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
33
- add_tokens=[self.entity_token, self.sep_token])
34
-
35
- # hierarchical representation of tokens
36
- self.rnn = LstmSeq2SeqEncoder(
37
- input_size=config.hidden_size,
38
- hidden_size=config.hidden_size // 2,
39
- num_layers=1,
40
- bidirectional=True,
41
- )
42
-
43
- # span representation
44
- self.span_rep_layer = SpanRepLayer(
45
- span_mode=config.span_mode,
46
- hidden_size=config.hidden_size,
47
- max_width=config.max_width,
48
- dropout=config.dropout,
49
- )
50
-
51
- # prompt representation (FFN)
52
- self.prompt_rep_layer = nn.Sequential(
53
- nn.Linear(config.hidden_size, config.hidden_size * 4),
54
- nn.Dropout(config.dropout),
55
- nn.ReLU(),
56
- nn.Linear(config.hidden_size * 4, config.hidden_size)
57
- )
58
-
59
- def compute_score_train(self, x):
60
- span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
61
-
62
- new_length = x['seq_length'].clone()
63
- new_tokens = []
64
- all_len_prompt = []
65
- num_classes_all = []
66
-
67
- # add prompt to the tokens
68
- for i in range(len(x['tokens'])):
69
- all_types_i = list(x['classes_to_id'][i].keys())
70
- # multiple entity types in all_types. Prompt is appended at the start of tokens
71
- entity_prompt = []
72
- num_classes_all.append(len(all_types_i))
73
- # add enity types to prompt
74
- for entity_type in all_types_i:
75
- entity_prompt.append(self.entity_token) # [ENT] token
76
- entity_prompt.append(entity_type) # entity type
77
- entity_prompt.append(self.sep_token) # [SEP] token
78
-
79
- # prompt format:
80
- # [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
81
-
82
- # add prompt to the tokens
83
- tokens_p = entity_prompt + x['tokens'][i]
84
-
85
- # input format:
86
- # [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
87
-
88
- # update length of the sequence (add prompt length to the original length)
89
- new_length[i] = new_length[i] + len(entity_prompt)
90
- # update tokens
91
- new_tokens.append(tokens_p)
92
- # store prompt length
93
- all_len_prompt.append(len(entity_prompt))
94
-
95
- # create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
96
- max_num_classes = max(num_classes_all)
97
- entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
98
- x['span_mask'].device)
99
- entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
100
- x['span_mask'].device) # [batch_size, max_num_classes]
101
-
102
- # compute all token representations
103
- bert_output = self.token_rep_layer(new_tokens, new_length)
104
- word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt)
105
- mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt)
106
-
107
- # get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
108
- word_rep = [] # word representation (after [SEP])
109
- mask = [] # mask (after [SEP])
110
- entity_type_rep = [] # entity type representation (before [SEP])
111
- for i in range(len(x['tokens'])):
112
- prompt_entity_length = all_len_prompt[i] # length of prompt for this example
113
- # get word representation (after [SEP])
114
- word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
115
- # get mask (after [SEP])
116
- mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
117
-
118
- # get entity type representation (before [SEP])
119
- entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP]
120
- entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one
121
- entity_type_rep.append(entity_rep)
122
-
123
- # padding for word_rep, mask and entity_type_rep
124
- word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size]
125
- mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len]
126
- entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size]
127
-
128
- # compute span representation
129
- word_rep = self.rnn(word_rep, mask)
130
- span_rep = self.span_rep_layer(word_rep, span_idx)
131
-
132
- # compute final entity type representation (FFN)
133
- entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
134
- num_classes = entity_type_rep.shape[1] # number of entity types
135
-
136
- # similarity score
137
- scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
138
-
139
- return scores, num_classes, entity_type_mask
140
-
141
- def forward(self, x):
142
- # compute span representation
143
- scores, num_classes, entity_type_mask = self.compute_score_train(x)
144
- batch_size = scores.shape[0]
145
-
146
- # loss for filtering classifier
147
- logits_label = scores.view(-1, num_classes)
148
- labels = x["span_label"].view(-1) # (batch_size * num_spans)
149
- mask_label = labels != -1 # (batch_size * num_spans)
150
- labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
151
-
152
- # one-hot encoding
153
- labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
154
- labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
155
- labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
156
- # Shape of labels_one_hot: (batch_size * num_spans, num_classes)
157
-
158
- # compute loss (without reduction)
159
- all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
160
- reduction='none')
161
- # mask loss using entity_type_mask (B, C)
162
- masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
163
- all_losses = masked_loss.view(-1, num_classes)
164
- # expand mask_label to all_losses
165
- mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
166
- # put lower loss for in label_one_hot (2 for positive, 1 for negative)
167
- weight_c = labels_one_hot + 1
168
- # apply mask
169
- all_losses = all_losses * mask_label.float() * weight_c
170
- return all_losses.sum()
171
-
172
- def compute_score_eval(self, x, device):
173
- # check if classes_to_id is dict
174
- assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
175
-
176
- span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
177
-
178
- all_types = list(x['classes_to_id'].keys())
179
- # multiple entity types in all_types. Prompt is appended at the start of tokens
180
- entity_prompt = []
181
-
182
- # add enity types to prompt
183
- for entity_type in all_types:
184
- entity_prompt.append(self.entity_token)
185
- entity_prompt.append(entity_type)
186
-
187
- entity_prompt.append(self.sep_token)
188
-
189
- prompt_entity_length = len(entity_prompt)
190
-
191
- # add prompt
192
- tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
193
- seq_length_p = x['seq_length'] + prompt_entity_length
194
-
195
- out = self.token_rep_layer(tokens_p, seq_length_p)
196
-
197
- word_rep_w_prompt = out["embeddings"]
198
- mask_w_prompt = out["mask"]
199
-
200
- # remove prompt
201
- word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
202
- mask = mask_w_prompt[:, prompt_entity_length:]
203
-
204
- # get_entity_type_rep
205
- entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
206
- # extract [ENT] tokens (which are at even positions in entity_type_rep)
207
- entity_type_rep = entity_type_rep[:, 0::2, :]
208
-
209
- entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
210
-
211
- word_rep = self.rnn(word_rep, mask)
212
-
213
- span_rep = self.span_rep_layer(word_rep, span_idx)
214
-
215
- local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
216
-
217
- return local_scores
218
-
219
- @torch.no_grad()
220
- def predict(self, x, flat_ner=False, threshold=0.5):
221
- self.eval()
222
- local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
223
- spans = []
224
- for i, _ in enumerate(x["tokens"]):
225
- local_i = local_scores[i]
226
- wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
227
- span_i = []
228
- for s, k, c in zip(*wh_i):
229
- if s + k < len(x["tokens"][i]):
230
- span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
231
- span_i = greedy_search(span_i, flat_ner)
232
- spans.append(span_i)
233
- return spans
234
-
235
- def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
236
- tokens = []
237
- start_token_idx_to_text_idx = []
238
- end_token_idx_to_text_idx = []
239
- for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
240
- tokens.append(match.group())
241
- start_token_idx_to_text_idx.append(match.start())
242
- end_token_idx_to_text_idx.append(match.end())
243
-
244
- input_x = {"tokenized_text": tokens, "ner": None}
245
- x = self.collate_fn([input_x], labels)
246
- output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
247
-
248
- entities = []
249
- for start_token_idx, end_token_idx, ent_type in output[0]:
250
- start_text_idx = start_token_idx_to_text_idx[start_token_idx]
251
- end_text_idx = end_token_idx_to_text_idx[end_token_idx]
252
- entities.append({
253
- "start": start_token_idx_to_text_idx[start_token_idx],
254
- "end": end_token_idx_to_text_idx[end_token_idx],
255
- "text": text[start_text_idx:end_text_idx],
256
- "label": ent_type,
257
- })
258
- return entities
259
-
260
- def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
261
- self.eval()
262
- data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
263
- device = next(self.parameters()).device
264
- all_preds = []
265
- all_trues = []
266
- for x in data_loader:
267
- for k, v in x.items():
268
- if isinstance(v, torch.Tensor):
269
- x[k] = v.to(device)
270
- batch_predictions = self.predict(x, flat_ner, threshold)
271
- all_preds.extend(batch_predictions)
272
- all_trues.extend(x["entities"])
273
- evaluator = Evaluator(all_trues, all_preds)
274
- out, f1 = evaluator.evaluate()
275
- return out, f1
276
-
277
- @classmethod
278
- def _from_pretrained(
279
- cls,
280
- *,
281
- model_id: str,
282
- revision: Optional[str],
283
- cache_dir: Optional[Union[str, Path]],
284
- force_download: bool,
285
- proxies: Optional[Dict],
286
- resume_download: bool,
287
- local_files_only: bool,
288
- token: Union[str, bool, None],
289
- map_location: str = "cpu",
290
- strict: bool = False,
291
- **model_kwargs,
292
- ):
293
- # 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
294
- filenames = ["gliner_base.pt", "gliner_multi.pt"]
295
- for filename in filenames:
296
- model_file = Path(model_id) / filename
297
- if not model_file.exists():
298
- try:
299
- model_file = hf_hub_download(
300
- repo_id=model_id,
301
- filename=filename,
302
- revision=revision,
303
- cache_dir=cache_dir,
304
- force_download=force_download,
305
- proxies=proxies,
306
- resume_download=resume_download,
307
- token=token,
308
- local_files_only=local_files_only,
309
- )
310
- except HfHubHTTPError:
311
- continue
312
- dict_load = torch.load(model_file, map_location=torch.device(map_location))
313
- config = dict_load["config"]
314
- state_dict = dict_load["model_weights"]
315
- config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
316
- model = cls(config)
317
- model.load_state_dict(state_dict, strict=strict, assign=True)
318
- # Required to update flair's internals as well:
319
- model.to(map_location)
320
- return model
321
-
322
- # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
323
- from .train import load_config_as_namespace
324
-
325
- model_file = Path(model_id) / "pytorch_model.bin"
326
- if not model_file.exists():
327
- model_file = hf_hub_download(
328
- repo_id=model_id,
329
- filename="pytorch_model.bin",
330
- revision=revision,
331
- cache_dir=cache_dir,
332
- force_download=force_download,
333
- proxies=proxies,
334
- resume_download=resume_download,
335
- token=token,
336
- local_files_only=local_files_only,
337
- )
338
- config_file = Path(model_id) / "gliner_config.json"
339
- if not config_file.exists():
340
- config_file = hf_hub_download(
341
- repo_id=model_id,
342
- filename="gliner_config.json",
343
- revision=revision,
344
- cache_dir=cache_dir,
345
- force_download=force_download,
346
- proxies=proxies,
347
- resume_download=resume_download,
348
- token=token,
349
- local_files_only=local_files_only,
350
- )
351
- config = load_config_as_namespace(config_file)
352
- model = cls(config)
353
- state_dict = torch.load(model_file, map_location=torch.device(map_location))
354
- model.load_state_dict(state_dict, strict=strict, assign=True)
355
- model.to(map_location)
356
- return model
357
-
358
- def save_pretrained(
359
- self,
360
- save_directory: Union[str, Path],
361
- *,
362
- config: Optional[Union[dict, "DataclassInstance"]] = None,
363
- repo_id: Optional[str] = None,
364
- push_to_hub: bool = False,
365
- **push_to_hub_kwargs,
366
- ) -> Optional[str]:
367
- """
368
- Save weights in local directory.
369
-
370
- Args:
371
- save_directory (`str` or `Path`):
372
- Path to directory in which the model weights and configuration will be saved.
373
- config (`dict` or `DataclassInstance`, *optional*):
374
- Model configuration specified as a key/value dictionary or a dataclass instance.
375
- push_to_hub (`bool`, *optional*, defaults to `False`):
376
- Whether or not to push your model to the Huggingface Hub after saving it.
377
- repo_id (`str`, *optional*):
378
- ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
379
- not provided.
380
- kwargs:
381
- Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
382
- """
383
- save_directory = Path(save_directory)
384
- save_directory.mkdir(parents=True, exist_ok=True)
385
-
386
- # save model weights/files
387
- torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
388
-
389
- # save config (if provided)
390
- if config is None:
391
- config = self.config
392
- if config is not None:
393
- if isinstance(config, argparse.Namespace):
394
- config = vars(config)
395
- (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
396
-
397
- # push to the Hub if required
398
- if push_to_hub:
399
- kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
400
- if config is not None: # kwarg for `push_to_hub`
401
- kwargs["config"] = config
402
- if repo_id is None:
403
- repo_id = save_directory.name # Defaults to `save_directory` name
404
- return self.push_to_hub(repo_id=repo_id, **kwargs)
405
- return None
406
-
407
- def to(self, device):
408
- super().to(device)
409
- import flair
410
-
411
- flair.device = device
412
- return self
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ import re
5
+ from typing import Dict, Optional, Union
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from .modules.layers import LstmSeq2SeqEncoder
9
+ from .modules.base import InstructBase
10
+ from .modules.evaluator import Evaluator, greedy_search
11
+ from .modules.span_rep import SpanRepLayer
12
+ from .modules.token_rep import TokenRepLayer
13
+ from torch import nn
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
16
+ from huggingface_hub.utils import HfHubHTTPError
17
+
18
+
19
+
20
+ class GLiNER(InstructBase, PyTorchModelHubMixin):
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+
24
+ self.config = config
25
+
26
+ # [ENT] token
27
+ self.entity_token = "<<ENT>>"
28
+ self.sep_token = "<<SEP>>"
29
+
30
+ # usually a pretrained bidirectional transformer, returns first subtoken representation
31
+ self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
32
+ subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
33
+ add_tokens=[self.entity_token, self.sep_token])
34
+
35
+ # hierarchical representation of tokens
36
+ self.rnn = LstmSeq2SeqEncoder(
37
+ input_size=config.hidden_size,
38
+ hidden_size=config.hidden_size // 2,
39
+ num_layers=1,
40
+ bidirectional=True,
41
+ )
42
+
43
+ # span representation
44
+ self.span_rep_layer = SpanRepLayer(
45
+ span_mode=config.span_mode,
46
+ hidden_size=config.hidden_size,
47
+ max_width=config.max_width,
48
+ dropout=config.dropout,
49
+ )
50
+
51
+ # prompt representation (FFN)
52
+ self.prompt_rep_layer = nn.Sequential(
53
+ nn.Linear(config.hidden_size, config.hidden_size * 4),
54
+ nn.Dropout(config.dropout),
55
+ nn.ReLU(),
56
+ nn.Linear(config.hidden_size * 4, config.hidden_size)
57
+ )
58
+
59
+ def compute_score_train(self, x):
60
+ span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
61
+
62
+ new_length = x['seq_length'].clone()
63
+ new_tokens = []
64
+ all_len_prompt = []
65
+ num_classes_all = []
66
+
67
+ # add prompt to the tokens
68
+ for i in range(len(x['tokens'])):
69
+ all_types_i = list(x['classes_to_id'][i].keys())
70
+ # multiple entity types in all_types. Prompt is appended at the start of tokens
71
+ entity_prompt = []
72
+ num_classes_all.append(len(all_types_i))
73
+ # add enity types to prompt
74
+ for entity_type in all_types_i:
75
+ entity_prompt.append(self.entity_token) # [ENT] token
76
+ entity_prompt.append(entity_type) # entity type
77
+ entity_prompt.append(self.sep_token) # [SEP] token
78
+
79
+ # prompt format:
80
+ # [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
81
+
82
+ # add prompt to the tokens
83
+ tokens_p = entity_prompt + x['tokens'][i]
84
+
85
+ # input format:
86
+ # [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
87
+
88
+ # update length of the sequence (add prompt length to the original length)
89
+ new_length[i] = new_length[i] + len(entity_prompt)
90
+ # update tokens
91
+ new_tokens.append(tokens_p)
92
+ # store prompt length
93
+ all_len_prompt.append(len(entity_prompt))
94
+
95
+ # create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
96
+ max_num_classes = max(num_classes_all)
97
+ entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
98
+ x['span_mask'].device)
99
+ entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
100
+ x['span_mask'].device) # [batch_size, max_num_classes]
101
+
102
+ # compute all token representations
103
+ bert_output = self.token_rep_layer(new_tokens, new_length)
104
+ word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt)
105
+ mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt)
106
+
107
+ # get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
108
+ word_rep = [] # word representation (after [SEP])
109
+ mask = [] # mask (after [SEP])
110
+ entity_type_rep = [] # entity type representation (before [SEP])
111
+ for i in range(len(x['tokens'])):
112
+ prompt_entity_length = all_len_prompt[i] # length of prompt for this example
113
+ # get word representation (after [SEP])
114
+ word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
115
+ # get mask (after [SEP])
116
+ mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
117
+
118
+ # get entity type representation (before [SEP])
119
+ entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP]
120
+ entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one
121
+ entity_type_rep.append(entity_rep)
122
+
123
+ # padding for word_rep, mask and entity_type_rep
124
+ word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size]
125
+ mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len]
126
+ entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size]
127
+
128
+ # compute span representation
129
+ word_rep = self.rnn(word_rep, mask)
130
+ span_rep = self.span_rep_layer(word_rep, span_idx)
131
+
132
+ # compute final entity type representation (FFN)
133
+ entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
134
+ num_classes = entity_type_rep.shape[1] # number of entity types
135
+
136
+ # similarity score
137
+ scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
138
+
139
+ return scores, num_classes, entity_type_mask
140
+
141
+ def forward(self, x):
142
+ # compute span representation
143
+ scores, num_classes, entity_type_mask = self.compute_score_train(x)
144
+ batch_size = scores.shape[0]
145
+
146
+ # loss for filtering classifier
147
+ logits_label = scores.view(-1, num_classes)
148
+ labels = x["span_label"].view(-1) # (batch_size * num_spans)
149
+ mask_label = labels != -1 # (batch_size * num_spans)
150
+ labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
151
+
152
+ # one-hot encoding
153
+ labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
154
+ labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
155
+ labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
156
+ # Shape of labels_one_hot: (batch_size * num_spans, num_classes)
157
+
158
+ # compute loss (without reduction)
159
+ all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
160
+ reduction='none')
161
+ # mask loss using entity_type_mask (B, C)
162
+ masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
163
+ all_losses = masked_loss.view(-1, num_classes)
164
+ # expand mask_label to all_losses
165
+ mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
166
+ # put lower loss for in label_one_hot (2 for positive, 1 for negative)
167
+ weight_c = labels_one_hot + 1
168
+ # apply mask
169
+ all_losses = all_losses * mask_label.float() * weight_c
170
+ return all_losses.sum()
171
+
172
+ def compute_score_eval(self, x, device):
173
+ # check if classes_to_id is dict
174
+ assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
175
+
176
+ span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
177
+
178
+ all_types = list(x['classes_to_id'].keys())
179
+ # multiple entity types in all_types. Prompt is appended at the start of tokens
180
+ entity_prompt = []
181
+
182
+ # add enity types to prompt
183
+ for entity_type in all_types:
184
+ entity_prompt.append(self.entity_token)
185
+ entity_prompt.append(entity_type)
186
+
187
+ entity_prompt.append(self.sep_token)
188
+
189
+ prompt_entity_length = len(entity_prompt)
190
+
191
+ # add prompt
192
+ tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
193
+ seq_length_p = x['seq_length'] + prompt_entity_length
194
+
195
+ out = self.token_rep_layer(tokens_p, seq_length_p)
196
+
197
+ word_rep_w_prompt = out["embeddings"]
198
+ mask_w_prompt = out["mask"]
199
+
200
+ # remove prompt
201
+ word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
202
+ mask = mask_w_prompt[:, prompt_entity_length:]
203
+
204
+ # get_entity_type_rep
205
+ entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
206
+ # extract [ENT] tokens (which are at even positions in entity_type_rep)
207
+ entity_type_rep = entity_type_rep[:, 0::2, :]
208
+
209
+ entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
210
+
211
+ word_rep = self.rnn(word_rep, mask)
212
+
213
+ span_rep = self.span_rep_layer(word_rep, span_idx)
214
+
215
+ local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
216
+
217
+ return local_scores
218
+
219
+ @torch.no_grad()
220
+ def predict(self, x, flat_ner=False, threshold=0.5):
221
+ self.eval()
222
+ local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
223
+ spans = []
224
+ for i, _ in enumerate(x["tokens"]):
225
+ local_i = local_scores[i]
226
+ wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
227
+ span_i = []
228
+ for s, k, c in zip(*wh_i):
229
+ if s + k < len(x["tokens"][i]):
230
+ span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
231
+ span_i = greedy_search(span_i, flat_ner)
232
+ spans.append(span_i)
233
+ return spans
234
+
235
+ def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
236
+ tokens = []
237
+ start_token_idx_to_text_idx = []
238
+ end_token_idx_to_text_idx = []
239
+ for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
240
+ tokens.append(match.group())
241
+ start_token_idx_to_text_idx.append(match.start())
242
+ end_token_idx_to_text_idx.append(match.end())
243
+
244
+ input_x = {"tokenized_text": tokens, "ner": None}
245
+ x = self.collate_fn([input_x], labels)
246
+ output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
247
+
248
+ entities = []
249
+ for start_token_idx, end_token_idx, ent_type in output[0]:
250
+ start_text_idx = start_token_idx_to_text_idx[start_token_idx]
251
+ end_text_idx = end_token_idx_to_text_idx[end_token_idx]
252
+ entities.append({
253
+ "start": start_token_idx_to_text_idx[start_token_idx],
254
+ "end": end_token_idx_to_text_idx[end_token_idx],
255
+ "text": text[start_text_idx:end_text_idx],
256
+ "label": ent_type,
257
+ })
258
+ return entities
259
+
260
+ def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
261
+ self.eval()
262
+ data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
263
+ device = next(self.parameters()).device
264
+ all_preds = []
265
+ all_trues = []
266
+ for x in data_loader:
267
+ for k, v in x.items():
268
+ if isinstance(v, torch.Tensor):
269
+ x[k] = v.to(device)
270
+ batch_predictions = self.predict(x, flat_ner, threshold)
271
+ all_preds.extend(batch_predictions)
272
+ all_trues.extend(x["entities"])
273
+ evaluator = Evaluator(all_trues, all_preds)
274
+ out, f1 = evaluator.evaluate()
275
+ return out, f1
276
+
277
+ @classmethod
278
+ def _from_pretrained(
279
+ cls,
280
+ *,
281
+ model_id: str,
282
+ revision: Optional[str],
283
+ cache_dir: Optional[Union[str, Path]],
284
+ force_download: bool,
285
+ proxies: Optional[Dict],
286
+ resume_download: bool,
287
+ local_files_only: bool,
288
+ token: Union[str, bool, None],
289
+ map_location: str = "cpu",
290
+ strict: bool = False,
291
+ **model_kwargs,
292
+ ):
293
+ # 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
294
+ filenames = ["gliner_base.pt", "gliner_multi.pt"]
295
+ for filename in filenames:
296
+ model_file = Path(model_id) / filename
297
+ if not model_file.exists():
298
+ try:
299
+ model_file = hf_hub_download(
300
+ repo_id=model_id,
301
+ filename=filename,
302
+ revision=revision,
303
+ cache_dir=cache_dir,
304
+ force_download=force_download,
305
+ proxies=proxies,
306
+ resume_download=resume_download,
307
+ token=token,
308
+ local_files_only=local_files_only,
309
+ )
310
+ except HfHubHTTPError:
311
+ continue
312
+ dict_load = torch.load(model_file, map_location=torch.device(map_location))
313
+ config = dict_load["config"]
314
+ state_dict = dict_load["model_weights"]
315
+ config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
316
+ model = cls(config)
317
+ model.load_state_dict(state_dict, strict=strict, assign=True)
318
+ # Required to update flair's internals as well:
319
+ model.to(map_location)
320
+ return model
321
+
322
+ # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
323
+ from .train import load_config_as_namespace
324
+
325
+ model_file = Path(model_id) / "pytorch_model.bin"
326
+ if not model_file.exists():
327
+ model_file = hf_hub_download(
328
+ repo_id=model_id,
329
+ filename="pytorch_model.bin",
330
+ revision=revision,
331
+ cache_dir=cache_dir,
332
+ force_download=force_download,
333
+ proxies=proxies,
334
+ resume_download=resume_download,
335
+ token=token,
336
+ local_files_only=local_files_only,
337
+ )
338
+ config_file = Path(model_id) / "gliner_config.json"
339
+ if not config_file.exists():
340
+ config_file = hf_hub_download(
341
+ repo_id=model_id,
342
+ filename="gliner_config.json",
343
+ revision=revision,
344
+ cache_dir=cache_dir,
345
+ force_download=force_download,
346
+ proxies=proxies,
347
+ resume_download=resume_download,
348
+ token=token,
349
+ local_files_only=local_files_only,
350
+ )
351
+ config = load_config_as_namespace(config_file)
352
+ model = cls(config)
353
+ state_dict = torch.load(model_file, map_location=torch.device(map_location))
354
+ model.load_state_dict(state_dict, strict=strict, assign=True)
355
+ model.to(map_location)
356
+ return model
357
+
358
+ def save_pretrained(
359
+ self,
360
+ save_directory: Union[str, Path],
361
+ *,
362
+ config: Optional[Union[dict, "DataclassInstance"]] = None,
363
+ repo_id: Optional[str] = None,
364
+ push_to_hub: bool = False,
365
+ **push_to_hub_kwargs,
366
+ ) -> Optional[str]:
367
+ """
368
+ Save weights in local directory.
369
+
370
+ Args:
371
+ save_directory (`str` or `Path`):
372
+ Path to directory in which the model weights and configuration will be saved.
373
+ config (`dict` or `DataclassInstance`, *optional*):
374
+ Model configuration specified as a key/value dictionary or a dataclass instance.
375
+ push_to_hub (`bool`, *optional*, defaults to `False`):
376
+ Whether or not to push your model to the Huggingface Hub after saving it.
377
+ repo_id (`str`, *optional*):
378
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
379
+ not provided.
380
+ kwargs:
381
+ Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
382
+ """
383
+ save_directory = Path(save_directory)
384
+ save_directory.mkdir(parents=True, exist_ok=True)
385
+
386
+ # save model weights/files
387
+ torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
388
+
389
+ # save config (if provided)
390
+ if config is None:
391
+ config = self.config
392
+ if config is not None:
393
+ if isinstance(config, argparse.Namespace):
394
+ config = vars(config)
395
+ (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
396
+
397
+ # push to the Hub if required
398
+ if push_to_hub:
399
+ kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
400
+ if config is not None: # kwarg for `push_to_hub`
401
+ kwargs["config"] = config
402
+ if repo_id is None:
403
+ repo_id = save_directory.name # Defaults to `save_directory` name
404
+ return self.push_to_hub(repo_id=repo_id, **kwargs)
405
+ return None
406
+
407
+ def to(self, device):
408
+ super().to(device)
409
+ import flair
410
+
411
+ flair.device = device
412
+ return self
backup/modules/__pycache__/base.cpython-310.pyc ADDED
Binary file (5.06 kB). View file
 
backup/modules/__pycache__/evaluator.cpython-310.pyc ADDED
Binary file (4.44 kB). View file
 
backup/modules/__pycache__/layers.cpython-310.pyc ADDED
Binary file (1.23 kB). View file
 
backup/modules/__pycache__/run_evaluation.cpython-310.pyc ADDED
Binary file (4.31 kB). View file
 
backup/modules/__pycache__/span_rep.cpython-310.pyc ADDED
Binary file (9.62 kB). View file
 
backup/modules/__pycache__/token_rep.cpython-310.pyc ADDED
Binary file (2.4 kB). View file
 
backup/modules/base.py CHANGED
@@ -1,150 +1,150 @@
1
- from collections import defaultdict
2
- from typing import List, Tuple, Dict
3
-
4
- import torch
5
- from torch import nn
6
- from torch.nn.utils.rnn import pad_sequence
7
- from torch.utils.data import DataLoader
8
- import random
9
-
10
-
11
- class InstructBase(nn.Module):
12
- def __init__(self, config):
13
- super().__init__()
14
- self.max_width = config.max_width
15
- self.base_config = config
16
-
17
- def get_dict(self, spans, classes_to_id):
18
- dict_tag = defaultdict(int)
19
- for span in spans:
20
- if span[2] in classes_to_id:
21
- dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
22
- return dict_tag
23
-
24
- def preprocess_spans(self, tokens, ner, classes_to_id):
25
-
26
- max_len = self.base_config.max_len
27
-
28
- if len(tokens) > max_len:
29
- length = max_len
30
- tokens = tokens[:max_len]
31
- else:
32
- length = len(tokens)
33
-
34
- spans_idx = []
35
- for i in range(length):
36
- spans_idx.extend([(i, i + j) for j in range(self.max_width)])
37
-
38
- dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
39
-
40
- # 0 for null labels
41
- span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
42
- spans_idx = torch.LongTensor(spans_idx)
43
-
44
- # mask for valid spans
45
- valid_span_mask = spans_idx[:, 1] > length - 1
46
-
47
- # mask invalid positions
48
- span_label = span_label.masked_fill(valid_span_mask, -1)
49
-
50
- return {
51
- 'tokens': tokens,
52
- 'span_idx': spans_idx,
53
- 'span_label': span_label,
54
- 'seq_length': length,
55
- 'entities': ner,
56
- }
57
-
58
- def collate_fn(self, batch_list, entity_types=None):
59
- # batch_list: list of dict containing tokens, ner
60
- if entity_types is None:
61
- negs = self.get_negatives(batch_list, 100)
62
- class_to_ids = []
63
- id_to_classes = []
64
- for b in batch_list:
65
- # negs = b["negative"]
66
- random.shuffle(negs)
67
-
68
- # negs = negs[:sampled_neg]
69
- max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)
70
-
71
- if max_neg_type_ratio == 0:
72
- # no negatives
73
- neg_type_ratio = 0
74
- else:
75
- neg_type_ratio = random.randint(0, max_neg_type_ratio)
76
-
77
- if neg_type_ratio == 0:
78
- # no negatives
79
- negs_i = []
80
- else:
81
- negs_i = negs[:len(b['ner']) * neg_type_ratio]
82
-
83
- # this is the list of all possible entity types (positive and negative)
84
- types = list(set([el[-1] for el in b['ner']] + negs_i))
85
-
86
- # shuffle (every epoch)
87
- random.shuffle(types)
88
-
89
- if len(types) != 0:
90
- # prob of higher number shoul
91
- # random drop
92
- if self.base_config.random_drop:
93
- num_ents = random.randint(1, len(types))
94
- types = types[:num_ents]
95
-
96
- # maximum number of entities types
97
- types = types[:int(self.base_config.max_types)]
98
-
99
- # supervised training
100
- if "label" in b:
101
- types = sorted(b["label"])
102
-
103
- class_to_id = {k: v for v, k in enumerate(types, start=1)}
104
- id_to_class = {k: v for v, k in class_to_id.items()}
105
- class_to_ids.append(class_to_id)
106
- id_to_classes.append(id_to_class)
107
-
108
- batch = [
109
- self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
110
- ]
111
-
112
- else:
113
- class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
114
- id_to_classes = {k: v for v, k in class_to_ids.items()}
115
- batch = [
116
- self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
117
- ]
118
-
119
- span_idx = pad_sequence(
120
- [b['span_idx'] for b in batch], batch_first=True, padding_value=0
121
- )
122
-
123
- span_label = pad_sequence(
124
- [el['span_label'] for el in batch], batch_first=True, padding_value=-1
125
- )
126
-
127
- return {
128
- 'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
129
- 'span_idx': span_idx,
130
- 'tokens': [el['tokens'] for el in batch],
131
- 'span_mask': span_label != -1,
132
- 'span_label': span_label,
133
- 'entities': [el['entities'] for el in batch],
134
- 'classes_to_id': class_to_ids,
135
- 'id_to_classes': id_to_classes,
136
- }
137
-
138
- @staticmethod
139
- def get_negatives(batch_list, sampled_neg=5):
140
- ent_types = []
141
- for b in batch_list:
142
- types = set([el[-1] for el in b['ner']])
143
- ent_types.extend(list(types))
144
- ent_types = list(set(ent_types))
145
- # sample negatives
146
- random.shuffle(ent_types)
147
- return ent_types[:sampled_neg]
148
-
149
- def create_dataloader(self, data, entity_types=None, **kwargs):
150
- return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)
 
1
+ from collections import defaultdict
2
+ from typing import List, Tuple, Dict
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ from torch.utils.data import DataLoader
8
+ import random
9
+
10
+
11
+ class InstructBase(nn.Module):
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ self.max_width = config.max_width
15
+ self.base_config = config
16
+
17
+ def get_dict(self, spans, classes_to_id):
18
+ dict_tag = defaultdict(int)
19
+ for span in spans:
20
+ if span[2] in classes_to_id:
21
+ dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
22
+ return dict_tag
23
+
24
+ def preprocess_spans(self, tokens, ner, classes_to_id):
25
+
26
+ max_len = self.base_config.max_len
27
+
28
+ if len(tokens) > max_len:
29
+ length = max_len
30
+ tokens = tokens[:max_len]
31
+ else:
32
+ length = len(tokens)
33
+
34
+ spans_idx = []
35
+ for i in range(length):
36
+ spans_idx.extend([(i, i + j) for j in range(self.max_width)])
37
+
38
+ dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
39
+
40
+ # 0 for null labels
41
+ span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
42
+ spans_idx = torch.LongTensor(spans_idx)
43
+
44
+ # mask for valid spans
45
+ valid_span_mask = spans_idx[:, 1] > length - 1
46
+
47
+ # mask invalid positions
48
+ span_label = span_label.masked_fill(valid_span_mask, -1)
49
+
50
+ return {
51
+ 'tokens': tokens,
52
+ 'span_idx': spans_idx,
53
+ 'span_label': span_label,
54
+ 'seq_length': length,
55
+ 'entities': ner,
56
+ }
57
+
58
+ def collate_fn(self, batch_list, entity_types=None):
59
+ # batch_list: list of dict containing tokens, ner
60
+ if entity_types is None:
61
+ negs = self.get_negatives(batch_list, 100)
62
+ class_to_ids = []
63
+ id_to_classes = []
64
+ for b in batch_list:
65
+ # negs = b["negative"]
66
+ random.shuffle(negs)
67
+
68
+ # negs = negs[:sampled_neg]
69
+ max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)
70
+
71
+ if max_neg_type_ratio == 0:
72
+ # no negatives
73
+ neg_type_ratio = 0
74
+ else:
75
+ neg_type_ratio = random.randint(0, max_neg_type_ratio)
76
+
77
+ if neg_type_ratio == 0:
78
+ # no negatives
79
+ negs_i = []
80
+ else:
81
+ negs_i = negs[:len(b['ner']) * neg_type_ratio]
82
+
83
+ # this is the list of all possible entity types (positive and negative)
84
+ types = list(set([el[-1] for el in b['ner']] + negs_i))
85
+
86
+ # shuffle (every epoch)
87
+ random.shuffle(types)
88
+
89
+ if len(types) != 0:
90
+ # prob of higher number shoul
91
+ # random drop
92
+ if self.base_config.random_drop:
93
+ num_ents = random.randint(1, len(types))
94
+ types = types[:num_ents]
95
+
96
+ # maximum number of entities types
97
+ types = types[:int(self.base_config.max_types)]
98
+
99
+ # supervised training
100
+ if "label" in b:
101
+ types = sorted(b["label"])
102
+
103
+ class_to_id = {k: v for v, k in enumerate(types, start=1)}
104
+ id_to_class = {k: v for v, k in class_to_id.items()}
105
+ class_to_ids.append(class_to_id)
106
+ id_to_classes.append(id_to_class)
107
+
108
+ batch = [
109
+ self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
110
+ ]
111
+
112
+ else:
113
+ class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
114
+ id_to_classes = {k: v for v, k in class_to_ids.items()}
115
+ batch = [
116
+ self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
117
+ ]
118
+
119
+ span_idx = pad_sequence(
120
+ [b['span_idx'] for b in batch], batch_first=True, padding_value=0
121
+ )
122
+
123
+ span_label = pad_sequence(
124
+ [el['span_label'] for el in batch], batch_first=True, padding_value=-1
125
+ )
126
+
127
+ return {
128
+ 'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
129
+ 'span_idx': span_idx,
130
+ 'tokens': [el['tokens'] for el in batch],
131
+ 'span_mask': span_label != -1,
132
+ 'span_label': span_label,
133
+ 'entities': [el['entities'] for el in batch],
134
+ 'classes_to_id': class_to_ids,
135
+ 'id_to_classes': id_to_classes,
136
+ }
137
+
138
+ @staticmethod
139
+ def get_negatives(batch_list, sampled_neg=5):
140
+ ent_types = []
141
+ for b in batch_list:
142
+ types = set([el[-1] for el in b['ner']])
143
+ ent_types.extend(list(types))
144
+ ent_types = list(set(ent_types))
145
+ # sample negatives
146
+ random.shuffle(ent_types)
147
+ return ent_types[:sampled_neg]
148
+
149
+ def create_dataloader(self, data, entity_types=None, **kwargs):
150
+ return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)
backup/modules/data_proc.py CHANGED
@@ -1,73 +1,73 @@
1
- import json
2
- from tqdm import tqdm
3
- # ast.literal_eval
4
- import ast, re
5
-
6
- path = 'train.json'
7
-
8
- with open(path, 'r') as f:
9
- data = json.load(f)
10
-
11
- def tokenize_text(text):
12
- return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
13
-
14
- def extract_entity_spans(entry):
15
- text = ""
16
- len_start = len("What describes ")
17
- len_end = len(" in the text?")
18
- entity_types = []
19
- entity_texts = []
20
-
21
- for c in entry['conversations']:
22
- if c['from'] == 'human' and c['value'].startswith('Text: '):
23
- text = c['value'][len('Text: '):]
24
- tokenized_text = tokenize_text(text)
25
-
26
- if c['from'] == 'human' and c['value'].startswith('What describes '):
27
-
28
- c_type = c['value'][len_start:-len_end]
29
- c_type = c_type.replace(' ', '_')
30
- entity_types.append(c_type)
31
-
32
- elif c['from'] == 'gpt' and c['value'].startswith('['):
33
- if c['value'] == '[]':
34
- entity_types = entity_types[:-1]
35
- continue
36
-
37
- texts_ents = ast.literal_eval(c['value'])
38
- # replace space to _ in texts_ents
39
- entity_texts.extend(texts_ents)
40
- num_repeat = len(texts_ents) - 1
41
- entity_types.extend([entity_types[-1]] * num_repeat)
42
-
43
- entity_spans = []
44
- for j, entity_text in enumerate(entity_texts):
45
- entity_tokens = tokenize_text(entity_text)
46
- matches = []
47
- for i in range(len(tokenized_text) - len(entity_tokens) + 1):
48
- if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
49
- matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
50
- if matches:
51
- entity_spans.extend(matches)
52
-
53
- return entity_spans, tokenized_text
54
-
55
- # Usage:
56
- # Replace 'entry' with the specific entry from your JSON data
57
- entry = data[17818] # For example, taking the first entry
58
- entity_spans, tokenized_text = extract_entity_spans(entry)
59
- print("Entity Spans:", entity_spans)
60
- #print("Tokenized Text:", tokenized_text)
61
-
62
- # create a dict: {"tokenized_text": tokenized_text, "entity_spans": entity_spans}
63
-
64
- all_data = []
65
-
66
- for entry in tqdm(data):
67
- entity_spans, tokenized_text = extract_entity_spans(entry)
68
- all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans})
69
-
70
-
71
- with open('train_instruct.json', 'w') as f:
72
- json.dump(all_data, f)
73
-
 
1
+ import json
2
+ from tqdm import tqdm
3
+ # ast.literal_eval
4
+ import ast, re
5
+
6
+ path = 'train.json'
7
+
8
+ with open(path, 'r') as f:
9
+ data = json.load(f)
10
+
11
+ def tokenize_text(text):
12
+ return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
13
+
14
+ def extract_entity_spans(entry):
15
+ text = ""
16
+ len_start = len("What describes ")
17
+ len_end = len(" in the text?")
18
+ entity_types = []
19
+ entity_texts = []
20
+
21
+ for c in entry['conversations']:
22
+ if c['from'] == 'human' and c['value'].startswith('Text: '):
23
+ text = c['value'][len('Text: '):]
24
+ tokenized_text = tokenize_text(text)
25
+
26
+ if c['from'] == 'human' and c['value'].startswith('What describes '):
27
+
28
+ c_type = c['value'][len_start:-len_end]
29
+ c_type = c_type.replace(' ', '_')
30
+ entity_types.append(c_type)
31
+
32
+ elif c['from'] == 'gpt' and c['value'].startswith('['):
33
+ if c['value'] == '[]':
34
+ entity_types = entity_types[:-1]
35
+ continue
36
+
37
+ texts_ents = ast.literal_eval(c['value'])
38
+ # replace space to _ in texts_ents
39
+ entity_texts.extend(texts_ents)
40
+ num_repeat = len(texts_ents) - 1
41
+ entity_types.extend([entity_types[-1]] * num_repeat)
42
+
43
+ entity_spans = []
44
+ for j, entity_text in enumerate(entity_texts):
45
+ entity_tokens = tokenize_text(entity_text)
46
+ matches = []
47
+ for i in range(len(tokenized_text) - len(entity_tokens) + 1):
48
+ if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
49
+ matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
50
+ if matches:
51
+ entity_spans.extend(matches)
52
+
53
+ return entity_spans, tokenized_text
54
+
55
+ # Usage:
56
+ # Replace 'entry' with the specific entry from your JSON data
57
+ entry = data[17818] # For example, taking the first entry
58
+ entity_spans, tokenized_text = extract_entity_spans(entry)
59
+ print("Entity Spans:", entity_spans)
60
+ #print("Tokenized Text:", tokenized_text)
61
+
62
+ # create a dict: {"tokenized_text": tokenized_text, "entity_spans": entity_spans}
63
+
64
+ all_data = []
65
+
66
+ for entry in tqdm(data):
67
+ entity_spans, tokenized_text = extract_entity_spans(entry)
68
+ all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans})
69
+
70
+
71
+ with open('train_instruct.json', 'w') as f:
72
+ json.dump(all_data, f)
73
+
backup/modules/evaluator.py CHANGED
@@ -1,152 +1,152 @@
1
- from collections import defaultdict
2
-
3
- import numpy as np
4
- import torch
5
- from seqeval.metrics.v1 import _prf_divide
6
-
7
-
8
- def extract_tp_actual_correct(y_true, y_pred):
9
- entities_true = defaultdict(set)
10
- entities_pred = defaultdict(set)
11
-
12
- for type_name, (start, end), idx in y_true:
13
- entities_true[type_name].add((start, end, idx))
14
- for type_name, (start, end), idx in y_pred:
15
- entities_pred[type_name].add((start, end, idx))
16
-
17
- target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
18
-
19
- tp_sum = np.array([], dtype=np.int32)
20
- pred_sum = np.array([], dtype=np.int32)
21
- true_sum = np.array([], dtype=np.int32)
22
- for type_name in target_names:
23
- entities_true_type = entities_true.get(type_name, set())
24
- entities_pred_type = entities_pred.get(type_name, set())
25
- tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
26
- pred_sum = np.append(pred_sum, len(entities_pred_type))
27
- true_sum = np.append(true_sum, len(entities_true_type))
28
-
29
- return pred_sum, tp_sum, true_sum, target_names
30
-
31
-
32
- def flatten_for_eval(y_true, y_pred):
33
- all_true = []
34
- all_pred = []
35
-
36
- for i, (true, pred) in enumerate(zip(y_true, y_pred)):
37
- all_true.extend([t + [i] for t in true])
38
- all_pred.extend([p + [i] for p in pred])
39
-
40
- return all_true, all_pred
41
-
42
-
43
- def compute_prf(y_true, y_pred, average='micro'):
44
- y_true, y_pred = flatten_for_eval(y_true, y_pred)
45
-
46
- pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
47
-
48
- if average == 'micro':
49
- tp_sum = np.array([tp_sum.sum()])
50
- pred_sum = np.array([pred_sum.sum()])
51
- true_sum = np.array([true_sum.sum()])
52
-
53
- precision = _prf_divide(
54
- numerator=tp_sum,
55
- denominator=pred_sum,
56
- metric='precision',
57
- modifier='predicted',
58
- average=average,
59
- warn_for=('precision', 'recall', 'f-score'),
60
- zero_division='warn'
61
- )
62
-
63
- recall = _prf_divide(
64
- numerator=tp_sum,
65
- denominator=true_sum,
66
- metric='recall',
67
- modifier='true',
68
- average=average,
69
- warn_for=('precision', 'recall', 'f-score'),
70
- zero_division='warn'
71
- )
72
-
73
- denominator = precision + recall
74
- denominator[denominator == 0.] = 1
75
- f_score = 2 * (precision * recall) / denominator
76
-
77
- return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
78
-
79
-
80
- class Evaluator:
81
- def __init__(self, all_true, all_outs):
82
- self.all_true = all_true
83
- self.all_outs = all_outs
84
-
85
- def get_entities_fr(self, ents):
86
- all_ents = []
87
- for s, e, lab in ents:
88
- all_ents.append([lab, (s, e)])
89
- return all_ents
90
-
91
- def transform_data(self):
92
- all_true_ent = []
93
- all_outs_ent = []
94
- for i, j in zip(self.all_true, self.all_outs):
95
- e = self.get_entities_fr(i)
96
- all_true_ent.append(e)
97
- e = self.get_entities_fr(j)
98
- all_outs_ent.append(e)
99
- return all_true_ent, all_outs_ent
100
-
101
- @torch.no_grad()
102
- def evaluate(self):
103
- all_true_typed, all_outs_typed = self.transform_data()
104
- precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
105
- output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
106
- return output_str, f1
107
-
108
-
109
- def is_nested(idx1, idx2):
110
- # Return True if idx2 is nested inside idx1 or vice versa
111
- return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
112
-
113
-
114
- def has_overlapping(idx1, idx2):
115
- overlapping = True
116
- if idx1[:2] == idx2[:2]:
117
- return overlapping
118
- if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
119
- overlapping = False
120
- return overlapping
121
-
122
-
123
- def has_overlapping_nested(idx1, idx2):
124
- # Return True if idx1 and idx2 overlap, but neither is nested inside the other
125
- if idx1[:2] == idx2[:2]:
126
- return True
127
- if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
128
- return False
129
- else:
130
- return True
131
-
132
-
133
- def greedy_search(spans, flat_ner=True): # start, end, class, score
134
-
135
- if flat_ner:
136
- has_ov = has_overlapping
137
- else:
138
- has_ov = has_overlapping_nested
139
-
140
- new_list = []
141
- span_prob = sorted(spans, key=lambda x: -x[-1])
142
- for i in range(len(spans)):
143
- b = span_prob[i]
144
- flag = False
145
- for new in new_list:
146
- if has_ov(b[:-1], new):
147
- flag = True
148
- break
149
- if not flag:
150
- new_list.append(b[:-1])
151
- new_list = sorted(new_list, key=lambda x: x[0])
152
- return new_list
 
1
+ from collections import defaultdict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from seqeval.metrics.v1 import _prf_divide
6
+
7
+
8
+ def extract_tp_actual_correct(y_true, y_pred):
9
+ entities_true = defaultdict(set)
10
+ entities_pred = defaultdict(set)
11
+
12
+ for type_name, (start, end), idx in y_true:
13
+ entities_true[type_name].add((start, end, idx))
14
+ for type_name, (start, end), idx in y_pred:
15
+ entities_pred[type_name].add((start, end, idx))
16
+
17
+ target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
18
+
19
+ tp_sum = np.array([], dtype=np.int32)
20
+ pred_sum = np.array([], dtype=np.int32)
21
+ true_sum = np.array([], dtype=np.int32)
22
+ for type_name in target_names:
23
+ entities_true_type = entities_true.get(type_name, set())
24
+ entities_pred_type = entities_pred.get(type_name, set())
25
+ tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
26
+ pred_sum = np.append(pred_sum, len(entities_pred_type))
27
+ true_sum = np.append(true_sum, len(entities_true_type))
28
+
29
+ return pred_sum, tp_sum, true_sum, target_names
30
+
31
+
32
+ def flatten_for_eval(y_true, y_pred):
33
+ all_true = []
34
+ all_pred = []
35
+
36
+ for i, (true, pred) in enumerate(zip(y_true, y_pred)):
37
+ all_true.extend([t + [i] for t in true])
38
+ all_pred.extend([p + [i] for p in pred])
39
+
40
+ return all_true, all_pred
41
+
42
+
43
+ def compute_prf(y_true, y_pred, average='micro'):
44
+ y_true, y_pred = flatten_for_eval(y_true, y_pred)
45
+
46
+ pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
47
+
48
+ if average == 'micro':
49
+ tp_sum = np.array([tp_sum.sum()])
50
+ pred_sum = np.array([pred_sum.sum()])
51
+ true_sum = np.array([true_sum.sum()])
52
+
53
+ precision = _prf_divide(
54
+ numerator=tp_sum,
55
+ denominator=pred_sum,
56
+ metric='precision',
57
+ modifier='predicted',
58
+ average=average,
59
+ warn_for=('precision', 'recall', 'f-score'),
60
+ zero_division='warn'
61
+ )
62
+
63
+ recall = _prf_divide(
64
+ numerator=tp_sum,
65
+ denominator=true_sum,
66
+ metric='recall',
67
+ modifier='true',
68
+ average=average,
69
+ warn_for=('precision', 'recall', 'f-score'),
70
+ zero_division='warn'
71
+ )
72
+
73
+ denominator = precision + recall
74
+ denominator[denominator == 0.] = 1
75
+ f_score = 2 * (precision * recall) / denominator
76
+
77
+ return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
78
+
79
+
80
+ class Evaluator:
81
+ def __init__(self, all_true, all_outs):
82
+ self.all_true = all_true
83
+ self.all_outs = all_outs
84
+
85
+ def get_entities_fr(self, ents):
86
+ all_ents = []
87
+ for s, e, lab in ents:
88
+ all_ents.append([lab, (s, e)])
89
+ return all_ents
90
+
91
+ def transform_data(self):
92
+ all_true_ent = []
93
+ all_outs_ent = []
94
+ for i, j in zip(self.all_true, self.all_outs):
95
+ e = self.get_entities_fr(i)
96
+ all_true_ent.append(e)
97
+ e = self.get_entities_fr(j)
98
+ all_outs_ent.append(e)
99
+ return all_true_ent, all_outs_ent
100
+
101
+ @torch.no_grad()
102
+ def evaluate(self):
103
+ all_true_typed, all_outs_typed = self.transform_data()
104
+ precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
105
+ output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
106
+ return output_str, f1
107
+
108
+
109
+ def is_nested(idx1, idx2):
110
+ # Return True if idx2 is nested inside idx1 or vice versa
111
+ return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
112
+
113
+
114
+ def has_overlapping(idx1, idx2):
115
+ overlapping = True
116
+ if idx1[:2] == idx2[:2]:
117
+ return overlapping
118
+ if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
119
+ overlapping = False
120
+ return overlapping
121
+
122
+
123
+ def has_overlapping_nested(idx1, idx2):
124
+ # Return True if idx1 and idx2 overlap, but neither is nested inside the other
125
+ if idx1[:2] == idx2[:2]:
126
+ return True
127
+ if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
128
+ return False
129
+ else:
130
+ return True
131
+
132
+
133
+ def greedy_search(spans, flat_ner=True): # start, end, class, score
134
+
135
+ if flat_ner:
136
+ has_ov = has_overlapping
137
+ else:
138
+ has_ov = has_overlapping_nested
139
+
140
+ new_list = []
141
+ span_prob = sorted(spans, key=lambda x: -x[-1])
142
+ for i in range(len(spans)):
143
+ b = span_prob[i]
144
+ flag = False
145
+ for new in new_list:
146
+ if has_ov(b[:-1], new):
147
+ flag = True
148
+ break
149
+ if not flag:
150
+ new_list.append(b[:-1])
151
+ new_list = sorted(new_list, key=lambda x: x[0])
152
+ return new_list
backup/modules/layers.py CHANGED
@@ -1,28 +1,28 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5
-
6
-
7
- class LstmSeq2SeqEncoder(nn.Module):
8
- def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
9
- super(LstmSeq2SeqEncoder, self).__init__()
10
- self.lstm = nn.LSTM(input_size=input_size,
11
- hidden_size=hidden_size,
12
- num_layers=num_layers,
13
- dropout=dropout,
14
- bidirectional=bidirectional,
15
- batch_first=True)
16
-
17
- def forward(self, x, mask, hidden=None):
18
- # Packing the input sequence
19
- lengths = mask.sum(dim=1).cpu()
20
- packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
21
-
22
- # Passing packed sequence through LSTM
23
- packed_output, hidden = self.lstm(packed_x, hidden)
24
-
25
- # Unpacking the output sequence
26
- output, _ = pad_packed_sequence(packed_output, batch_first=True)
27
-
28
- return output
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5
+
6
+
7
+ class LstmSeq2SeqEncoder(nn.Module):
8
+ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
9
+ super(LstmSeq2SeqEncoder, self).__init__()
10
+ self.lstm = nn.LSTM(input_size=input_size,
11
+ hidden_size=hidden_size,
12
+ num_layers=num_layers,
13
+ dropout=dropout,
14
+ bidirectional=bidirectional,
15
+ batch_first=True)
16
+
17
+ def forward(self, x, mask, hidden=None):
18
+ # Packing the input sequence
19
+ lengths = mask.sum(dim=1).cpu()
20
+ packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
21
+
22
+ # Passing packed sequence through LSTM
23
+ packed_output, hidden = self.lstm(packed_x, hidden)
24
+
25
+ # Unpacking the output sequence
26
+ output, _ = pad_packed_sequence(packed_output, batch_first=True)
27
+
28
+ return output
backup/modules/run_evaluation.py CHANGED
@@ -1,188 +1,188 @@
1
- import glob
2
- import json
3
- import os
4
- import os
5
-
6
- import torch
7
- from tqdm import tqdm
8
- import random
9
-
10
-
11
- def open_content(path):
12
- paths = glob.glob(os.path.join(path, "*.json"))
13
- train, dev, test, labels = None, None, None, None
14
- for p in paths:
15
- if "train" in p:
16
- with open(p, "r") as f:
17
- train = json.load(f)
18
- elif "dev" in p:
19
- with open(p, "r") as f:
20
- dev = json.load(f)
21
- elif "test" in p:
22
- with open(p, "r") as f:
23
- test = json.load(f)
24
- elif "labels" in p:
25
- with open(p, "r") as f:
26
- labels = json.load(f)
27
- return train, dev, test, labels
28
-
29
-
30
- def process(data):
31
- words = data['sentence'].split()
32
- entities = [] # List of entities (start, end, type)
33
-
34
- for entity in data['entities']:
35
- start_char, end_char = entity['pos']
36
-
37
- # Initialize variables to keep track of word positions
38
- start_word = None
39
- end_word = None
40
-
41
- # Iterate through words and find the word positions
42
- char_count = 0
43
- for i, word in enumerate(words):
44
- word_length = len(word)
45
- if char_count == start_char:
46
- start_word = i
47
- if char_count + word_length == end_char:
48
- end_word = i
49
- break
50
- char_count += word_length + 1 # Add 1 for the space
51
-
52
- # Append the word positions to the list
53
- entities.append((start_word, end_word, entity['type']))
54
-
55
- # Create a list of word positions for each entity
56
- sample = {
57
- "tokenized_text": words,
58
- "ner": entities
59
- }
60
-
61
- return sample
62
-
63
-
64
- # create dataset
65
- def create_dataset(path):
66
- train, dev, test, labels = open_content(path)
67
- train_dataset = []
68
- dev_dataset = []
69
- test_dataset = []
70
- for data in train:
71
- train_dataset.append(process(data))
72
- for data in dev:
73
- dev_dataset.append(process(data))
74
- for data in test:
75
- test_dataset.append(process(data))
76
- return train_dataset, dev_dataset, test_dataset, labels
77
-
78
-
79
- @torch.no_grad()
80
- def get_for_one_path(path, model):
81
- # load the dataset
82
- _, _, test_dataset, entity_types = create_dataset(path)
83
-
84
- data_name = path.split("/")[-1] # get the name of the dataset
85
-
86
- # check if the dataset is flat_ner
87
- flat_ner = True
88
- if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]):
89
- flat_ner = False
90
-
91
- # evaluate the model
92
- results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12,
93
- entity_types=entity_types)
94
- return data_name, results, f1
95
-
96
-
97
- def get_for_all_path(model, steps, log_dir, data_paths):
98
- all_paths = glob.glob(f"{data_paths}/*")
99
-
100
- all_paths = sorted(all_paths)
101
-
102
- # move the model to the device
103
- device = next(model.parameters()).device
104
- model.to(device)
105
- # set the model to eval mode
106
- model.eval()
107
-
108
- # log the results
109
- save_path = os.path.join(log_dir, "results.txt")
110
-
111
- with open(save_path, "a") as f:
112
- f.write("##############################################\n")
113
- # write step
114
- f.write("step: " + str(steps) + "\n")
115
-
116
- zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music",
117
- "CrossNER_politics", "CrossNER_science"]
118
-
119
- zero_shot_benc_results = {}
120
- all_results = {} # without crossNER
121
-
122
- for p in tqdm(all_paths):
123
- if "sample_" not in p:
124
- data_name, results, f1 = get_for_one_path(p, model)
125
- # write to file
126
- with open(save_path, "a") as f:
127
- f.write(data_name + "\n")
128
- f.write(str(results) + "\n")
129
-
130
- if data_name in zero_shot_benc:
131
- zero_shot_benc_results[data_name] = f1
132
- else:
133
- all_results[data_name] = f1
134
-
135
- avg_all = sum(all_results.values()) / len(all_results)
136
- avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results)
137
-
138
- save_path_table = os.path.join(log_dir, "tables.txt")
139
-
140
- # results for all datasets except crossNER
141
- table_bench_all = ""
142
- for k, v in all_results.items():
143
- table_bench_all += f"{k:20}: {v:.1%}\n"
144
- # (20 size aswell for average i.e. :20)
145
- table_bench_all += f"{'Average':20}: {avg_all:.1%}"
146
-
147
- # results for zero-shot benchmark
148
- table_bench_zeroshot = ""
149
- for k, v in zero_shot_benc_results.items():
150
- table_bench_zeroshot += f"{k:20}: {v:.1%}\n"
151
- table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}"
152
-
153
- # write to file
154
- with open(save_path_table, "a") as f:
155
- f.write("##############################################\n")
156
- f.write("step: " + str(steps) + "\n")
157
- f.write("Table for all datasets except crossNER\n")
158
- f.write(table_bench_all + "\n\n")
159
- f.write("Table for zero-shot benchmark\n")
160
- f.write(table_bench_zeroshot + "\n")
161
- f.write("##############################################\n\n")
162
-
163
-
164
- def sample_train_data(data_paths, sample_size=10000):
165
- all_paths = glob.glob(f"{data_paths}/*")
166
-
167
- all_paths = sorted(all_paths)
168
-
169
- # to exclude the zero-shot benchmark datasets
170
- zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music",
171
- "CrossNER_politics", "CrossNER_science", "ACE 2004"]
172
-
173
- new_train = []
174
- # take 10k samples from each dataset
175
- for p in tqdm(all_paths):
176
- if any([i in p for i in zero_shot_benc]):
177
- continue
178
- train, dev, test, labels = create_dataset(p)
179
-
180
- # add label key to the train data
181
- for i in range(len(train)):
182
- train[i]["label"] = labels
183
-
184
- random.shuffle(train)
185
- train = train[:sample_size]
186
- new_train.extend(train)
187
-
188
- return new_train
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import os
5
+
6
+ import torch
7
+ from tqdm import tqdm
8
+ import random
9
+
10
+
11
+ def open_content(path):
12
+ paths = glob.glob(os.path.join(path, "*.json"))
13
+ train, dev, test, labels = None, None, None, None
14
+ for p in paths:
15
+ if "train" in p:
16
+ with open(p, "r") as f:
17
+ train = json.load(f)
18
+ elif "dev" in p:
19
+ with open(p, "r") as f:
20
+ dev = json.load(f)
21
+ elif "test" in p:
22
+ with open(p, "r") as f:
23
+ test = json.load(f)
24
+ elif "labels" in p:
25
+ with open(p, "r") as f:
26
+ labels = json.load(f)
27
+ return train, dev, test, labels
28
+
29
+
30
+ def process(data):
31
+ words = data['sentence'].split()
32
+ entities = [] # List of entities (start, end, type)
33
+
34
+ for entity in data['entities']:
35
+ start_char, end_char = entity['pos']
36
+
37
+ # Initialize variables to keep track of word positions
38
+ start_word = None
39
+ end_word = None
40
+
41
+ # Iterate through words and find the word positions
42
+ char_count = 0
43
+ for i, word in enumerate(words):
44
+ word_length = len(word)
45
+ if char_count == start_char:
46
+ start_word = i
47
+ if char_count + word_length == end_char:
48
+ end_word = i
49
+ break
50
+ char_count += word_length + 1 # Add 1 for the space
51
+
52
+ # Append the word positions to the list
53
+ entities.append((start_word, end_word, entity['type']))
54
+
55
+ # Create a list of word positions for each entity
56
+ sample = {
57
+ "tokenized_text": words,
58
+ "ner": entities
59
+ }
60
+
61
+ return sample
62
+
63
+
64
+ # create dataset
65
+ def create_dataset(path):
66
+ train, dev, test, labels = open_content(path)
67
+ train_dataset = []
68
+ dev_dataset = []
69
+ test_dataset = []
70
+ for data in train:
71
+ train_dataset.append(process(data))
72
+ for data in dev:
73
+ dev_dataset.append(process(data))
74
+ for data in test:
75
+ test_dataset.append(process(data))
76
+ return train_dataset, dev_dataset, test_dataset, labels
77
+
78
+
79
+ @torch.no_grad()
80
+ def get_for_one_path(path, model):
81
+ # load the dataset
82
+ _, _, test_dataset, entity_types = create_dataset(path)
83
+
84
+ data_name = path.split("/")[-1] # get the name of the dataset
85
+
86
+ # check if the dataset is flat_ner
87
+ flat_ner = True
88
+ if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]):
89
+ flat_ner = False
90
+
91
+ # evaluate the model
92
+ results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12,
93
+ entity_types=entity_types)
94
+ return data_name, results, f1
95
+
96
+
97
+ def get_for_all_path(model, steps, log_dir, data_paths):
98
+ all_paths = glob.glob(f"{data_paths}/*")
99
+
100
+ all_paths = sorted(all_paths)
101
+
102
+ # move the model to the device
103
+ device = next(model.parameters()).device
104
+ model.to(device)
105
+ # set the model to eval mode
106
+ model.eval()
107
+
108
+ # log the results
109
+ save_path = os.path.join(log_dir, "results.txt")
110
+
111
+ with open(save_path, "a") as f:
112
+ f.write("##############################################\n")
113
+ # write step
114
+ f.write("step: " + str(steps) + "\n")
115
+
116
+ zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music",
117
+ "CrossNER_politics", "CrossNER_science"]
118
+
119
+ zero_shot_benc_results = {}
120
+ all_results = {} # without crossNER
121
+
122
+ for p in tqdm(all_paths):
123
+ if "sample_" not in p:
124
+ data_name, results, f1 = get_for_one_path(p, model)
125
+ # write to file
126
+ with open(save_path, "a") as f:
127
+ f.write(data_name + "\n")
128
+ f.write(str(results) + "\n")
129
+
130
+ if data_name in zero_shot_benc:
131
+ zero_shot_benc_results[data_name] = f1
132
+ else:
133
+ all_results[data_name] = f1
134
+
135
+ avg_all = sum(all_results.values()) / len(all_results)
136
+ avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results)
137
+
138
+ save_path_table = os.path.join(log_dir, "tables.txt")
139
+
140
+ # results for all datasets except crossNER
141
+ table_bench_all = ""
142
+ for k, v in all_results.items():
143
+ table_bench_all += f"{k:20}: {v:.1%}\n"
144
+ # (20 size aswell for average i.e. :20)
145
+ table_bench_all += f"{'Average':20}: {avg_all:.1%}"
146
+
147
+ # results for zero-shot benchmark
148
+ table_bench_zeroshot = ""
149
+ for k, v in zero_shot_benc_results.items():
150
+ table_bench_zeroshot += f"{k:20}: {v:.1%}\n"
151
+ table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}"
152
+
153
+ # write to file
154
+ with open(save_path_table, "a") as f:
155
+ f.write("##############################################\n")
156
+ f.write("step: " + str(steps) + "\n")
157
+ f.write("Table for all datasets except crossNER\n")
158
+ f.write(table_bench_all + "\n\n")
159
+ f.write("Table for zero-shot benchmark\n")
160
+ f.write(table_bench_zeroshot + "\n")
161
+ f.write("##############################################\n\n")
162
+
163
+
164
+ def sample_train_data(data_paths, sample_size=10000):
165
+ all_paths = glob.glob(f"{data_paths}/*")
166
+
167
+ all_paths = sorted(all_paths)
168
+
169
+ # to exclude the zero-shot benchmark datasets
170
+ zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music",
171
+ "CrossNER_politics", "CrossNER_science", "ACE 2004"]
172
+
173
+ new_train = []
174
+ # take 10k samples from each dataset
175
+ for p in tqdm(all_paths):
176
+ if any([i in p for i in zero_shot_benc]):
177
+ continue
178
+ train, dev, test, labels = create_dataset(p)
179
+
180
+ # add label key to the train data
181
+ for i in range(len(train)):
182
+ train[i]["label"] = labels
183
+
184
+ random.shuffle(train)
185
+ train = train[:sample_size]
186
+ new_train.extend(train)
187
+
188
+ return new_train
backup/modules/span_rep.py CHANGED
@@ -1,369 +1,369 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
- def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
6
- """
7
- Creates a projection layer with specified configurations.
8
- """
9
- if out_dim is None:
10
- out_dim = hidden_size
11
-
12
- return nn.Sequential(
13
- nn.Linear(hidden_size, out_dim * 4),
14
- nn.ReLU(),
15
- nn.Dropout(dropout),
16
- nn.Linear(out_dim * 4, out_dim)
17
- )
18
-
19
-
20
- class SpanQuery(nn.Module):
21
-
22
- def __init__(self, hidden_size, max_width, trainable=True):
23
- super().__init__()
24
-
25
- self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
26
-
27
- nn.init.uniform_(self.query_seg, a=-1, b=1)
28
-
29
- if not trainable:
30
- self.query_seg.requires_grad = False
31
-
32
- self.project = nn.Sequential(
33
- nn.Linear(hidden_size, hidden_size),
34
- nn.ReLU()
35
- )
36
-
37
- def forward(self, h, *args):
38
- # h of shape [B, L, D]
39
- # query_seg of shape [D, max_width]
40
-
41
- span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg)
42
-
43
- return self.project(span_rep)
44
-
45
-
46
- class SpanMLP(nn.Module):
47
-
48
- def __init__(self, hidden_size, max_width):
49
- super().__init__()
50
-
51
- self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
52
-
53
- def forward(self, h, *args):
54
- # h of shape [B, L, D]
55
- # query_seg of shape [D, max_width]
56
-
57
- B, L, D = h.size()
58
-
59
- span_rep = self.mlp(h)
60
-
61
- span_rep = span_rep.view(B, L, -1, D)
62
-
63
- return span_rep.relu()
64
-
65
-
66
- class SpanCAT(nn.Module):
67
-
68
- def __init__(self, hidden_size, max_width):
69
- super().__init__()
70
-
71
- self.max_width = max_width
72
-
73
- self.query_seg = nn.Parameter(torch.randn(128, max_width))
74
-
75
- self.project = nn.Sequential(
76
- nn.Linear(hidden_size + 128, hidden_size),
77
- nn.ReLU()
78
- )
79
-
80
- def forward(self, h, *args):
81
- # h of shape [B, L, D]
82
- # query_seg of shape [D, max_width]
83
-
84
- B, L, D = h.size()
85
-
86
- h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
87
-
88
- q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
89
-
90
- span_rep = torch.cat([h, q], dim=-1)
91
-
92
- span_rep = self.project(span_rep)
93
-
94
- return span_rep
95
-
96
-
97
- class SpanConvBlock(nn.Module):
98
- def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
99
- super().__init__()
100
-
101
- if span_mode == 'conv_conv':
102
- self.conv = nn.Conv1d(hidden_size, hidden_size,
103
- kernel_size=kernel_size)
104
-
105
- # initialize the weights
106
- nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
107
-
108
- elif span_mode == 'conv_max':
109
- self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
110
- elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
111
- self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
112
-
113
- self.span_mode = span_mode
114
-
115
- self.pad = kernel_size - 1
116
-
117
- def forward(self, x):
118
-
119
- x = torch.einsum('bld->bdl', x)
120
-
121
- if self.pad > 0:
122
- x = F.pad(x, (0, self.pad), "constant", 0)
123
-
124
- x = self.conv(x)
125
-
126
- if self.span_mode == "conv_sum":
127
- x = x * (self.pad + 1)
128
-
129
- return torch.einsum('bdl->bld', x)
130
-
131
-
132
- class SpanConv(nn.Module):
133
- def __init__(self, hidden_size, max_width, span_mode):
134
- super().__init__()
135
-
136
- kernels = [i + 2 for i in range(max_width - 1)]
137
-
138
- self.convs = nn.ModuleList()
139
-
140
- for kernel in kernels:
141
- self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
142
-
143
- self.project = nn.Sequential(
144
- nn.ReLU(),
145
- nn.Linear(hidden_size, hidden_size)
146
- )
147
-
148
- def forward(self, x, *args):
149
-
150
- span_reps = [x]
151
-
152
- for conv in self.convs:
153
- h = conv(x)
154
- span_reps.append(h)
155
-
156
- span_reps = torch.stack(span_reps, dim=-2)
157
-
158
- return self.project(span_reps)
159
-
160
-
161
- class SpanEndpointsBlock(nn.Module):
162
- def __init__(self, kernel_size):
163
- super().__init__()
164
-
165
- self.kernel_size = kernel_size
166
-
167
- def forward(self, x):
168
- B, L, D = x.size()
169
-
170
- span_idx = torch.LongTensor(
171
- [[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
172
-
173
- x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
174
-
175
- # endrep
176
- start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
177
-
178
- start_end_rep = start_end_rep.view(B, L, 2, D)
179
-
180
- return start_end_rep
181
-
182
-
183
- class ConvShare(nn.Module):
184
- def __init__(self, hidden_size, max_width):
185
- super().__init__()
186
-
187
- self.max_width = max_width
188
-
189
- self.conv_weigth = nn.Parameter(
190
- torch.randn(hidden_size, hidden_size, max_width))
191
-
192
- nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu')
193
-
194
- self.project = nn.Sequential(
195
- nn.ReLU(),
196
- nn.Linear(hidden_size, hidden_size)
197
- )
198
-
199
- def forward(self, x, *args):
200
- span_reps = []
201
-
202
- x = torch.einsum('bld->bdl', x)
203
-
204
- for i in range(self.max_width):
205
- pad = i
206
- x_i = F.pad(x, (0, pad), "constant", 0)
207
- conv_w = self.conv_weigth[:, :, :i + 1]
208
- out_i = F.conv1d(x_i, conv_w)
209
- span_reps.append(out_i.transpose(-1, -2))
210
-
211
- out = torch.stack(span_reps, dim=-2)
212
-
213
- return self.project(out)
214
-
215
-
216
- def extract_elements(sequence, indices):
217
- B, L, D = sequence.shape
218
- K = indices.shape[1]
219
-
220
- # Expand indices to [B, K, D]
221
- expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
222
-
223
- # Gather the elements
224
- extracted_elements = torch.gather(sequence, 1, expanded_indices)
225
-
226
- return extracted_elements
227
-
228
-
229
- class SpanMarker(nn.Module):
230
-
231
- def __init__(self, hidden_size, max_width, dropout=0.4):
232
- super().__init__()
233
-
234
- self.max_width = max_width
235
-
236
- self.project_start = nn.Sequential(
237
- nn.Linear(hidden_size, hidden_size * 2, bias=True),
238
- nn.ReLU(),
239
- nn.Dropout(dropout),
240
- nn.Linear(hidden_size * 2, hidden_size, bias=True),
241
- )
242
-
243
- self.project_end = nn.Sequential(
244
- nn.Linear(hidden_size, hidden_size * 2, bias=True),
245
- nn.ReLU(),
246
- nn.Dropout(dropout),
247
- nn.Linear(hidden_size * 2, hidden_size, bias=True),
248
- )
249
-
250
- self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
251
-
252
- def forward(self, h, span_idx):
253
- # h of shape [B, L, D]
254
- # query_seg of shape [D, max_width]
255
-
256
- B, L, D = h.size()
257
-
258
- # project start and end
259
- start_rep = self.project_start(h)
260
- end_rep = self.project_end(h)
261
-
262
- start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
263
- end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
264
-
265
- # concat start and end
266
- cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
267
-
268
- # project
269
- cat = self.out_project(cat)
270
-
271
- # reshape
272
- return cat.view(B, L, self.max_width, D)
273
-
274
-
275
- class SpanMarkerV0(nn.Module):
276
- """
277
- Marks and projects span endpoints using an MLP.
278
- """
279
-
280
- def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
281
- super().__init__()
282
- self.max_width = max_width
283
- self.project_start = create_projection_layer(hidden_size, dropout)
284
- self.project_end = create_projection_layer(hidden_size, dropout)
285
-
286
- self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
287
-
288
- def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
289
- B, L, D = h.size()
290
-
291
- start_rep = self.project_start(h)
292
- end_rep = self.project_end(h)
293
-
294
- start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
295
- end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
296
-
297
- cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
298
-
299
- return self.out_project(cat).view(B, L, self.max_width, D)
300
-
301
-
302
- class ConvShareV2(nn.Module):
303
- def __init__(self, hidden_size, max_width):
304
- super().__init__()
305
-
306
- self.max_width = max_width
307
-
308
- self.conv_weigth = nn.Parameter(
309
- torch.randn(hidden_size, hidden_size, max_width)
310
- )
311
-
312
- nn.init.xavier_normal_(self.conv_weigth)
313
-
314
- def forward(self, x, *args):
315
- span_reps = []
316
-
317
- x = torch.einsum('bld->bdl', x)
318
-
319
- for i in range(self.max_width):
320
- pad = i
321
- x_i = F.pad(x, (0, pad), "constant", 0)
322
- conv_w = self.conv_weigth[:, :, :i + 1]
323
- out_i = F.conv1d(x_i, conv_w)
324
- span_reps.append(out_i.transpose(-1, -2))
325
-
326
- out = torch.stack(span_reps, dim=-2)
327
-
328
- return out
329
-
330
-
331
- class SpanRepLayer(nn.Module):
332
- """
333
- Various span representation approaches
334
- """
335
-
336
- def __init__(self, hidden_size, max_width, span_mode, **kwargs):
337
- super().__init__()
338
-
339
- if span_mode == 'marker':
340
- self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
341
- elif span_mode == 'markerV0':
342
- self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs)
343
- elif span_mode == 'query':
344
- self.span_rep_layer = SpanQuery(
345
- hidden_size, max_width, trainable=True)
346
- elif span_mode == 'mlp':
347
- self.span_rep_layer = SpanMLP(hidden_size, max_width)
348
- elif span_mode == 'cat':
349
- self.span_rep_layer = SpanCAT(hidden_size, max_width)
350
- elif span_mode == 'conv_conv':
351
- self.span_rep_layer = SpanConv(
352
- hidden_size, max_width, span_mode='conv_conv')
353
- elif span_mode == 'conv_max':
354
- self.span_rep_layer = SpanConv(
355
- hidden_size, max_width, span_mode='conv_max')
356
- elif span_mode == 'conv_mean':
357
- self.span_rep_layer = SpanConv(
358
- hidden_size, max_width, span_mode='conv_mean')
359
- elif span_mode == 'conv_sum':
360
- self.span_rep_layer = SpanConv(
361
- hidden_size, max_width, span_mode='conv_sum')
362
- elif span_mode == 'conv_share':
363
- self.span_rep_layer = ConvShare(hidden_size, max_width)
364
- else:
365
- raise ValueError(f'Unknown span mode {span_mode}')
366
-
367
- def forward(self, x, *args):
368
-
369
- return self.span_rep_layer(x, *args)
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
6
+ """
7
+ Creates a projection layer with specified configurations.
8
+ """
9
+ if out_dim is None:
10
+ out_dim = hidden_size
11
+
12
+ return nn.Sequential(
13
+ nn.Linear(hidden_size, out_dim * 4),
14
+ nn.ReLU(),
15
+ nn.Dropout(dropout),
16
+ nn.Linear(out_dim * 4, out_dim)
17
+ )
18
+
19
+
20
+ class SpanQuery(nn.Module):
21
+
22
+ def __init__(self, hidden_size, max_width, trainable=True):
23
+ super().__init__()
24
+
25
+ self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
26
+
27
+ nn.init.uniform_(self.query_seg, a=-1, b=1)
28
+
29
+ if not trainable:
30
+ self.query_seg.requires_grad = False
31
+
32
+ self.project = nn.Sequential(
33
+ nn.Linear(hidden_size, hidden_size),
34
+ nn.ReLU()
35
+ )
36
+
37
+ def forward(self, h, *args):
38
+ # h of shape [B, L, D]
39
+ # query_seg of shape [D, max_width]
40
+
41
+ span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg)
42
+
43
+ return self.project(span_rep)
44
+
45
+
46
+ class SpanMLP(nn.Module):
47
+
48
+ def __init__(self, hidden_size, max_width):
49
+ super().__init__()
50
+
51
+ self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
52
+
53
+ def forward(self, h, *args):
54
+ # h of shape [B, L, D]
55
+ # query_seg of shape [D, max_width]
56
+
57
+ B, L, D = h.size()
58
+
59
+ span_rep = self.mlp(h)
60
+
61
+ span_rep = span_rep.view(B, L, -1, D)
62
+
63
+ return span_rep.relu()
64
+
65
+
66
+ class SpanCAT(nn.Module):
67
+
68
+ def __init__(self, hidden_size, max_width):
69
+ super().__init__()
70
+
71
+ self.max_width = max_width
72
+
73
+ self.query_seg = nn.Parameter(torch.randn(128, max_width))
74
+
75
+ self.project = nn.Sequential(
76
+ nn.Linear(hidden_size + 128, hidden_size),
77
+ nn.ReLU()
78
+ )
79
+
80
+ def forward(self, h, *args):
81
+ # h of shape [B, L, D]
82
+ # query_seg of shape [D, max_width]
83
+
84
+ B, L, D = h.size()
85
+
86
+ h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
87
+
88
+ q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
89
+
90
+ span_rep = torch.cat([h, q], dim=-1)
91
+
92
+ span_rep = self.project(span_rep)
93
+
94
+ return span_rep
95
+
96
+
97
+ class SpanConvBlock(nn.Module):
98
+ def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
99
+ super().__init__()
100
+
101
+ if span_mode == 'conv_conv':
102
+ self.conv = nn.Conv1d(hidden_size, hidden_size,
103
+ kernel_size=kernel_size)
104
+
105
+ # initialize the weights
106
+ nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
107
+
108
+ elif span_mode == 'conv_max':
109
+ self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
110
+ elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
111
+ self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
112
+
113
+ self.span_mode = span_mode
114
+
115
+ self.pad = kernel_size - 1
116
+
117
+ def forward(self, x):
118
+
119
+ x = torch.einsum('bld->bdl', x)
120
+
121
+ if self.pad > 0:
122
+ x = F.pad(x, (0, self.pad), "constant", 0)
123
+
124
+ x = self.conv(x)
125
+
126
+ if self.span_mode == "conv_sum":
127
+ x = x * (self.pad + 1)
128
+
129
+ return torch.einsum('bdl->bld', x)
130
+
131
+
132
+ class SpanConv(nn.Module):
133
+ def __init__(self, hidden_size, max_width, span_mode):
134
+ super().__init__()
135
+
136
+ kernels = [i + 2 for i in range(max_width - 1)]
137
+
138
+ self.convs = nn.ModuleList()
139
+
140
+ for kernel in kernels:
141
+ self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
142
+
143
+ self.project = nn.Sequential(
144
+ nn.ReLU(),
145
+ nn.Linear(hidden_size, hidden_size)
146
+ )
147
+
148
+ def forward(self, x, *args):
149
+
150
+ span_reps = [x]
151
+
152
+ for conv in self.convs:
153
+ h = conv(x)
154
+ span_reps.append(h)
155
+
156
+ span_reps = torch.stack(span_reps, dim=-2)
157
+
158
+ return self.project(span_reps)
159
+
160
+
161
+ class SpanEndpointsBlock(nn.Module):
162
+ def __init__(self, kernel_size):
163
+ super().__init__()
164
+
165
+ self.kernel_size = kernel_size
166
+
167
+ def forward(self, x):
168
+ B, L, D = x.size()
169
+
170
+ span_idx = torch.LongTensor(
171
+ [[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
172
+
173
+ x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
174
+
175
+ # endrep
176
+ start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
177
+
178
+ start_end_rep = start_end_rep.view(B, L, 2, D)
179
+
180
+ return start_end_rep
181
+
182
+
183
+ class ConvShare(nn.Module):
184
+ def __init__(self, hidden_size, max_width):
185
+ super().__init__()
186
+
187
+ self.max_width = max_width
188
+
189
+ self.conv_weigth = nn.Parameter(
190
+ torch.randn(hidden_size, hidden_size, max_width))
191
+
192
+ nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu')
193
+
194
+ self.project = nn.Sequential(
195
+ nn.ReLU(),
196
+ nn.Linear(hidden_size, hidden_size)
197
+ )
198
+
199
+ def forward(self, x, *args):
200
+ span_reps = []
201
+
202
+ x = torch.einsum('bld->bdl', x)
203
+
204
+ for i in range(self.max_width):
205
+ pad = i
206
+ x_i = F.pad(x, (0, pad), "constant", 0)
207
+ conv_w = self.conv_weigth[:, :, :i + 1]
208
+ out_i = F.conv1d(x_i, conv_w)
209
+ span_reps.append(out_i.transpose(-1, -2))
210
+
211
+ out = torch.stack(span_reps, dim=-2)
212
+
213
+ return self.project(out)
214
+
215
+
216
+ def extract_elements(sequence, indices):
217
+ B, L, D = sequence.shape
218
+ K = indices.shape[1]
219
+
220
+ # Expand indices to [B, K, D]
221
+ expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
222
+
223
+ # Gather the elements
224
+ extracted_elements = torch.gather(sequence, 1, expanded_indices)
225
+
226
+ return extracted_elements
227
+
228
+
229
+ class SpanMarker(nn.Module):
230
+
231
+ def __init__(self, hidden_size, max_width, dropout=0.4):
232
+ super().__init__()
233
+
234
+ self.max_width = max_width
235
+
236
+ self.project_start = nn.Sequential(
237
+ nn.Linear(hidden_size, hidden_size * 2, bias=True),
238
+ nn.ReLU(),
239
+ nn.Dropout(dropout),
240
+ nn.Linear(hidden_size * 2, hidden_size, bias=True),
241
+ )
242
+
243
+ self.project_end = nn.Sequential(
244
+ nn.Linear(hidden_size, hidden_size * 2, bias=True),
245
+ nn.ReLU(),
246
+ nn.Dropout(dropout),
247
+ nn.Linear(hidden_size * 2, hidden_size, bias=True),
248
+ )
249
+
250
+ self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
251
+
252
+ def forward(self, h, span_idx):
253
+ # h of shape [B, L, D]
254
+ # query_seg of shape [D, max_width]
255
+
256
+ B, L, D = h.size()
257
+
258
+ # project start and end
259
+ start_rep = self.project_start(h)
260
+ end_rep = self.project_end(h)
261
+
262
+ start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
263
+ end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
264
+
265
+ # concat start and end
266
+ cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
267
+
268
+ # project
269
+ cat = self.out_project(cat)
270
+
271
+ # reshape
272
+ return cat.view(B, L, self.max_width, D)
273
+
274
+
275
+ class SpanMarkerV0(nn.Module):
276
+ """
277
+ Marks and projects span endpoints using an MLP.
278
+ """
279
+
280
+ def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
281
+ super().__init__()
282
+ self.max_width = max_width
283
+ self.project_start = create_projection_layer(hidden_size, dropout)
284
+ self.project_end = create_projection_layer(hidden_size, dropout)
285
+
286
+ self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
287
+
288
+ def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
289
+ B, L, D = h.size()
290
+
291
+ start_rep = self.project_start(h)
292
+ end_rep = self.project_end(h)
293
+
294
+ start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
295
+ end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
296
+
297
+ cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
298
+
299
+ return self.out_project(cat).view(B, L, self.max_width, D)
300
+
301
+
302
+ class ConvShareV2(nn.Module):
303
+ def __init__(self, hidden_size, max_width):
304
+ super().__init__()
305
+
306
+ self.max_width = max_width
307
+
308
+ self.conv_weigth = nn.Parameter(
309
+ torch.randn(hidden_size, hidden_size, max_width)
310
+ )
311
+
312
+ nn.init.xavier_normal_(self.conv_weigth)
313
+
314
+ def forward(self, x, *args):
315
+ span_reps = []
316
+
317
+ x = torch.einsum('bld->bdl', x)
318
+
319
+ for i in range(self.max_width):
320
+ pad = i
321
+ x_i = F.pad(x, (0, pad), "constant", 0)
322
+ conv_w = self.conv_weigth[:, :, :i + 1]
323
+ out_i = F.conv1d(x_i, conv_w)
324
+ span_reps.append(out_i.transpose(-1, -2))
325
+
326
+ out = torch.stack(span_reps, dim=-2)
327
+
328
+ return out
329
+
330
+
331
+ class SpanRepLayer(nn.Module):
332
+ """
333
+ Various span representation approaches
334
+ """
335
+
336
+ def __init__(self, hidden_size, max_width, span_mode, **kwargs):
337
+ super().__init__()
338
+
339
+ if span_mode == 'marker':
340
+ self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
341
+ elif span_mode == 'markerV0':
342
+ self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs)
343
+ elif span_mode == 'query':
344
+ self.span_rep_layer = SpanQuery(
345
+ hidden_size, max_width, trainable=True)
346
+ elif span_mode == 'mlp':
347
+ self.span_rep_layer = SpanMLP(hidden_size, max_width)
348
+ elif span_mode == 'cat':
349
+ self.span_rep_layer = SpanCAT(hidden_size, max_width)
350
+ elif span_mode == 'conv_conv':
351
+ self.span_rep_layer = SpanConv(
352
+ hidden_size, max_width, span_mode='conv_conv')
353
+ elif span_mode == 'conv_max':
354
+ self.span_rep_layer = SpanConv(
355
+ hidden_size, max_width, span_mode='conv_max')
356
+ elif span_mode == 'conv_mean':
357
+ self.span_rep_layer = SpanConv(
358
+ hidden_size, max_width, span_mode='conv_mean')
359
+ elif span_mode == 'conv_sum':
360
+ self.span_rep_layer = SpanConv(
361
+ hidden_size, max_width, span_mode='conv_sum')
362
+ elif span_mode == 'conv_share':
363
+ self.span_rep_layer = ConvShare(hidden_size, max_width)
364
+ else:
365
+ raise ValueError(f'Unknown span mode {span_mode}')
366
+
367
+ def forward(self, x, *args):
368
+
369
+ return self.span_rep_layer(x, *args)
backup/modules/token_rep.py CHANGED
@@ -1,54 +1,54 @@
1
- from typing import List
2
-
3
- import torch
4
- from flair.data import Sentence
5
- from flair.embeddings import TransformerWordEmbeddings
6
- from torch import nn
7
- from torch.nn.utils.rnn import pad_sequence
8
-
9
-
10
- # flair.cache_root = '/gpfswork/rech/pds/upa43yu/.cache'
11
-
12
-
13
- class TokenRepLayer(nn.Module):
14
- def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
15
- hidden_size: int = 768,
16
- add_tokens=["[SEP]", "[ENT]"]
17
- ):
18
- super().__init__()
19
-
20
- self.bert_layer = TransformerWordEmbeddings(
21
- model_name,
22
- fine_tune=fine_tune,
23
- subtoken_pooling=subtoken_pooling,
24
- allow_long_sentences=True
25
- )
26
-
27
- # add tokens to vocabulary
28
- self.bert_layer.tokenizer.add_tokens(add_tokens)
29
-
30
- # resize token embeddings
31
- self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))
32
-
33
- bert_hidden_size = self.bert_layer.embedding_length
34
-
35
- if hidden_size != bert_hidden_size:
36
- self.projection = nn.Linear(bert_hidden_size, hidden_size)
37
-
38
- def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
39
- token_embeddings = self.compute_word_embedding(tokens)
40
-
41
- if hasattr(self, "projection"):
42
- token_embeddings = self.projection(token_embeddings)
43
-
44
- B = len(lengths)
45
- max_length = lengths.max()
46
- mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to(
47
- token_embeddings.device).long()
48
- return {"embeddings": token_embeddings, "mask": mask}
49
-
50
- def compute_word_embedding(self, tokens):
51
- sentences = [Sentence(i) for i in tokens]
52
- self.bert_layer.embed(sentences)
53
- token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True)
54
- return token_embeddings
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from flair.data import Sentence
5
+ from flair.embeddings import TransformerWordEmbeddings
6
+ from torch import nn
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+
10
+ # flair.cache_root = '/gpfswork/rech/pds/upa43yu/.cache'
11
+
12
+
13
+ class TokenRepLayer(nn.Module):
14
+ def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
15
+ hidden_size: int = 768,
16
+ add_tokens=["[SEP]", "[ENT]"]
17
+ ):
18
+ super().__init__()
19
+
20
+ self.bert_layer = TransformerWordEmbeddings(
21
+ model_name,
22
+ fine_tune=fine_tune,
23
+ subtoken_pooling=subtoken_pooling,
24
+ allow_long_sentences=True
25
+ )
26
+
27
+ # add tokens to vocabulary
28
+ self.bert_layer.tokenizer.add_tokens(add_tokens)
29
+
30
+ # resize token embeddings
31
+ self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))
32
+
33
+ bert_hidden_size = self.bert_layer.embedding_length
34
+
35
+ if hidden_size != bert_hidden_size:
36
+ self.projection = nn.Linear(bert_hidden_size, hidden_size)
37
+
38
+ def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
39
+ token_embeddings = self.compute_word_embedding(tokens)
40
+
41
+ if hasattr(self, "projection"):
42
+ token_embeddings = self.projection(token_embeddings)
43
+
44
+ B = len(lengths)
45
+ max_length = lengths.max()
46
+ mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to(
47
+ token_embeddings.device).long()
48
+ return {"embeddings": token_embeddings, "mask": mask}
49
+
50
+ def compute_word_embedding(self, tokens):
51
+ sentences = [Sentence(i) for i in tokens]
52
+ self.bert_layer.embed(sentences)
53
+ token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True)
54
+ return token_embeddings
backup/requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- torch
2
- transformers
3
- huggingface_hub
4
- flair
5
- seqeval
6
  tqdm
 
1
+ torch
2
+ transformers
3
+ huggingface_hub
4
+ flair
5
+ seqeval
6
  tqdm
backup/save_load.py CHANGED
@@ -1,20 +1,20 @@
1
- import torch
2
- from .model import GLiNER
3
-
4
-
5
- def save_model(current_model, path):
6
- config = current_model.config
7
- dict_save = {"model_weights": current_model.state_dict(), "config": config}
8
- torch.save(dict_save, path)
9
-
10
-
11
- def load_model(path, model_name=None, device=None):
12
- dict_load = torch.load(path, map_location=torch.device('cpu'))
13
- config = dict_load["config"]
14
-
15
- if model_name is not None:
16
- config.model_name = model_name
17
-
18
- loaded_model = GLiNER(config)
19
- loaded_model.load_state_dict(dict_load["model_weights"])
20
- return loaded_model.to(device) if device is not None else loaded_model
 
1
+ import torch
2
+ from .model import GLiNER
3
+
4
+
5
+ def save_model(current_model, path):
6
+ config = current_model.config
7
+ dict_save = {"model_weights": current_model.state_dict(), "config": config}
8
+ torch.save(dict_save, path)
9
+
10
+
11
+ def load_model(path, model_name=None, device=None):
12
+ dict_load = torch.load(path, map_location=torch.device('cpu'))
13
+ config = dict_load["config"]
14
+
15
+ if model_name is not None:
16
+ config.model_name = model_name
17
+
18
+ loaded_model = GLiNER(config)
19
+ loaded_model.load_state_dict(dict_load["model_weights"])
20
+ return loaded_model.to(device) if device is not None else loaded_model
backup/train.py CHANGED
@@ -1,132 +1,132 @@
1
- import argparse
2
- import os
3
-
4
- import torch
5
- import yaml
6
- from tqdm import tqdm
7
- from transformers import get_cosine_schedule_with_warmup
8
-
9
- # from model_nested import NerFilteredSemiCRF
10
- from .model import GLiNER
11
- from .modules.run_evaluation import get_for_all_path, sample_train_data
12
- from .save_load import save_model, load_model
13
-
14
- import json
15
-
16
-
17
- # train function
18
- def train(model, optimizer, train_data, num_steps=1000, eval_every=100, log_dir="logs", warmup_ratio=0.1,
19
- train_batch_size=8, device='cuda'):
20
- model.train()
21
-
22
- # initialize data loaders
23
- train_loader = model.create_dataloader(train_data, batch_size=train_batch_size, shuffle=True)
24
-
25
- pbar = tqdm(range(num_steps))
26
-
27
- if warmup_ratio < 1:
28
- num_warmup_steps = int(num_steps * warmup_ratio)
29
- else:
30
- num_warmup_steps = int(warmup_ratio)
31
-
32
- scheduler = get_cosine_schedule_with_warmup(
33
- optimizer,
34
- num_warmup_steps=num_warmup_steps,
35
- num_training_steps=num_steps
36
- )
37
-
38
- iter_train_loader = iter(train_loader)
39
-
40
- for step in pbar:
41
- try:
42
- x = next(iter_train_loader)
43
- except StopIteration:
44
- iter_train_loader = iter(train_loader)
45
- x = next(iter_train_loader)
46
-
47
- for k, v in x.items():
48
- if isinstance(v, torch.Tensor):
49
- x[k] = v.to(device)
50
-
51
- try:
52
- loss = model(x) # Forward pass
53
- except:
54
- continue
55
-
56
- # check if loss is nan
57
- if torch.isnan(loss):
58
- continue
59
-
60
- loss.backward() # Compute gradients
61
- optimizer.step() # Update parameters
62
- scheduler.step() # Update learning rate schedule
63
- optimizer.zero_grad() # Reset gradients
64
-
65
- description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
66
-
67
- if (step + 1) % eval_every == 0:
68
- current_path = os.path.join(log_dir, f'model_{step + 1}')
69
- save_model(model, current_path)
70
- #val_data_dir = "/gpfswork/rech/ohy/upa43yu/NER_datasets" # can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
71
- #get_for_all_path(model, step, log_dir, val_data_dir) # you can remove this comment if you want to evaluate the model
72
-
73
- model.train()
74
-
75
- pbar.set_description(description)
76
-
77
-
78
- def create_parser():
79
- parser = argparse.ArgumentParser(description="Span-based NER")
80
- parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
81
- parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
82
- return parser
83
-
84
-
85
- def load_config_as_namespace(config_file):
86
- with open(config_file, 'r') as f:
87
- config_dict = yaml.safe_load(f)
88
- return argparse.Namespace(**config_dict)
89
-
90
-
91
- if __name__ == "__main__":
92
- # parse args
93
- parser = create_parser()
94
- args = parser.parse_args()
95
-
96
- # load config
97
- config = load_config_as_namespace(args.config)
98
-
99
- config.log_dir = args.log_dir
100
-
101
- try:
102
- with open(config.train_data, 'r') as f:
103
- data = json.load(f)
104
- except:
105
- data = sample_train_data(config.train_data, 10000)
106
-
107
- if config.prev_path != "none":
108
- model = load_model(config.prev_path)
109
- model.config = config
110
- else:
111
- model = GLiNER(config)
112
-
113
- if torch.cuda.is_available():
114
- model = model.cuda()
115
-
116
- lr_encoder = float(config.lr_encoder)
117
- lr_others = float(config.lr_others)
118
-
119
- optimizer = torch.optim.AdamW([
120
- # encoder
121
- {'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
122
- {'params': model.rnn.parameters(), 'lr': lr_others},
123
- # projection layers
124
- {'params': model.span_rep_layer.parameters(), 'lr': lr_others},
125
- {'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
126
- ])
127
-
128
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
129
-
130
- train(model, optimizer, data, num_steps=config.num_steps, eval_every=config.eval_every,
131
- log_dir=config.log_dir, warmup_ratio=config.warmup_ratio, train_batch_size=config.train_batch_size,
132
- device=device)
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ import yaml
6
+ from tqdm import tqdm
7
+ from transformers import get_cosine_schedule_with_warmup
8
+
9
+ # from model_nested import NerFilteredSemiCRF
10
+ from .model import GLiNER
11
+ from .modules.run_evaluation import get_for_all_path, sample_train_data
12
+ from .save_load import save_model, load_model
13
+
14
+ import json
15
+
16
+
17
+ # train function
18
+ def train(model, optimizer, train_data, num_steps=1000, eval_every=100, log_dir="logs", warmup_ratio=0.1,
19
+ train_batch_size=8, device='cuda'):
20
+ model.train()
21
+
22
+ # initialize data loaders
23
+ train_loader = model.create_dataloader(train_data, batch_size=train_batch_size, shuffle=True)
24
+
25
+ pbar = tqdm(range(num_steps))
26
+
27
+ if warmup_ratio < 1:
28
+ num_warmup_steps = int(num_steps * warmup_ratio)
29
+ else:
30
+ num_warmup_steps = int(warmup_ratio)
31
+
32
+ scheduler = get_cosine_schedule_with_warmup(
33
+ optimizer,
34
+ num_warmup_steps=num_warmup_steps,
35
+ num_training_steps=num_steps
36
+ )
37
+
38
+ iter_train_loader = iter(train_loader)
39
+
40
+ for step in pbar:
41
+ try:
42
+ x = next(iter_train_loader)
43
+ except StopIteration:
44
+ iter_train_loader = iter(train_loader)
45
+ x = next(iter_train_loader)
46
+
47
+ for k, v in x.items():
48
+ if isinstance(v, torch.Tensor):
49
+ x[k] = v.to(device)
50
+
51
+ try:
52
+ loss = model(x) # Forward pass
53
+ except:
54
+ continue
55
+
56
+ # check if loss is nan
57
+ if torch.isnan(loss):
58
+ continue
59
+
60
+ loss.backward() # Compute gradients
61
+ optimizer.step() # Update parameters
62
+ scheduler.step() # Update learning rate schedule
63
+ optimizer.zero_grad() # Reset gradients
64
+
65
+ description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
66
+
67
+ if (step + 1) % eval_every == 0:
68
+ current_path = os.path.join(log_dir, f'model_{step + 1}')
69
+ save_model(model, current_path)
70
+ #val_data_dir = "/gpfswork/rech/ohy/upa43yu/NER_datasets" # can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
71
+ #get_for_all_path(model, step, log_dir, val_data_dir) # you can remove this comment if you want to evaluate the model
72
+
73
+ model.train()
74
+
75
+ pbar.set_description(description)
76
+
77
+
78
+ def create_parser():
79
+ parser = argparse.ArgumentParser(description="Span-based NER")
80
+ parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
81
+ parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
82
+ return parser
83
+
84
+
85
+ def load_config_as_namespace(config_file):
86
+ with open(config_file, 'r') as f:
87
+ config_dict = yaml.safe_load(f)
88
+ return argparse.Namespace(**config_dict)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ # parse args
93
+ parser = create_parser()
94
+ args = parser.parse_args()
95
+
96
+ # load config
97
+ config = load_config_as_namespace(args.config)
98
+
99
+ config.log_dir = args.log_dir
100
+
101
+ try:
102
+ with open(config.train_data, 'r') as f:
103
+ data = json.load(f)
104
+ except:
105
+ data = sample_train_data(config.train_data, 10000)
106
+
107
+ if config.prev_path != "none":
108
+ model = load_model(config.prev_path)
109
+ model.config = config
110
+ else:
111
+ model = GLiNER(config)
112
+
113
+ if torch.cuda.is_available():
114
+ model = model.cuda()
115
+
116
+ lr_encoder = float(config.lr_encoder)
117
+ lr_others = float(config.lr_others)
118
+
119
+ optimizer = torch.optim.AdamW([
120
+ # encoder
121
+ {'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
122
+ {'params': model.rnn.parameters(), 'lr': lr_others},
123
+ # projection layers
124
+ {'params': model.span_rep_layer.parameters(), 'lr': lr_others},
125
+ {'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
126
+ ])
127
+
128
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
129
+
130
+ train(model, optimizer, data, num_steps=config.num_steps, eval_every=config.eval_every,
131
+ log_dir=config.log_dir, warmup_ratio=config.warmup_ratio, train_batch_size=config.train_batch_size,
132
+ device=device)
core/__pycache__/base.cpython-310.pyc ADDED
Binary file (1.59 kB). View file
 
core/__pycache__/gradio_ocr.cpython-310.pyc ADDED
Binary file (1.93 kB). View file
 
core/__pycache__/ner_engine.cpython-310.pyc ADDED
Binary file (2.57 kB). View file
 
core/__pycache__/ocr_engine.cpython-310.pyc ADDED
Binary file (4.04 kB). View file
 
core/__pycache__/vlm_engine.cpython-310.pyc ADDED
Binary file (3.21 kB). View file
 
core/base.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Any
3
+
4
+ class BaseEngine(ABC):
5
+ @abstractmethod
6
+ def process(self, image_path: str) -> Any:
7
+ pass
8
+
9
+ class BaseOCR(BaseEngine):
10
+ @abstractmethod
11
+ def extract_text(self, image_path: str) -> str:
12
+ pass
13
+
14
+ class BaseVLM(BaseEngine):
15
+ @abstractmethod
16
+ def extract_structured_data(self, image_path: str, prompt: str) -> Dict[str, Any]:
17
+ pass
18
+
19
+ class BaseNER(BaseEngine):
20
+ @abstractmethod
21
+ def extract_entities(self, text: str, labels: List[str]) -> Dict[str, List[str]]:
22
+ pass
core/gradio_ocr.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from gradio_client import Client, handle_file
4
+ from .base import BaseOCR
5
+
6
+ class GradioOCREngine(BaseOCR):
7
+ def __init__(self, space_name="WebAshlarWA/glm-ocr-v1"):
8
+ self.space_name = space_name
9
+ self.client = None
10
+ self._initialize_client()
11
+
12
+ def _initialize_client(self):
13
+ try:
14
+ self.client = Client(self.space_name)
15
+ logging.info(f"Gradio Client initialized for Space: {self.space_name}")
16
+ except Exception as e:
17
+ logging.error(f"Failed to initialize Gradio Client for {self.space_name}: {e}")
18
+
19
+ def extract_text(self, image_path: str) -> str:
20
+ if not self.client:
21
+ logging.error("Gradio Client not initialized.")
22
+ return ""
23
+
24
+ logging.info(f"Gradio OCR: Starting extraction for {os.path.basename(image_path)}")
25
+ try:
26
+ # According to the user snippet, the input is 'image' and output is a string?
27
+ # Or structured data. The snippet used /proses_intelijen
28
+ result = self.client.predict(
29
+ image=handle_file(image_path),
30
+ api_name="/proses_intelijen"
31
+ )
32
+
33
+ if isinstance(result, list) and len(result) > 0:
34
+ # Gradio spaces often return lists of [text, score] or similar
35
+ return str(result[0])
36
+ elif isinstance(result, str):
37
+ return result
38
+ elif isinstance(result, dict):
39
+ # If it's structured, we might need to stringify or handle it elsewhere
40
+ # For OCR we expect a string
41
+ return result.get('text', str(result))
42
+
43
+ logging.info(f"Gradio OCR: Successfully extracted text.")
44
+ return str(result)
45
+ except Exception as e:
46
+ logging.error(f"Gradio OCR extraction failed: {e}")
47
+ return ""
48
+
49
+ def process(self, image_path: str) -> str:
50
+ return self.extract_text(image_path)
core/ner_engine.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict
3
+ from .base import BaseNER
4
+
5
+ class NEREngine(BaseNER):
6
+ def __init__(self, model_name="urchade/gliner_mediumv2.1"):
7
+ self.model_name = model_name
8
+ self.model = None
9
+ self._initialize_model()
10
+
11
+ def _initialize_model(self):
12
+ logging.info(f"Initializing NER model: {self.model_name}")
13
+ try:
14
+ from backup.model import GLiNER
15
+ self.model = GLiNER.from_pretrained(self.model_name)
16
+ logging.info(f"NER model '{self.model_name}' loaded successfully.")
17
+ except Exception as e:
18
+ logging.error(f"Failed to load NER model: {e}. NER extraction will be unavailable.")
19
+
20
+ def extract_entities(self, text: str, labels: List[str] = None) -> Dict[str, List[str]]:
21
+ if not text:
22
+ logging.warning("NER: Received empty text for extraction.")
23
+ return {}
24
+
25
+ if not self.model:
26
+ logging.error("NER: Model not initialized. Skipping extraction.")
27
+ return {}
28
+
29
+ if labels is None:
30
+ labels = ["Name", "Designation", "Company", "Contact", "Address", "Email", "Link"]
31
+
32
+ logging.info(f"NER: Extracting entities for {len(text)} characters of text.")
33
+ try:
34
+ entities = self.model.predict_entities(text, labels, threshold=0.3)
35
+ structured_data = {label: [] for label in labels}
36
+ for ent in entities:
37
+ label = ent["label"]
38
+ if label in structured_data:
39
+ structured_data[label].append(ent["text"])
40
+
41
+ non_empty_tags = sum(1 for v in structured_data.values() if v)
42
+ logging.info(f"NER: Extracted entities for {non_empty_tags} labels.")
43
+ return structured_data
44
+ except Exception as e:
45
+ logging.error(f"NER: Extraction pipeline crashed: {e}")
46
+ return {}
47
+
48
+ def process(self, text: str) -> Dict[str, List[str]]:
49
+ return self.extract_entities(text)
core/ocr_engine.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import cv2
3
+ import os
4
+ import numpy as np
5
+ from PIL import Image, ImageEnhance
6
+ from .base import BaseOCR
7
+ from .gradio_ocr import GradioOCREngine
8
+
9
+ class OCREngine(BaseOCR):
10
+ def __init__(self, engine_type='paddle'):
11
+ self.engine_type = engine_type
12
+ self.ocr = None
13
+ self.gradio_fallback = None
14
+ self._initialize_engine()
15
+
16
+ def _initialize_engine(self):
17
+ logging.info(f"Initializing OCR engine: {self.engine_type}")
18
+
19
+ # Pre-emptive Gradio initialization as it's the most reliable fallback
20
+ try:
21
+ self.gradio_fallback = GradioOCREngine()
22
+ except Exception as e:
23
+ logging.error(f"Failed to pre-initialize Gradio fallback: {e}")
24
+
25
+ if self.engine_type == 'paddle':
26
+ try:
27
+ from paddleocr import PaddleOCR
28
+ self.ocr = PaddleOCR(use_angle_cls=False, lang='en', show_log=False)
29
+ logging.info("PaddleOCR engine initialized successfully.")
30
+ except Exception as e:
31
+ logging.warning(f"Failed to initialize PaddleOCR: {e}. Switching to EasyOCR fallback.")
32
+ self.engine_type = 'easyocr'
33
+
34
+ if self.engine_type == 'easyocr':
35
+ try:
36
+ import easyocr
37
+ self.ocr = easyocr.Reader(['en'])
38
+ logging.info("EasyOCR engine initialized successfully.")
39
+ except Exception as e:
40
+ logging.error(f"Failed to initialize EasyOCR: {e}. OCR will be partially unavailable.")
41
+ self.ocr = None
42
+
43
+ def preprocess_image(self, image_path, scale=2):
44
+ try:
45
+ image = cv2.imread(image_path)
46
+ if image is None:
47
+ logging.error(f"Image not found or unreadable: {image_path}")
48
+ return None
49
+
50
+ # Upscale
51
+ height, width = image.shape[:2]
52
+ image = cv2.resize(image, (width * scale, height * scale), interpolation=cv2.INTER_CUBIC)
53
+
54
+ # Denoise
55
+ image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
56
+
57
+ # Sharpen
58
+ kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
59
+ image = cv2.filter2D(image, -1, kernel)
60
+
61
+ # Enhance Contrast
62
+ pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
63
+ enhancer = ImageEnhance.Contrast(pil_img)
64
+ enhanced_image = enhancer.enhance(1.5)
65
+
66
+ logging.debug(f"Preprocessing completed for {image_path}")
67
+ return cv2.cvtColor(np.array(enhanced_image), cv2.COLOR_RGB2BGR)
68
+ except Exception as e:
69
+ logging.error(f"Error during image preprocessing for {image_path}: {e}")
70
+ return None
71
+
72
+ def extract_text(self, image_path: str) -> str:
73
+ logging.info(f"Starting text extraction for: {os.path.basename(image_path)}")
74
+
75
+ # Tiered Extraction Strategy:
76
+ # 1. Primary Engine (Paddle/EasyOCR)
77
+ # 2. Gradio Remote Fallback (Very reliable)
78
+
79
+ extracted_text = ""
80
+
81
+ # 1. Local OCR
82
+ if self.engine_type == 'paddle' and self.ocr:
83
+ try:
84
+ processed_img = self.preprocess_image(image_path)
85
+ if processed_img is not None:
86
+ results = self.ocr.ocr(processed_img)
87
+ if results and results[0]:
88
+ extracted_text = " ".join([line[1][0] for line in results[0]])
89
+ except Exception as e:
90
+ logging.error(f"PaddleOCR crashed: {e}")
91
+
92
+ elif self.engine_type == 'easyocr' and self.ocr:
93
+ try:
94
+ processed_img = self.preprocess_image(image_path)
95
+ if processed_img is not None:
96
+ results = self.ocr.readtext(processed_img)
97
+ extracted_text = " ".join([res[1] for res in results])
98
+ except Exception as e:
99
+ logging.error(f"EasyOCR crashed: {e}")
100
+
101
+ # 2. Gradio Fallback if Local failed
102
+ if not extracted_text and self.gradio_fallback:
103
+ logging.info("Local OCR failed or returned empty. Trying Gradio OCR fallback...")
104
+ extracted_text = self.gradio_fallback.extract_text(image_path)
105
+
106
+ if extracted_text:
107
+ logging.info(f"OCR extracted {len(extracted_text)} characters using {'Gradio' if not extracted_text else self.engine_type}.")
108
+ else:
109
+ logging.error("All OCR methods failed to extract text.")
110
+
111
+ return extracted_text
112
+
113
+ def process(self, image_path: str) -> str:
114
+ return self.extract_text(image_path)
core/vlm_engine.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import json
4
+ import logging
5
+ import requests
6
+ import cv2
7
+ from typing import Dict, Any
8
+ from .base import BaseVLM
9
+
10
+ class GroqVLMEngine(BaseVLM):
11
+ def __init__(self, model="meta-llama/llama-4-scout-17b-16e-instruct"):
12
+ self.api_key = os.getenv("GROQ_API_KEY")
13
+ self.url = "https://api.groq.com/openai/v1/chat/completions"
14
+ self.model = model
15
+ if not self.api_key:
16
+ logging.warning("GROQ_API_KEY missing from environment. VLM extraction will be skipped.")
17
+
18
+ def image_to_base64(self, image_path: str) -> str:
19
+ try:
20
+ img = cv2.imread(image_path)
21
+ if img is None:
22
+ logging.error(f"VLM: Image not found at {image_path}")
23
+ return ""
24
+ _, buffer = cv2.imencode(".jpg", img)
25
+ return base64.b64encode(buffer).decode("utf-8")
26
+ except Exception as e:
27
+ logging.error(f"VLM: Error converting image to base64: {e}")
28
+ return ""
29
+
30
+ def extract_structured_data(self, image_path: str, prompt: str) -> Dict[str, Any]:
31
+ if not self.api_key:
32
+ return {}
33
+
34
+ logging.info(f"VLM: Starting extraction for {os.path.basename(image_path)} using {self.model}")
35
+ base64_image = self.image_to_base64(image_path)
36
+ if not base64_image:
37
+ return {}
38
+
39
+ headers = {
40
+ "Content-Type": "application/json",
41
+ "Authorization": f"Bearer {self.api_key}"
42
+ }
43
+
44
+ payload = {
45
+ "model": self.model,
46
+ "messages": [
47
+ {
48
+ "role": "system",
49
+ "content": "You are a strict information extraction engine for business cards. Return only valid JSON. Do not include any other text."
50
+ },
51
+ {
52
+ "role": "user",
53
+ "content": [
54
+ {"type": "text", "text": prompt},
55
+ {
56
+ "type": "image_url",
57
+ "image_url": {
58
+ "url": f"data:image/jpeg;base64,{base64_image}"
59
+ }
60
+ }
61
+ ]
62
+ }
63
+ ],
64
+ "response_format": {"type": "json_object"},
65
+ "temperature": 0.1
66
+ }
67
+
68
+ try:
69
+ resp = requests.post(self.url, headers=headers, json=payload, timeout=60)
70
+ if resp.status_code != 200:
71
+ logging.error(f"VLM API Error: {resp.status_code} - {resp.text}")
72
+ return {}
73
+
74
+ content = resp.json()["choices"][0]["message"]["content"]
75
+ data = json.loads(content)
76
+ logging.info(f"VLM: Successfully extracted structured data from {os.path.basename(image_path)}")
77
+ return data
78
+ except requests.exceptions.Timeout:
79
+ logging.error("VLM: Request timed out.")
80
+ return {}
81
+ except Exception as e:
82
+ logging.error(f"VLM: Unexpected error: {e}")
83
+ return {}
84
+
85
+ def process(self, image_path: str) -> Dict[str, Any]:
86
+ prompt = """
87
+ Extract structured text from this business card and return ONLY valid JSON.
88
+ Fields: Name, Designation, Company, Contact, Address, Email, Link.
89
+ Every value must be a JSON array. If not found, use [].
90
+ """
91
+ return self.extract_structured_data(image_path, prompt)
requirements.txt CHANGED
@@ -1,16 +1,18 @@
1
- Flask
2
- huggingface_hub
3
- python-dotenv
4
- easyocr
5
- Pillow
6
- opencv-python
7
- numpy
8
- paddle-bfloat
9
- paddlepaddle
10
- paddleocr
11
- torch
12
- transformers
13
- flair
14
- seqeval
15
- tqdm
16
- gunicorn
 
 
 
1
+ Flask
2
+ python-dotenv
3
+ Pillow
4
+ opencv-python
5
+ numpy
6
+ paddle-bfloat
7
+ paddlepaddle>=2.6.0
8
+ paddleocr>=2.7.0
9
+ gradio_client
10
+ easyocr
11
+ langchain
12
+ langchain-community
13
+ torch
14
+ transformers
15
+ flair
16
+ tqdm
17
+ gunicorn
18
+ requests
static/uploads/IN_Standard-Visiting-Cards_Overview.png ADDED
templates/index.html CHANGED
@@ -1,284 +1,236 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
-
4
- <head>
5
- <meta charset="UTF-8" />
6
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
- <title>AI Data Extractor</title>
8
- <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet" />
9
- <style>
10
- body {
11
- background-color: #1c1c1e;
12
- font-family: "Poppins", sans-serif;
13
- color: #f5f5f7;
14
- margin: 0;
15
- }
16
-
17
- h1 {
18
- color: #e5e5e7;
19
- text-align: center;
20
- margin-bottom: 20px;
21
- }
22
-
23
- .container {
24
- margin-top: 70px;
25
- }
26
-
27
- .file-upload-section {
28
- background-color: #2c2c2e;
29
- padding: 30px;
30
- border-radius: 15px;
31
- box-shadow: 0 8px 16px rgba(0, 0, 0, 0.5);
32
- text-align: center;
33
- }
34
-
35
- .file-upload-section input[type="file"] {
36
- margin: 20px 0;
37
- }
38
-
39
- .file-upload-section input[type="submit"] {
40
- background-color: #ee4410;
41
- color: white;
42
- border: none;
43
- padding: 10px 20px;
44
- border-radius: 5px;
45
- transition: background-color 0.3s ease;
46
- }
47
-
48
- .file-upload-section input[type="submit"]:hover {
49
- background-color: #ee4410;
50
- }
51
-
52
- .file-actions a {
53
- margin: 0 10px;
54
- text-decoration: none;
55
- color: #ee4410;
56
- }
57
-
58
- .file-actions a:hover {
59
- color: #ee4410;
60
- }
61
-
62
- .flash-message {
63
- margin-bottom: 20px;
64
- padding: 15px;
65
- border-radius: 5px;
66
- color: #333;
67
- }
68
-
69
- .alert {
70
- text-align: center;
71
- position: sticky;
72
- top: 0;
73
- right: 15%;
74
- }
75
-
76
- /* Loader styles */
77
- .loader {
78
- border: 8px solid #f3f3f3;
79
- border-top: 8px solid #ee4410;
80
- border-radius: 50%;
81
- width: 60px;
82
- height: 60px;
83
- animation: spin 2s linear infinite;
84
- margin: 20px auto;
85
- display: none;
86
- }
87
-
88
- @keyframes spin {
89
- 0% {
90
- transform: rotate(0deg);
91
- }
92
-
93
- 100% {
94
- transform: rotate(360deg);
95
- }
96
- }
97
-
98
- /* Top bar styles */
99
- .top-bar {
100
- background-color: #333;
101
- position: fixed;
102
- top: 0;
103
- width: 100%;
104
- z-index: 1000;
105
- padding: 10px 20px;
106
- display: flex;
107
- justify-content: space-between;
108
- align-items: center;
109
- }
110
-
111
- .top-bar h2 {
112
- color: white;
113
- }
114
-
115
- /* Navigation tab styles */
116
- .tab {
117
- display: flex;
118
- gap: 10px;
119
- }
120
-
121
- .tab button {
122
- background-color: inherit;
123
- border: none;
124
- outline: none;
125
- cursor: pointer;
126
- padding: 10px 16px;
127
- transition: 0.3s;
128
- font-size: 17px;
129
- color: white;
130
- }
131
-
132
- .tab button:hover {
133
- background-color: #575757;
134
- cursor: pointer;
135
- }
136
-
137
- .tab button.active {
138
- background-color: #ee4410;
139
- }
140
-
141
- /* Tab content styles */
142
- .tabcontent {
143
- display: none;
144
- padding: 20px;
145
- margin-top: 70px;
146
- }
147
- .disabled {
148
- cursor: not-allowed !important;
149
- opacity: 0.6;/* Set cursor to not-allowed */
150
- }
151
-
152
- /* Responsive design */
153
- @media (max-width: 768px) {
154
- .tab {
155
- flex-direction: column;
156
- }
157
- }
158
- </style>
159
- </head>
160
-
161
- <body>
162
- <!-- Locked Top Bar with Tabs -->
163
- <div class="top-bar">
164
- <h2>AI Data Extractor</h2>
165
- <!-- Navigation Tabs -->
166
- <div class="tab">
167
- <button class="tablinks active" onclick="openLink('https://webashlarwa-imagedataextractor2.hf.space/', this, '#ff4d00')" id="defaultOpen">Visiting Card Data Extractor</button>
168
- <button class="tablinks" onclick="openLink('https://webashlarwa-resumeextractor2.hf.space/', this, '#ff4d00')">Resume Data Extractor</button>
169
- </div>
170
- </div>
171
- <div class="container">
172
- <h1>Visiting Card Data Extractor</h1>
173
- <div class="file-upload-section">
174
- <form id="fileUploadForm" action="{{ url_for('upload_file') }}" method="POST" enctype="multipart/form-data">
175
- <input type="file" name="files" multiple class="form-control" required />
176
- <input type="submit" value="Upload your Images" class="btn btn-outline-primary" />
177
- </form>
178
-
179
- {% if session.get('uploaded_files') %}
180
- <p class="mt-4">
181
- Uploaded:
182
- <span class="text-danger">{{ session.get('uploaded_files') }}</span>
183
- </p>
184
- <form action="{{ url_for('remove_file') }}" method="post">
185
- <button type="submit" class="btn btn-outline-danger">
186
- <i class="bi bi-trash"></i> Remove Uploaded File
187
- </button>
188
- </form>
189
- {% endif %}
190
- </div>
191
-
192
- <div class="container">
193
- <!-- Loader -->
194
- <div class="loader" id="loader"></div>
195
- </div>
196
-
197
- {% with messages = get_flashed_messages() %} {% if messages %}
198
- <div class="alert alert-success mt-4" id="flash-message">
199
- {{ messages[0] }}
200
- </div>
201
- {% endif %} {% endwith %}
202
- </div>
203
-
204
- <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.bundle.min.js"></script>
205
- <script>
206
- // Loader functionality
207
- document.getElementById('fileUploadForm').onsubmit = function() {
208
- document.getElementById('loader').style.display = 'block';
209
-
210
- // Disable the tab buttons and apply disabled class
211
- const buttons = document.querySelectorAll('.tab button');
212
- buttons.forEach(button => {
213
- button.setAttribute('disabled', 'true');
214
- button.classList.add('disabled'); // Add disabled class
215
- });
216
-
217
- // Show processing message
218
- const processingMessage = document.createElement('p');
219
- processingMessage.id = 'processing-message';
220
- processingMessage.textContent = 'Processing, please wait...';
221
- processingMessage.style.color = '#e68a10'; // Style as needed
222
- document.querySelector('.file-upload-section').appendChild(processingMessage);
223
- };
224
-
225
- // Flash message auto-hide
226
- setTimeout(function () {
227
- let flashMessage = document.getElementById("flash-message");
228
- if (flashMessage) {
229
- flashMessage.style.transition = "opacity 1s ease";
230
- flashMessage.style.opacity = 0;
231
- setTimeout(() => flashMessage.remove(), 1000);
232
- }
233
-
234
- // After processing is complete (You can adjust this based on your logic)
235
- const processingMessage = document.getElementById('processing-message');
236
- if (processingMessage) {
237
- processingMessage.remove(); // Remove the processing message
238
- }
239
-
240
- // Re-enable tab buttons and remove disabled class
241
- const buttons = document.querySelectorAll('.tab button');
242
- buttons.forEach(button => {
243
- button.removeAttribute('disabled');
244
- button.classList.remove('disabled'); // Remove disabled class
245
- });
246
- }, 3000); // Adjust timing based on your upload duration
247
-
248
- // Function to open links in the same tab
249
- function openLink(url, element) {
250
- window.location.href = url; // Redirects to the specified URL in the same tab
251
-
252
- // Remove "active" class from all buttons
253
- const buttons = document.querySelectorAll('.tab button');
254
- buttons.forEach(button => button.classList.remove('active'));
255
-
256
- // Add "active" class to the clicked button
257
- element.classList.add('active');
258
- }
259
- //Refreshing the cookie
260
- function setCookie(name, value, days) {
261
- let expires = "";
262
- if (days) {
263
- const date = new Date();
264
- date.setTime(date.getTime() + (days * 24 * 60 * 60 * 1000));
265
- expires = "; expires=" + date.toUTCString();
266
- }
267
- document.cookie = name + "=" + (value || "") + expires + "; path=/";
268
- }
269
-
270
- function deleteCookie(name) {
271
- document.cookie = name + '=; Max-Age=0; path=/;'; // Delete the cookie
272
- }
273
-
274
- // Set the cookie (you can comment this out after testing)
275
- setCookie('myCookie', 'myValue', 1); // Sets a cookie for demonstration
276
-
277
- // Automatically delete the cookie when the page is loaded or refreshed
278
- window.onload = function() {
279
- deleteCookie('myCookie'); // Replace 'myCookie' with your cookie name
280
- }
281
- </script>
282
- </body>
283
-
284
- </html>
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>AI Data Extractor - Visiting Card</title>
8
+ <link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600&display=swap" rel="stylesheet">
9
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet" />
10
+ <style>
11
+ :root {
12
+ --primary: #ee4410;
13
+ --secondary: #ff9f0a;
14
+ --bg-dark: #0a0a0c;
15
+ --card-bg: rgba(30, 30, 35, 0.7);
16
+ --text-glow: 0 0 10px rgba(238, 68, 16, 0.5);
17
+ }
18
+
19
+ body {
20
+ background-color: var(--bg-dark);
21
+ font-family: 'Outfit', sans-serif;
22
+ color: #f5f5f7;
23
+ overflow-x: hidden;
24
+ background: radial-gradient(circle at 50% 50%, #1a1a1f 0%, #0a0a0c 100%);
25
+ min-height: 100vh;
26
+ }
27
+
28
+ .glass-card {
29
+ background: var(--card-bg);
30
+ backdrop-filter: blur(12px);
31
+ border: 1px solid rgba(255, 255, 255, 0.1);
32
+ border-radius: 20px;
33
+ box-shadow: 0 15px 35px rgba(0, 0, 0, 0.4);
34
+ padding: 2.5rem;
35
+ margin-top: 2rem;
36
+ }
37
+
38
+ .premium-title {
39
+ background: linear-gradient(135deg, #fff 0%, #aaa 100%);
40
+ -webkit-background-clip: text;
41
+ background-clip: text;
42
+ -webkit-text-fill-color: transparent;
43
+ font-weight: 600;
44
+ letter-spacing: -1px;
45
+ text-shadow: var(--text-glow);
46
+ margin-bottom: 2rem;
47
+ text-align: center;
48
+ }
49
+
50
+ .top-bar {
51
+ background: rgba(20, 20, 25, 0.8);
52
+ backdrop-filter: blur(10px);
53
+ padding: 1rem 2rem;
54
+ border-bottom: 1px solid rgba(255, 255, 255, 0.05);
55
+ position: sticky;
56
+ top: 0;
57
+ z-index: 1000;
58
+ }
59
+
60
+ .tab-btn {
61
+ background: transparent;
62
+ border: 1px solid rgba(255, 255, 255, 0.1);
63
+ color: #8e8e93;
64
+ padding: 0.6rem 1.2rem;
65
+ border-radius: 30px;
66
+ margin-right: 10px;
67
+ transition: all 0.3s ease;
68
+ text-decoration: none;
69
+ font-size: 0.9rem;
70
+ }
71
+
72
+ .tab-btn:hover, .tab-btn.active {
73
+ border-color: var(--primary);
74
+ color: #fff;
75
+ background: rgba(238, 68, 16, 0.1);
76
+ }
77
+
78
+ .tab-btn.active {
79
+ box-shadow: 0 0 15px rgba(238, 68, 16, 0.2);
80
+ }
81
+
82
+ .upload-area {
83
+ border: 2px dashed rgba(255, 255, 255, 0.1);
84
+ border-radius: 15px;
85
+ padding: 3rem;
86
+ text-align: center;
87
+ transition: all 0.3s ease;
88
+ cursor: pointer;
89
+ margin-bottom: 2rem;
90
+ }
91
+
92
+ .upload-area:hover {
93
+ border-color: var(--primary);
94
+ background: rgba(238, 68, 16, 0.05);
95
+ }
96
+
97
+ .btn-premium {
98
+ background: linear-gradient(135deg, var(--primary) 0%, var(--secondary) 100%);
99
+ border: none;
100
+ color: white;
101
+ padding: 0.8rem 2rem;
102
+ border-radius: 30px;
103
+ font-weight: 600;
104
+ box-shadow: 0 5px 15px rgba(238, 68, 16, 0.3);
105
+ transition: all 0.3s ease;
106
+ width: 100%;
107
+ }
108
+
109
+ .btn-premium:hover:not(:disabled) {
110
+ transform: scale(1.02);
111
+ box-shadow: 0 8px 25px rgba(238, 68, 16, 0.5);
112
+ }
113
+
114
+ .btn-premium:disabled {
115
+ opacity: 0.6;
116
+ cursor: not-allowed;
117
+ }
118
+
119
+ .loader {
120
+ display: none;
121
+ width: 40px;
122
+ height: 40px;
123
+ border: 3px solid rgba(255,255,255,0.1);
124
+ border-top: 3px solid var(--primary);
125
+ border-radius: 50%;
126
+ animation: spin 1s linear infinite;
127
+ margin: 2rem auto;
128
+ }
129
+
130
+ @keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
131
+
132
+ .alert-premium {
133
+ background: rgba(40, 40, 45, 0.9);
134
+ border: 1px solid var(--primary);
135
+ color: #fff;
136
+ border-radius: 12px;
137
+ padding: 1rem;
138
+ margin-top: 1.5rem;
139
+ text-align: center;
140
+ }
141
+
142
+ .file-list {
143
+ margin-top: 1.5rem;
144
+ padding: 1rem;
145
+ background: rgba(0, 0, 0, 0.2);
146
+ border-radius: 10px;
147
+ font-size: 0.9rem;
148
+ }
149
+ </style>
150
+ </head>
151
+
152
+ <body>
153
+ <div class="top-bar d-flex justify-content-between align-items-center">
154
+ <h4 class="mb-0 text-white" style="font-weight: 600; letter-spacing: -0.5px;">AI EXtractor</h4>
155
+ <div class="d-none d-md-flex">
156
+ <a href="#" class="tab-btn active">Visiting Card</a>
157
+ <a href="https://webashalarforml-resumeextractor2.hf.space/" class="tab-btn">Resume Detail</a>
158
+ </div>
159
+ </div>
160
+
161
+ <div class="container py-5">
162
+ <div class="row justify-content-center">
163
+ <div class="col-lg-6">
164
+ <div class="glass-card">
165
+ <h1 class="premium-title">Card Scanner <small style="display: block; font-size: 0.4em; letter-spacing: 2px; color: var(--primary); margin-top: 5px;">POWERED BY GROQ VLM</small></h1>
166
+
167
+ <form id="uploadForm" action="{{ url_for('upload_file') }}" method="POST" enctype="multipart/form-data">
168
+ <div class="upload-area" onclick="document.getElementById('fileInput').click()">
169
+ <svg xmlns="http://www.w3.org/2000/svg" width="48" height="48" fill="currentColor" class="bi bi-cloud-upload text-muted mb-3" viewBox="0 0 16 16">
170
+ <path fill-rule="evenodd" d="M4.406 3.342A5.53 5.53 0 0 1 8 2c2.69 0 4.923 2 5.166 4.579C14.758 6.804 16 8.137 16 9.773 16 11.569 14.502 13 12.687 13H3.781C1.708 13 0 11.366 0 9.318c0-1.763 1.266-3.223 2.942-3.593.143-.863.698-1.723 1.464-2.383zm.653.757c-.757.653-1.153 1.44-1.153 2.056v.448l-.445.049C2.064 6.805 1 7.952 1 9.318 1 10.785 2.23 12 3.781 12h8.906C13.98 12 15 10.988 15 9.773c0-1.216-1.02-2.228-2.313-2.228h-.5v-.5C12.188 4.825 10.328 3 8 3a4.53 4.53 0 0 0-2.941 1.1z"/>
171
+ <path fill-rule="evenodd" d="M7.646 5.146a.5.5 0 0 1 .708 0l2 2a.5.5 0 0 1-.708.708L8.5 6.707V10.5a.5.5 0 0 1-1 0V6.707L6.354 7.854a.5.5 0 1 1-.708-.708l2-2z"/>
172
+ </svg>
173
+ <p class="mb-0 text-muted">Click or Drag & Drop Business Cards</p>
174
+ <input type="file" name="files" id="fileInput" multiple style="display: none;" required onchange="updateFileList(this)" />
175
+ </div>
176
+
177
+ <div id="fileList" class="file-list" style="display: none;"></div>
178
+
179
+ <button type="submit" id="submitBtn" class="btn-premium mt-3">Start Extraction</button>
180
+ </form>
181
+
182
+ <div class="loader" id="loader"></div>
183
+ <p id="loadingMsg" class="text-center text-muted small mt-2" style="display: none;">Analyzing images with AI engine...</p>
184
+
185
+ {% if session.get('uploaded_files') %}
186
+ <div class="mt-4 pt-3 border-top border-secondary">
187
+ <div class="d-flex justify-content-between align-items-center">
188
+ <span class="small text-muted">Ready to process: {{ session.get('uploaded_files')|length }} files</span>
189
+ <a href="{{ url_for('reset_upload') }}" class="text-danger small text-decoration-none">Clear All</a>
190
+ </div>
191
+ </div>
192
+ {% endif %}
193
+
194
+ {% with messages = get_flashed_messages() %}
195
+ {% if messages %}
196
+ <div class="alert-premium" id="flashMessage">
197
+ {{ messages[0] }}
198
+ </div>
199
+ {% endif %}
200
+ {% endwith %}
201
+ </div>
202
+ </div>
203
+ </div>
204
+ </div>
205
+
206
+ <script>
207
+ function updateFileList(input) {
208
+ const list = document.getElementById('fileList');
209
+ if (input.files.length > 0) {
210
+ list.style.display = 'block';
211
+ list.innerHTML = '<strong>Selected:</strong><br>' +
212
+ Array.from(input.files).map(f => f.name).join('<br>');
213
+ } else {
214
+ list.style.display = 'none';
215
+ }
216
+ }
217
+
218
+ document.getElementById('uploadForm').onsubmit = function() {
219
+ document.getElementById('loader').style.display = 'block';
220
+ document.getElementById('loadingMsg').style.display = 'block';
221
+ document.getElementById('submitBtn').disabled = true;
222
+ document.getElementById('submitBtn').innerText = 'Processing...';
223
+ };
224
+
225
+ // Flash message auto-hide
226
+ setTimeout(() => {
227
+ const flash = document.getElementById('flashMessage');
228
+ if (flash) {
229
+ flash.style.transition = 'opacity 1s ease';
230
+ flash.style.opacity = '0';
231
+ setTimeout(() => flash.remove(), 1000);
232
+ }
233
+ }, 4000);
234
+ </script>
235
+ </body>
236
+ </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
templates/result.html CHANGED
@@ -1,248 +1,326 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
-
4
- <head>
5
- <meta charset="UTF-8" />
6
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
- <title>Processed Results</title>
8
- <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet" />
9
- <style>
10
- body {
11
- background-color: #1c1c1e;
12
- font-family: "Poppins", sans-serif;
13
- color: #f5f5f7;
14
- }
15
-
16
- h1 {
17
- color: #e5e5e7;
18
- text-align: center;
19
- }
20
-
21
- .cont {
22
- background-color: #2c2c2e;
23
- padding: 30px;
24
- border-radius: 15px;
25
- box-shadow: 0 8px 16px rgba(0, 0, 0, 0.5);
26
- transition: 1s ease;
27
- }
28
-
29
- .section-title {
30
- color: #ee4410;
31
- font-size: 1.5rem;
32
- font-weight: bold;
33
- margin-top: 20px;
34
- border-bottom: 2px solid #ee4410;
35
- padding-bottom: 10px;
36
- }
37
-
38
- .card {
39
- background-color: #3a3a3c;
40
- color: #f5f5f7;
41
- border-radius: 10px;
42
- margin-bottom: 15px;
43
- padding: 15px;
44
- box-shadow: 0 4px 10px rgba(0, 0, 0, 0.3);
45
- transition: background-color 0.3s ease;
46
- }
47
-
48
- .card:hover {
49
- background-color: #3a3a3c98;
50
- }
51
-
52
- .card-title {
53
- color: #ee4410;
54
- font-size: 1.2rem;
55
- font-weight: bold;
56
- }
57
-
58
- .card-text {
59
- color: #d1d1d6;
60
- font-size: 1rem;
61
- }
62
-
63
- ul {
64
- list-style-type: none;
65
- padding-left: 0;
66
- }
67
-
68
- li::before {
69
- content: "• ";
70
- color: #ee4410;
71
- }
72
-
73
- .btn-reset {
74
- background-color: #ff9f0a;
75
- color: white;
76
- border: none;
77
- padding: 10px 20px;
78
- border-radius: 5px;
79
- transition: background-color 0.3s ease;
80
- margin-bottom: 20px;
81
- }
82
-
83
- .btn-reset:hover {
84
- background-color: #e03a2f;
85
- }
86
-
87
- .alert {
88
- text-align: center;
89
- position: absolute;
90
- top: 0;
91
- right: 15%;
92
- }
93
-
94
- .image-container img {
95
- max-width: 100%;
96
- border-radius: 10px;
97
- }
98
- </style>
99
- </head>
100
-
101
- <body>
102
- <div class="container">
103
- {% with messages = get_flashed_messages() %} {% if messages %}
104
- <div class="alert alert-success mt-4 " id="flash-message">
105
- {{ messages[0] }}
106
- </div>
107
- {% endif %} {% endwith %}
108
- </div>
109
-
110
- <div class="container cont mt-5">
111
- <div class="d-flex align-items-center justify-content-between">
112
- <h1>Extracted Details From Image</h1>
113
- <!-- Reset Button -->
114
- <div class="text-center mt-4">
115
- <a href="{{ url_for('reset_upload') }}" class="btn btn-reset">Reset & Upload New File</a>
116
- </div>
117
- </div>
118
-
119
- {% if data %}
120
- <!-- Personal Information Section -->
121
- <section>
122
- <h3 class="section-title">Extracted Information</h3>
123
- <div class="row">
124
- <!-- Image Container on the Left -->
125
- <div class="col-md-6 image-container">
126
- <div class="card">
127
- <div class="card-body">
128
- {% if data.extracted_text.items() %}
129
- <h5 class="card-title">Extracted Image:</h5>
130
- <ul>
131
- {% for filename, text in data.extracted_text.items() %}
132
- <!--<li>{{ filename }}:</li>-->
133
- <img src="{{ Img[filename] }}" alt="Processed Image" class="img-fluid" />
134
- {% endfor %}
135
- </ul>
136
- {% endif %}
137
- </div>
138
- </div>
139
- </div>
140
-
141
- <!-- Extracted Text on the Right -->
142
- <div class="col-md-6">
143
- <div class="card">
144
- <div class="card-body">
145
- {% if data.name and data.name is iterable and data.name is not string %}
146
- <h5 class="card-title">Name:</h5>
147
- <ul>
148
- {% for value in data.name %}
149
- {% if value|lower != 'not found' %}
150
- <li>{{ value }}</li>
151
- {% endif %}
152
- {% endfor %}
153
- </ul>
154
- {% endif %}
155
-
156
- {% if data.Designation and data.Designation is iterable and data.Designation is not string %}
157
- <h5 class="card-title">Designation:</h5>
158
- <ul>
159
- {% for value in data.Designation %}
160
- {% if value|lower != 'not found' %}
161
- <li>{{ value }}</li>
162
- {% endif %}
163
- {% endfor %}
164
- </ul>
165
- {% endif %}
166
-
167
- {% if data.contact_number and data.contact_number is iterable and data.contact_number is not string %}
168
- <h5 class="card-title">Contact number:</h5>
169
- <ul>
170
- {% for value in data.contact_number %}
171
- {% if value|lower != 'not found' %}
172
- <li>{{ value }}</li>
173
- {% endif %}
174
- {% endfor %}
175
- </ul>
176
- {% endif %}
177
-
178
- {% if data.email and data.email is iterable and data.email is not string %}
179
- <h5 class="card-title">Email:</h5>
180
- <ul>
181
- {% for value in data.email %}
182
- {% if value|lower != 'not found' %}
183
- <li>{{ value }}</li>
184
- {% endif %}
185
- {% endfor %}
186
- </ul>
187
- {% endif %}
188
-
189
- {% if data.Location and data.Location is iterable and data.Location is not string %}
190
- <h5 class="card-title">Location:</h5>
191
- <ul>
192
- {% for value in data.Location %}
193
- {% if value|lower != 'not found' %}
194
- <li>{{ value }}</li>
195
- {% endif %}
196
- {% endfor %}
197
- </ul>
198
- {% endif %}
199
-
200
- {% if data.Link and data.Link is iterable and data.Link is not string %}
201
- <h5 class="card-title">Link:</h5>
202
- <ul>
203
- {% for value in data.Link %}
204
- {% if value|lower != 'not found' %}
205
- <li>{{ value }}</li>
206
- {% endif %}
207
- {% endfor %}
208
- </ul>
209
- {% endif %}
210
-
211
- {% if data.Company and data.Company is iterable and data.Company is not string %}
212
- <h5 class="card-title">Organisation:</h5>
213
- <ul>
214
- {% for value in data.Company %}
215
- {% if value|lower != 'not found' %}
216
- <li>{{ value }}</li>
217
- {% endif %}
218
- {% endfor %}
219
- </ul>
220
- {% endif %}
221
- </div>
222
- </div>
223
- </div>
224
- </div>
225
- </section>
226
-
227
- {% else %}
228
- <p>No data available. Please process a file.</p>
229
- {% endif %}
230
- </div>
231
-
232
- <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.bundle.min.js"></script>
233
- <script src="https://code.jquery.com/jquery-3.5.1.slim.min.js"></script>
234
- <script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.5.4/dist/umd/popper.min.js"></script>
235
- <script src="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"></script>
236
- <script>
237
- setTimeout(function () {
238
- let flashMessage = document.getElementById("flash-message");
239
- if (flashMessage) {
240
- flashMessage.style.transition = "opacity 1s ease";
241
- flashMessage.style.opacity = 0;
242
- setTimeout(() => flashMessage.remove(), 1000);
243
- }
244
- }, 3000);
245
- </script>
246
- </body>
247
-
248
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>Processed Results</title>
8
+ <link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600&display=swap" rel="stylesheet">
9
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet" />
10
+ <style>
11
+ :root {
12
+ --primary: #ee4410;
13
+ --secondary: #ff9f0a;
14
+ --bg-dark: #0a0a0c;
15
+ --card-bg: rgba(30, 30, 35, 0.7);
16
+ --text-glow: 0 0 10px rgba(238, 68, 16, 0.5);
17
+ }
18
+
19
+ body {
20
+ background-color: var(--bg-dark);
21
+ font-family: 'Outfit', sans-serif;
22
+ color: #f5f5f7;
23
+ overflow-x: hidden;
24
+ background: radial-gradient(circle at 50% 50%, #1a1a1f 0%, #0a0a0c 100%);
25
+ min-height: 100vh;
26
+ }
27
+
28
+ .glass-card {
29
+ background: var(--card-bg);
30
+ backdrop-filter: blur(12px);
31
+ border: 1px solid rgba(255, 255, 255, 0.1);
32
+ border-radius: 20px;
33
+ box-shadow: 0 15px 35px rgba(0, 0, 0, 0.4);
34
+ transition: all 0.4s cubic-bezier(0.175, 0.885, 0.32, 1.275);
35
+ padding: 2rem;
36
+ margin-bottom: 2rem;
37
+ }
38
+
39
+ .glass-card:hover {
40
+ transform: translateY(-5px);
41
+ border-color: rgba(238, 68, 16, 0.3);
42
+ box-shadow: 0 20px 45px rgba(238, 68, 16, 0.1);
43
+ }
44
+
45
+ .premium-title {
46
+ background: linear-gradient(135deg, #fff 0%, #aaa 100%);
47
+ -webkit-background-clip: text;
48
+ background-clip: text;
49
+ -webkit-text-fill-color: transparent;
50
+ font-weight: 600;
51
+ letter-spacing: -1px;
52
+ text-shadow: var(--text-glow);
53
+ margin-bottom: 1.5rem;
54
+ }
55
+
56
+ .section-title {
57
+ color: var(--primary);
58
+ font-size: 1.1rem;
59
+ text-transform: uppercase;
60
+ letter-spacing: 2px;
61
+ margin-bottom: 1.5rem;
62
+ display: flex;
63
+ align-items: center;
64
+ }
65
+
66
+ .section-title::after {
67
+ content: '';
68
+ flex: 1;
69
+ height: 1px;
70
+ background: linear-gradient(90deg, var(--primary), transparent);
71
+ margin-left: 1rem;
72
+ }
73
+
74
+ .data-item {
75
+ margin-bottom: 1.5rem;
76
+ border-left: 3px solid var(--primary);
77
+ padding-left: 1rem;
78
+ animation: fadeIn 0.6s ease-out forwards;
79
+ opacity: 0;
80
+ }
81
+
82
+ @keyframes fadeIn {
83
+ from { opacity: 0; transform: translateX(-10px); }
84
+ to { opacity: 1; transform: translateX(0); }
85
+ }
86
+
87
+ .data-label {
88
+ color: #8e8e93;
89
+ font-size: 0.85rem;
90
+ margin-bottom: 0.2rem;
91
+ }
92
+
93
+ .data-value {
94
+ font-size: 1.1rem;
95
+ color: #fff;
96
+ list-style: none;
97
+ padding: 0;
98
+ }
99
+
100
+ .data-value li {
101
+ margin-bottom: 0.4rem;
102
+ }
103
+
104
+ .result-img-container {
105
+ border-radius: 15px;
106
+ overflow: hidden;
107
+ border: 1px solid rgba(255, 255, 255, 0.1);
108
+ position: relative;
109
+ }
110
+
111
+ .result-img-container img {
112
+ width: 100%;
113
+ height: auto;
114
+ display: block;
115
+ transition: transform 0.5s ease;
116
+ }
117
+
118
+ .result-img-container:hover img {
119
+ transform: scale(1.05);
120
+ }
121
+
122
+ .btn-premium {
123
+ background: linear-gradient(135deg, var(--primary) 0%, var(--secondary) 100%);
124
+ border: none;
125
+ color: white;
126
+ padding: 0.8rem 2rem;
127
+ border-radius: 30px;
128
+ font-weight: 600;
129
+ box-shadow: 0 5px 15px rgba(238, 68, 16, 0.3);
130
+ text-decoration: none;
131
+ transition: all 0.3s ease;
132
+ display: inline-block;
133
+ }
134
+
135
+ .btn-premium:hover {
136
+ transform: scale(1.05);
137
+ box-shadow: 0 8px 25px rgba(238, 68, 16, 0.5);
138
+ color: #fff;
139
+ }
140
+
141
+ .alert-premium {
142
+ background: rgba(40, 40, 45, 0.9);
143
+ border: 1px solid var(--primary);
144
+ color: #fff;
145
+ border-radius: 12px;
146
+ padding: 1rem;
147
+ animation: slideDown 0.5s cubic-bezier(0.19, 1, 0.22, 1);
148
+ }
149
+
150
+ @keyframes slideDown {
151
+ from { transform: translateY(-50px); opacity: 0; }
152
+ to { transform: translateY(0); opacity: 1; }
153
+ }
154
+
155
+ .debug-panel {
156
+ margin-top: 4rem;
157
+ padding: 2rem;
158
+ background: rgba(0, 0, 0, 0.3);
159
+ border-top: 1px solid rgba(255, 255, 255, 0.05);
160
+ font-family: monospace;
161
+ font-size: 0.8rem;
162
+ color: #666;
163
+ }
164
+ </style>
165
+ </style>
166
+ </head>
167
+
168
+ <body>
169
+ <div class="container py-5">
170
+ {% with messages = get_flashed_messages() %}
171
+ {% if messages %}
172
+ <div class="alert-premium mb-4" id="flash-message">
173
+ {{ messages[0] }}
174
+ </div>
175
+ {% endif %}
176
+ {% endwith %}
177
+
178
+ <div class="d-flex align-items-center justify-content-between mb-5">
179
+ <h1 class="premium-title mb-0">Extraction Analysis <small style="font-size: 0.4em; color: var(--primary); letter-spacing: 2px; font-weight: 400; display: block; text-align: left;">v2.1 GOLD</small></h1>
180
+ <a href="{{ url_for('reset_upload') }}" class="btn-premium" style="width: auto;">Process New Image</a>
181
+ </div>
182
+
183
+ {% if data %}
184
+ <div class="row g-4">
185
+ <!-- Source Image Column -->
186
+ <div class="col-lg-5">
187
+ <div class="glass-card">
188
+ <h3 class="section-title">Source Image</h3>
189
+ {% if Img %}
190
+ {% for filename, result_path in Img.items() %}
191
+ <div class="result-img-container mb-3">
192
+ <img src="{{ result_path }}" alt="Analyzed Document" />
193
+ </div>
194
+ <p class="text-muted small">File: {{ filename | basename }}</p>
195
+ {% endfor %}
196
+ {% else %}
197
+ <div class="p-4 text-center text-muted">No image path available</div>
198
+ {% endif %}
199
+ </div>
200
+ </div>
201
+
202
+ <!-- Extracted Details Column -->
203
+ <div class="col-lg-7">
204
+ <div class="glass-card">
205
+ <h3 class="section-title">Verified Details</h3>
206
+
207
+ <div class="row">
208
+ <!-- Left Data Sub-col -->
209
+ <div class="col-md-6">
210
+ {% if data.name %}
211
+ <div class="data-item" style="animation-delay: 0.1s">
212
+ <div class="data-label">Full Name</div>
213
+ <ul class="data-value">
214
+ {% for val in data.name %}<li>{{ val }}</li>{% endfor %}
215
+ </ul>
216
+ </div>
217
+ {% endif %}
218
+
219
+ {% if data.Designation %}
220
+ <div class="data-item" style="animation-delay: 0.2s">
221
+ <div class="data-label">Designation</div>
222
+ <ul class="data-value">
223
+ {% for val in data.Designation %}<li>{{ val }}</li>{% endfor %}
224
+ </ul>
225
+ </div>
226
+ {% endif %}
227
+
228
+ {% if data.Company %}
229
+ <div class="data-item" style="animation-delay: 0.3s">
230
+ <div class="data-label">Organization</div>
231
+ <ul class="data-value">
232
+ {% for val in data.Company %}<li>{{ val }}</li>{% endfor %}
233
+ </ul>
234
+ </div>
235
+ {% endif %}
236
+ </div>
237
+
238
+ <!-- Right Data Sub-col -->
239
+ <div class="col-md-6">
240
+ {% if data.contact_number %}
241
+ <div class="data-item" style="animation-delay: 0.4s">
242
+ <div class="data-label">Phone Numbers</div>
243
+ <ul class="data-value">
244
+ {% for val in data.contact_number %}<li>{{ val }}</li>{% endfor %}
245
+ </ul>
246
+ </div>
247
+ {% endif %}
248
+
249
+ {% if data.email %}
250
+ <div class="data-item" style="animation-delay: 0.5s">
251
+ <div class="data-label">Email Addresses</div>
252
+ <ul class="data-value">
253
+ {% for val in data.email %}<li>{{ val }}</li>{% endfor %}
254
+ </ul>
255
+ </div>
256
+ {% endif %}
257
+
258
+ {% if data.Location %}
259
+ <div class="data-item" style="animation-delay: 0.6s">
260
+ <div class="data-label">Address</div>
261
+ <ul class="data-value">
262
+ {% for val in data.Location %}<li>{{ val }}</li>{% endfor %}
263
+ </ul>
264
+ </div>
265
+ {% endif %}
266
+
267
+ {% if data.Link %}
268
+ <div class="data-item" style="animation-delay: 0.7s">
269
+ <div class="data-label">Social/Web Links</div>
270
+ <ul class="data-value">
271
+ {% for val in data.Link %}<li>{{ val }}</li>{% endfor %}
272
+ </ul>
273
+ </div>
274
+ {% endif %}
275
+ </div>
276
+ </div>
277
+
278
+
279
+ {% if data.status_message %}
280
+ <div class="mt-4 pt-3 border-top border-secondary text-end">
281
+ <span class="badge rounded-pill bg-dark text-muted" style="font-size: 0.65rem; border: 1px solid rgba(255,255,255,0.05)">{{ data.status_message }}</span>
282
+ </div>
283
+ {% endif %}
284
+ </div>
285
+ </div>
286
+ </div>
287
+
288
+ <!-- Debug / Raw Output Panel -->
289
+ <div class="debug-panel rounded">
290
+ <h5 class="mb-3" style="color: #444; font-size: 0.9rem">SYSTEM_LOG :: RAW_EXTRACTION_BUFFER</h5>
291
+ {% for filename, raw in data.extracted_text.items() %}
292
+ <div class="mb-4">
293
+ <div class="mb-1 text-primary">>> {{ filename | basename }}</div>
294
+ <div class="p-3 bg-black rounded" style="color: #0f0; opacity: 0.8; font-size: 0.85rem">
295
+ {{ raw }}
296
+ </div>
297
+ </div>
298
+ {% endfor %}
299
+ </div>
300
+
301
+ {% else %}
302
+ <div class="text-center glass-card py-5">
303
+ <h2 class="premium-title">Waiting for Data...</h2>
304
+ <p class="text-muted">No analysis results found in session.</p>
305
+ <a href="{{ url_for('index') }}" class="btn-premium mt-3">Back to Upload</a>
306
+ </div>
307
+ {% endif %}
308
+ </div>
309
+
310
+ <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.bundle.min.js"></script>
311
+ <script>
312
+ // Auto-remove flash messages
313
+ setTimeout(function () {
314
+ let flashMessage = document.getElementById("flash-message");
315
+ if (flashMessage) {
316
+ flashMessage.style.transition = "all 0.8s ease";
317
+ flashMessage.style.opacity = 0;
318
+ flashMessage.style.transform = "translateY(-20px)";
319
+ setTimeout(() => flashMessage.remove(), 800);
320
+ }
321
+ }, 4000);
322
+ </script>
323
+ </body>
324
+ </html>
325
+
326
+ </html>
utility/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.77 kB). View file
 
utility/__pycache__/utils.cpython-312.pyc ADDED
Binary file (27 kB). View file
 
utility/__pycache__/utils.cpython-313.pyc ADDED
Binary file (27.2 kB). View file
 
utility/utils.py CHANGED
@@ -1,700 +1,132 @@
1
- # libraries
2
  import os
3
- from huggingface_hub import InferenceClient
4
- from dotenv import load_dotenv
5
  import json
6
  import re
7
- #import easyocr
8
- from PIL import Image, ImageEnhance, ImageDraw
9
- import cv2
10
- import numpy as np
11
- from paddleocr import PaddleOCR
12
  import logging
13
- from datetime import datetime
14
-
15
- # Configure logging
16
- logging.basicConfig(
17
- level=logging.INFO,
18
- handlers=[
19
- logging.StreamHandler() # Remove FileHandler and log only to the console
20
- ]
21
- )
22
-
23
- # Set the PaddleOCR home directory to a writable location
24
- import os
25
-
26
- os.environ['PADDLEOCR_HOME'] = '/tmp/.paddleocr'
27
-
28
- RESULT_FOLDER = 'static/results/'
29
- JSON_FOLDER = 'static/json/'
30
-
31
- if not os.path.exists('/tmp/.paddleocr'):
32
- os.makedirs(RESULT_FOLDER, exist_ok=True)
33
-
34
- # Check if PaddleOCR home directory is writable
35
- if not os.path.exists('/tmp/.paddleocr'):
36
- os.makedirs('/tmp/.paddleocr', exist_ok=True)
37
- logging.info("Created PaddleOCR home directory.")
38
- else:
39
- logging.info("PaddleOCR home directory exists.")
40
-
41
- # Load environment variables from .env file
42
- load_dotenv()
43
-
44
- # Authenticate with Hugging Face
45
- HFT = os.getenv('HF_TOKEN')
46
-
47
- # Initialize the InferenceClient
48
- client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=HFT)
49
-
50
- def load_image(image_path):
51
- ext = os.path.splitext(image_path)[1].lower()
52
- if ext in ['.png', '.jpg', '.jpeg', '.webp', '.tiff']:
53
- image = cv2.imread(image_path)
54
- if image is None:
55
- raise ValueError(f"Failed to load image from {image_path}. The file may be corrupted or unreadable.")
56
- return image
57
- else:
58
- raise ValueError(f"Unsupported image format: {ext}")
59
-
60
- # Function for upscaling image using OpenCV's INTER_CUBIC
61
- def upscale_image(image, scale=2):
62
- height, width = image.shape[:2]
63
- upscaled_image = cv2.resize(image, (width * scale, height * scale), interpolation=cv2.INTER_CUBIC)
64
- return upscaled_image
65
-
66
- # Function to denoise the image (reduce noise)
67
- def reduce_noise(image):
68
- return cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
69
-
70
- # Function to sharpen the image
71
- def sharpen_image(image):
72
- kernel = np.array([[0, -1, 0],
73
- [-1, 5, -1],
74
- [0, -1, 0]])
75
- sharpened_image = cv2.filter2D(image, -1, kernel)
76
- return sharpened_image
77
-
78
- # Function to increase contrast and enhance details without changing color
79
- def enhance_image(image):
80
- pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
81
- enhancer = ImageEnhance.Contrast(pil_img)
82
- enhanced_image = enhancer.enhance(1.5)
83
- enhanced_image_bgr = cv2.cvtColor(np.array(enhanced_image), cv2.COLOR_RGB2BGR)
84
- return enhanced_image_bgr
85
-
86
- # Complete function to process image
87
- def process_image(image_path, scale=2):
88
- # Load the image
89
- image = load_image(image_path)
90
-
91
- # Upscale the image
92
- upscaled_image = upscale_image(image, scale)
93
-
94
- # Reduce noise
95
- denoised_image = reduce_noise(upscaled_image)
96
-
97
- # Sharpen the image
98
- sharpened_image = sharpen_image(denoised_image)
99
-
100
- # Enhance the image contrast and details without changing color
101
- final_image = enhance_image(sharpened_image)
102
-
103
- return final_image
104
-
105
- # Function for OCR with PaddleOCR, returning both text and bounding boxes
106
- def ocr_with_paddle(img):
107
- final_text = ''
108
- boxes = []
109
-
110
- # Initialize PaddleOCR
111
- # In /app/utility/utils.py
112
- ocr = PaddleOCR(
113
- use_angle_cls=True,
114
- lang='en',
115
- enable_mkldnn=False, # <--- Add this line to disable the failing optimization
116
- use_gpu=False # Ensure this is False if you are on a CPU-only container
117
- )
118
- # ocr = PaddleOCR(
119
- # lang='en',
120
- # use_angle_cls=True,
121
- # det_model_dir=os.path.join(os.environ['PADDLEOCR_HOME'], 'whl/det'),
122
- # rec_model_dir=os.path.join(os.environ['PADDLEOCR_HOME'], 'whl/rec/en/en_PP-OCRv4_rec_infer'),
123
- # cls_model_dir=os.path.join(os.environ['PADDLEOCR_HOME'], 'whl/cls/ch_ppocr_mobile_v2.0_cls_infer')
124
- # )
125
- # ocr = PaddleOCR(
126
- # use_angle_cls=True,
127
- # lang='en',
128
- # det_model_dir='/app/paddleocr_models/whl/det/ch_ppocr_mobile_v2.0_det_infer',
129
- # rec_model_dir='/app/paddleocr_models/whl/rec/ch_ppocr_mobile_v2.0_rec_infer',
130
- # cls_model_dir='/app/paddleocr_models/whl/cls/ch_ppocr_mobile_v2.0_cls_infer'
131
- # )
132
-
133
-
134
- # Check if img is a file path or an image array
135
- if isinstance(img, str):
136
- img = cv2.imread(img)
137
-
138
- # Perform OCR
139
- result = ocr.ocr(img)
140
-
141
- # Iterate through the OCR result
142
- for line in result[0]:
143
- # Check how many values are returned (2 or 3) and unpack accordingly
144
- if len(line) == 3:
145
- box, (text, confidence), _ = line # When 3 values are returned
146
- elif len(line) == 2:
147
- box, (text, confidence) = line # When only 2 values are returned
148
-
149
- # Store the recognized text and bounding boxes
150
- final_text += ' ' + text # Extract the text from the tuple
151
- boxes.append(box)
152
-
153
- # Draw the bounding box
154
- points = [(int(point[0]), int(point[1])) for point in box]
155
- cv2.polylines(img, [np.array(points)], isClosed=True, color=(0, 255, 0), thickness=2)
156
-
157
- # Store the image with bounding boxes in a variable
158
- img_with_boxes = img
159
-
160
- return final_text, img_with_boxes
161
-
162
- def extract_text_from_images(image_paths):
163
- all_extracted_texts = {}
164
- all_extracted_imgs = {}
165
- for image_path in image_paths:
166
- try:
167
- # Enhance the image before OCR
168
- enhanced_image = process_image(image_path, scale=2)
169
-
170
- # Perform OCR on the enhanced image and get boxes
171
- result, img_with_boxes = ocr_with_paddle(enhanced_image)
172
-
173
- # Draw bounding boxes on the processed image
174
- img_result = Image.fromarray(enhanced_image)
175
- #img_with_boxes = draw_boxes(img_result, boxes)
176
-
177
- # genrating unique id to save the images
178
- # Get the current date and time
179
- current_time = datetime.now()
180
-
181
- # Format it as a string to create a unique ID
182
- unique_id = current_time.strftime("%Y%m%d%H%M%S%f")
183
-
184
- #print(unique_id)
185
-
186
- # Save the image with boxes
187
- result_image_path = os.path.join(RESULT_FOLDER, f'result_{unique_id}_{os.path.basename(image_path)}')
188
- #img_with_boxes.save(result_image_path)
189
- cv2.imwrite(result_image_path, img_with_boxes)
190
-
191
- # Store the text and image result paths
192
- all_extracted_texts[image_path] = result
193
- all_extracted_imgs[image_path] = result_image_path
194
- except ValueError as ve:
195
- print(f"Error processing image {image_path}: {ve}")
196
- continue # Continue to the next image if there's an error
197
-
198
- # Convert to JSON-compatible structure
199
- all_extracted_imgs_json = {str(k): str(v) for k, v in all_extracted_imgs.items()}
200
- return all_extracted_texts, all_extracted_imgs_json
201
-
202
- # Function to call the Gemma model and process the output as Json
203
- # def Data_Extractor(data, client=client):
204
- # text = f'''Act as a Text extractor for the following text given in text: {data}
205
- # extract text in the following output JSON string:
206
- # {{
207
- # "Name": ["Identify and Extract All the person's name from the text."],
208
- # "Designation": ["Extract All the designation or job title mentioned in the text."],
209
- # "Company": ["Extract All the company or organization name if mentioned."],
210
- # "Contact": ["Extract All phone number, including country codes if present."],
211
- # "Address": ["Extract All the full postal address or location mentioned in the text."],
212
- # "Email": ["Identify and Extract All valid email addresses mentioned in the text else 'Not found'."],
213
- # "Link": ["Identify and Extract any website URLs or social media links present in the text."]
214
- # }}
215
- # Output:
216
- # '''
217
-
218
- # # Call the API for inference
219
- # response = client.text_generation(text, max_new_tokens=1000)#, temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.2)
220
-
221
- # print("parse in text ---:",response)
222
-
223
- # # Convert the response text to JSON
224
- # try:
225
- # json_data = json.loads(response)
226
- # print("Json_data-------------->",json_data)
227
- # return json_data
228
- # except json.JSONDecodeError as e:
229
- # return {"error": f"Error decoding JSON: {e}"}
230
- def Data_Extractor(data):
231
- url = "https://api.groq.com/openai/v1/chat/completions"
232
-
233
- headers = {
234
- "Content-Type": "application/json",
235
- "Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
236
  }
237
 
238
- prompt = f"""
239
- You are a strict JSON generator.
240
-
241
- Extract structured data from the following text.
242
-
243
- Return ONLY valid JSON. No explanation. No markdown.
244
-
245
- Schema:
246
- {{
247
- "Name": [],
248
- "Designation": [],
249
- "Company": [],
250
- "Contact": [],
251
- "Address": [],
252
- "Email": [],
253
- "Link": []
254
- }}
255
-
256
- Rules:
257
- - Always return all keys
258
- - If nothing found → return empty list []
259
- - Do NOT return "Not found"
260
- - Ensure valid JSON format
261
-
262
- Text:
263
- {data}
264
- """
265
-
266
- payload = {
267
- "model": "llama-3.3-70b-versatile",
268
- "messages": [
269
- {"role": "user", "content": prompt}
270
- ],
271
- "temperature": 0.2, # 🔥 IMPORTANT: lower = more structured
272
- "max_tokens": 1024,
273
- "top_p": 1,
274
- "stream": False
 
 
 
 
 
275
  }
276
 
277
- response = requests.post(url, headers=headers, json=payload)
278
-
279
- if response.status_code != 200:
280
- return {"error": response.text}
281
-
282
- result = response.json()
283
-
284
- # Extract model output
285
- content = result["choices"][0]["message"]["content"]
286
-
287
- print("RAW LLM OUTPUT:\n", content)
288
-
289
- # 🔧 Clean response (important)
290
- content = content.strip()
291
-
292
- # Remove markdown if model adds ```json
293
- if content.startswith("```"):
294
- content = content.split("```")[1]
295
-
296
- try:
297
- json_data = json.loads(content)
298
- return json_data
299
- except json.JSONDecodeError as e:
300
- print("JSON ERROR:", e)
301
- return {"error": "Invalid JSON from model", "raw": content}
302
-
303
- # For have text compatible to the llm
304
- def json_to_llm_str(textJson):
305
- str=''
306
- for file,item in textJson.items():
307
- str+=item + ' '
308
- return str
309
-
310
- # Define the RE for extracting the contact details like number, mail , portfolio, website etc
311
- def extract_contact_details(text):
312
- # Regex patterns
313
- # Phone numbers with at least 5 digits in any segment
314
- combined_phone_regex = re.compile(r'''
315
- (?:
316
- #(?:(?:\+91[-.\s]?)?\d{5}[-.\s]?\d{5})|(?:\+?\d{1,3})?[-.\s()]?\d{5,}[-.\s()]?\d{5,}[-.\s()]?\d{1,9} | /^[\.-)( ]*([0-9]{3})[\.-)( ]*([0-9]{3})[\.-)( ]*([0-9]{4})$/ |
317
- \+1\s\(\d{3}\)\s\d{3}-\d{4} | # USA/Canada Intl +1 (XXX) XXX-XXXX
318
- \(\d{3}\)\s\d{3}-\d{4} | # USA/Canada STD (XXX) XXX-XXXX
319
- \(\d{3}\)\s\d{3}\s\d{4} | # USA/Canada (XXX) XXX XXXX
320
- \(\d{3}\)\s\d{3}\s\d{3} | # USA/Canada (XXX) XXX XXX
321
- \+1\d{10} | # +1 XXXXXXXXXX
322
- \d{10} | # XXXXXXXXXX
323
- \+44\s\d{4}\s\d{6} | # UK Intl +44 XXXX XXXXXX
324
- \+44\s\d{3}\s\d{3}\s\d{4} | # UK Intl +44 XXX XXX XXXX
325
- 0\d{4}\s\d{6} | # UK STD 0XXXX XXXXXX
326
- 0\d{3}\s\d{3}\s\d{4} | # UK STD 0XXX XXX XXXX
327
- \+44\d{10} | # +44 XXXXXXXXXX
328
- 0\d{10} | # 0XXXXXXXXXX
329
- \+61\s\d\s\d{4}\s\d{4} | # Australia Intl +61 X XXXX XXXX
330
- 0\d\s\d{4}\s\d{4} | # Australia STD 0X XXXX XXXX
331
- \+61\d{9} | # +61 XXXXXXXXX
332
- 0\d{9} | # 0XXXXXXXXX
333
- \+91\s\d{5}-\d{5} | # India Intl +91 XXXXX-XXXXX
334
- \+91\s\d{4}-\d{6} | # India Intl +91 XXXX-XXXXXX
335
- \+91\s\d{10} | # India Intl +91 XXXXXXXXXX
336
- \+91\s\d{3}\s\d{3}\s\d{4} | # India Intl +91 XXX XXX XXXX
337
- \+91\s\d{3}-\d{3}-\d{4} | # India Intl +91 XXX-XXX-XXXX
338
- \+91\s\d{2}\s\d{4}\s\d{4} | # India Intl +91 XX XXXX XXXX
339
- \+91\s\d{2}-\d{4}-\d{4} | # India Intl +91 XX-XXXX-XXXX
340
- \+91\s\d{5}\s\d{5} | # India Intl +91 XXXXX XXXXX
341
- \d{5}\s\d{5} | # India XXXXX XXXXX
342
- \d{5}-\d{5} | # India XXXXX-XXXXX
343
- 0\d{2}-\d{7} | # India STD 0XX-XXXXXXX
344
- \+91\d{10} | # +91 XXXXXXXXXX
345
- \d{10} | # XXXXXXXXXX # Here is the regex to handle all possible combination of the contact
346
- \d{6}-\d{4} | # XXXXXX-XXXX
347
- \d{4}-\d{6} | # XXXX-XXXXXX
348
- \d{3}\s\d{3}\s\d{4} | # XXX XXX XXXX
349
- \d{3}-\d{3}-\d{4} | # XXX-XXX-XXXX
350
- \d{4}\s\d{3}\s\d{3} | # XXXX XXX XXX
351
- \d{4}-\d{3}-\d{3} | # XXXX-XXX-XXX #-----
352
- \+49\s\d{4}\s\d{8} | # Germany Intl +49 XXXX XXXXXXXX
353
- \+49\s\d{3}\s\d{7} | # Germany Intl +49 XXX XXXXXXX
354
- 0\d{3}\s\d{8} | # Germany STD 0XXX XXXXXXXX
355
- \+49\d{12} | # +49 XXXXXXXXXXXX
356
- \+49\d{10} | # +49 XXXXXXXXXX
357
- 0\d{11} | # 0XXXXXXXXXXX
358
- \+86\s\d{3}\s\d{4}\s\d{4} | # China Intl +86 XXX XXXX XXXX
359
- 0\d{3}\s\d{4}\s\d{4} | # China STD 0XXX XXXX XXXX
360
- \+86\d{11} | # +86 XXXXXXXXXXX
361
- \+81\s\d\s\d{4}\s\d{4} | # Japan Intl +81 X XXXX XXXX
362
- \+81\s\d{2}\s\d{4}\s\d{4} | # Japan Intl +81 XX XXXX XXXX
363
- 0\d\s\d{4}\s\d{4} | # Japan STD 0X XXXX XXXX
364
- \+81\d{10} | # +81 XXXXXXXXXX
365
- \+81\d{9} | # +81 XXXXXXXXX
366
- 0\d{9} | # 0XXXXXXXXX
367
- \+55\s\d{2}\s\d{5}-\d{4} | # Brazil Intl +55 XX XXXXX-XXXX
368
- \+55\s\d{2}\s\d{4}-\d{4} | # Brazil Intl +55 XX XXXX-XXXX
369
- 0\d{2}\s\d{4}\s\d{4} | # Brazil STD 0XX XXXX XXXX
370
- \+55\d{11} | # +55 XXXXXXXXXXX
371
- \+55\d{10} | # +55 XXXXXXXXXX
372
- 0\d{10} | # 0XXXXXXXXXX
373
- \+33\s\d\s\d{2}\s\d{2}\s\d{2}\s\d{2} | # France Intl +33 X XX XX XX XX
374
- 0\d\s\d{2}\s\d{2}\s\d{2}\s\d{2} | # France STD 0X XX XX XX XX
375
- \+33\d{9} | # +33 XXXXXXXXX
376
- 0\d{9} | # 0XXXXXXXXX
377
- \+7\s\d{3}\s\d{3}-\d{2}-\d{2} | # Russia Intl +7 XXX XXX-XX-XX
378
- 8\s\d{3}\s\d{3}-\d{2}-\d{2} | # Russia STD 8 XXX XXX-XX-XX
379
- \+7\d{10} | # +7 XXXXXXXXXX
380
- 8\d{10} | # 8 XXXXXXXXXX
381
- \+27\s\d{2}\s\d{3}\s\d{4} | # South Africa Intl +27 XX XXX XXXX
382
- 0\d{2}\s\d{3}\s\d{4} | # South Africa STD 0XX XXX XXXX
383
- \+27\d{9} | # +27 XXXXXXXXX
384
- 0\d{9} | # 0XXXXXXXXX
385
- \+52\s\d{3}\s\d{3}\s\d{4} | # Mexico Intl +52 XXX XXX XXXX
386
- \+52\s\d{2}\s\d{4}\s\d{4} | # Mexico Intl +52 XX XXXX XXXX
387
- 01\s\d{3}\s\d{4} | # Mexico STD 01 XXX XXXX
388
- \+52\d{10} | # +52 XXXXXXXXXX
389
- 01\d{7} | # 01 XXXXXXX
390
- \+234\s\d{3}\s\d{3}\s\d{4} | # Nigeria Intl +234 XXX XXX XXXX
391
- 0\d{3}\s\d{3}\s\d{4} | # Nigeria STD 0XXX XXX XXXX
392
- \+234\d{10} | # +234 XXXXXXXXXX
393
- 0\d{10} | # 0XXXXXXXXXX
394
- \+971\s\d\s\d{3}\s\d{4} | # UAE Intl +971 X XXX XXXX
395
- 0\d\s\d{3}\s\d{4} | # UAE STD 0X XXX XXXX
396
- \+971\d{8} | # +971 XXXXXXXX
397
- 0\d{8} | # 0XXXXXXXX
398
- \+54\s9\s\d{3}\s\d{3}\s\d{4} | # Argentina Intl +54 9 XXX XXX XXXX
399
- \+54\s\d{1}\s\d{4}\s\d{4} | # Argentina Intl +54 X XXXX XXXX
400
- 0\d{3}\s\d{4} | # Argentina STD 0XXX XXXX
401
- \+54\d{10} | # +54 9 XXXXXXXXXX
402
- \+54\d{9} | # +54 XXXXXXXXX
403
- 0\d{7} | # 0XXXXXXX
404
- \+966\s\d\s\d{3}\s\d{4} | # Saudi Intl +966 X XXX XXXX
405
- 0\d\s\d{3}\s\d{4} | # Saudi STD 0X XXX XXXX
406
- \+966\d{8} | # +966 XXXXXXXX
407
- 0\d{8} | # 0XXXXXXXX
408
- \+1\d{10} | # +1 XXXXXXXXXX
409
- \+1\s\d{3}\s\d{3}\s\d{4} | # +1 XXX XXX XXXX
410
- \d{5}\s\d{5} | # XXXXX XXXXX
411
- \d{10} | # XXXXXXXXXX
412
- \+44\d{10} | # +44 XXXXXXXXXX
413
- 0\d{10} | # 0XXXXXXXXXX
414
- \+61\d{9} | # +61 XXXXXXXXX
415
- 0\d{9} | # 0XXXXXXXXX
416
- \+91\d{10} | # +91 XXXXXXXXXX
417
- \+49\d{12} | # +49 XXXXXXXXXXXX
418
- \+49\d{10} | # +49 XXXXXXXXXX
419
- 0\d{11} | # 0XXXXXXXXXXX
420
- \+86\d{11} | # +86 XXXXXXXXXXX
421
- \+81\d{10} | # +81 XXXXXXXXXX
422
- \+81\d{9} | # +81 XXXXXXXXX
423
- 0\d{9} | # 0XXXXXXXXX
424
- \+55\d{11} | # +55 XXXXXXXXXXX
425
- \+55\d{10} | # +55 XXXXXXXXXX
426
- 0\d{10} | # 0XXXXXXXXXX
427
- \+33\d{9} | # +33 XXXXXXXXX
428
- 0\d{9} | # 0XXXXXXXXX
429
- \+7\d{10} | # +7 XXXXXXXXXX
430
- 8\d{10} | # 8 XXXXXXXXXX
431
- \+27\d{9} | # +27 XXXXXXXXX
432
- 0\d{9} | # 0XXXXXXXXX (South Africa STD)
433
- \+52\d{10} | # +52 XXXXXXXXXX
434
- 01\d{7} | # 01 XXXXXXX
435
- \+234\d{10} | # +234 XXXXXXXXXX
436
- 0\d{10} | # 0XXXXXXXXXX
437
- \+971\d{8} | # +971 XXXXXXXX
438
- 0\d{8} | # 0XXXXXXXX
439
- \+54\s9\s\d{10} | # +54 9 XXXXXXXXXX
440
- \+54\d{9} | # +54 XXXXXXXXX
441
- 0\d{7} | # 0XXXXXXX
442
- \+966\d{8} | # +966 XXXXXXXX
443
- 0\d{8} # 0XXXXXXXX
444
- \+\d{3}-\d{3}-\d{4}
445
- )
446
-
447
- ''',re.VERBOSE)
448
-
449
- # Email regex
450
  email_regex = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b')
451
-
452
- # URL and links regex, updated to avoid conflicts with email domains
453
- link_regex = re.compile(r'\b(?:https?:\/\/)?(?:www\.)[a-zA-Z0-9-]+\.(?:com|co\.in|co|io|org|net|edu|gov|mil|int|uk|us|in|de|au|app|tech|xyz|info|biz|fr|dev)\b')
454
-
455
- # Find all matches in the text
456
- phone_numbers = [num for num in combined_phone_regex.findall(text) if len(num) >= 5]
457
-
458
- emails = email_regex.findall(text)
459
-
460
- links_RE = [link for link in link_regex.findall(text) if len(link)>=11]
461
-
462
- # Remove profile links that might conflict with emails
463
- links_RE = [link for link in links_RE if not any(email in link for email in emails)]
464
 
465
  return {
466
- "phone_numbers": phone_numbers,
467
- "emails": emails,
468
- "links_RE": links_RE
469
- }
470
-
471
- # preprocessing the data
472
- def process_extracted_text(extracted_text):
473
- # Load JSON data
474
- data = json.dumps(extracted_text, indent=4)
475
- data = json.loads(data)
476
-
477
- # Create a single dictionary to hold combined results
478
- combined_results = {
479
- "phone_numbers": [],
480
- "emails": [],
481
- "links_RE": []
482
  }
483
-
484
- # Process each text entry
485
- for filename, text in data.items():
486
- contact_details = extract_contact_details(text)
487
- # Extend combined results with the details from this file
488
- combined_results["phone_numbers"].extend(contact_details["phone_numbers"])
489
- combined_results["emails"].extend(contact_details["emails"])
490
- combined_results["links_RE"].extend(contact_details["links_RE"])
491
-
492
- # Convert the combined results to JSON
493
- #combined_results_json = json.dumps(combined_results, indent=4)
494
- combined_results_json = combined_results
495
-
496
- # Print the final JSON results
497
- print("Combined contact details in JSON format:")
498
- print(combined_results_json)
499
-
500
- return combined_results_json
501
-
502
- # Function to remove duplicates (case-insensitive) from each list in the dictionary
503
- def remove_duplicates_case_insensitive(data_dict):
504
- for key, value_list in data_dict.items():
505
- seen = set()
506
- unique_list = []
507
-
508
- for item in value_list:
509
- if item.lower() not in seen:
510
- unique_list.append(item) # Add original item (preserving its case)
511
- seen.add(item.lower()) # Track lowercase version
512
-
513
- # Update the dictionary with unique values
514
- data_dict[key] = unique_list
515
- return data_dict
516
-
517
- # # Process the model output for parsed result
518
- # def process_resume_data(LLMdata,cont_data,extracted_text):
519
-
520
- # # # Removing duplicate emails
521
- # # unique_emails = []
522
- # # for email in cont_data['emails']:
523
- # # if not any(email.lower() == existing_email.lower() for existing_email in LLMdata['Email']):
524
- # # unique_emails.append(email)
525
-
526
- # # # Removing duplicate links (case insensitive)
527
- # # unique_links = []
528
- # # for link in cont_data['links_RE']:
529
- # # if not any(link.lower() == existing_link.lower() for existing_link in LLMdata['Link']):
530
- # # unique_links.append(link)
531
-
532
- # # # Removing duplicate phone numbers
533
- # # normalized_contact = [num[-10:] for num in LLMdata['Contact']]
534
- # # unique_numbers = []
535
- # # for num in cont_data['phone_numbers']:
536
- # # if num[-10:] not in normalized_contact:
537
- # # unique_numbers.append(num)
538
-
539
- # # # Add unique emails, links, and phone numbers to the original LLMdata
540
- # # LLMdata['Email'] += unique_emails
541
- # # LLMdata['Link'] += unique_links
542
- # # LLMdata['Contact'] += unique_numbers
543
- # # Ensure keys exist (CRITICAL FIX)
544
- # LLMdata['Email'] = LLMdata.get('Email', []) or []
545
- # LLMdata['Link'] = LLMdata.get('Link', []) or []
546
- # LLMdata['Contact'] = LLMdata.get('Contact', []) or []
547
-
548
- # # Removing duplicate emails
549
- # unique_emails = []
550
- # for email in cont_data.get('emails', []):
551
- # if not any(email.lower() == str(existing_email).lower() for existing_email in LLMdata['Email']):
552
- # unique_emails.append(email)
553
-
554
- # # Removing duplicate links
555
- # unique_links = []
556
- # for link in cont_data.get('links_RE', []):
557
- # if not any(link.lower() == str(existing_link).lower() for existing_link in LLMdata['Link']):
558
- # unique_links.append(link)
559
-
560
- # # Normalize existing contacts safely
561
- # normalized_contact = [
562
- # str(num)[-10:] for num in LLMdata['Contact'] if num
563
- # ]
564
-
565
- # # Removing duplicate phone numbers
566
- # unique_numbers = []
567
- # for num in cont_data.get('phone_numbers', []):
568
- # if str(num)[-10:] not in normalized_contact:
569
- # unique_numbers.append(num)
570
-
571
- # # Merge safely
572
- # LLMdata['Email'].extend(unique_emails)
573
- # LLMdata['Link'].extend(unique_links)
574
- # LLMdata['Contact'].extend(unique_numbers)
575
-
576
-
577
- # # Apply the function to the data
578
- # LLMdata=remove_duplicates_case_insensitive(LLMdata)
579
-
580
- # # Initialize the processed data dictionary
581
- # processed_data = {
582
- # "name": [],
583
- # "contact_number": [],
584
- # "Designation":[],
585
- # "email": [],
586
- # "Location": [],
587
- # "Link": [],
588
- # "Company":[],
589
- # "extracted_text": extracted_text
590
- # }
591
- # #LLM
592
-
593
- # processed_data['name'].extend(LLMdata.get('Name', None))
594
- # #processed_data['contact_number'].extend(LLMdata.get('Contact', []))
595
- # processed_data['Designation'].extend(LLMdata.get('Designation', []))
596
- # #processed_data['email'].extend(LLMdata.get("Email", []))
597
- # processed_data['Location'].extend(LLMdata.get('Address', []))
598
- # #processed_data['Link'].extend(LLMdata.get('Link', []))
599
- # processed_data['Company'].extend(LLMdata.get('Company', []))
600
-
601
- # #Contact
602
- # #processed_data['email'].extend(cont_data.get("emails", []))
603
- # #processed_data['contact_number'].extend(cont_data.get("phone_numbers", []))
604
- # #processed_data['Link'].extend(cont_data.get("links_RE", []))
605
-
606
- # #New_merge_data
607
- # processed_data['email'].extend(LLMdata['Email'])
608
- # processed_data['contact_number'].extend(LLMdata['Contact'])
609
- # processed_data['Link'].extend(LLMdata['Link'])
610
-
611
- # #to remove not found fields
612
- # # List of keys to check for 'Not found'
613
- # keys_to_check = ["name", "contact_number", "Designation", "email", "Location", "Link", "Company"]
614
-
615
- # # Replace 'Not found' with an empty list for each key
616
- # for key in keys_to_check:
617
- # if processed_data[key] == ['Not found'] or processed_data[key] == ['not found']:
618
- # processed_data[key] = []
619
-
620
- # return processed_data
621
- def process_resume_data(LLMdata, cont_data, extracted_text):
622
-
623
- # -------------------------------
624
- # ✅ STEP 1: Normalize LLM Schema
625
- # -------------------------------
626
- expected_keys = ["Name", "Designation", "Company", "Contact", "Address", "Email", "Link"]
627
-
628
- for key in expected_keys:
629
- if key not in LLMdata or LLMdata[key] is None:
630
- LLMdata[key] = []
631
- elif not isinstance(LLMdata[key], list):
632
- LLMdata[key] = [LLMdata[key]]
633
-
634
- # -------------------------------
635
- # ✅ STEP 2: Normalize cont_data
636
- # -------------------------------
637
- cont_data = cont_data or {}
638
- cont_data.setdefault("emails", [])
639
- cont_data.setdefault("phone_numbers", [])
640
- cont_data.setdefault("links_RE", [])
641
-
642
- # -------------------------------
643
- # ✅ STEP 3: Normalize existing contacts
644
- # -------------------------------
645
- normalized_llm_numbers = {
646
- str(num)[-10:] for num in LLMdata["Contact"] if num
647
- }
648
-
649
- # -------------------------------
650
- # ✅ STEP 4: Merge Emails
651
- # -------------------------------
652
- for email in cont_data["emails"]:
653
- if not any(email.lower() == str(e).lower() for e in LLMdata["Email"]):
654
- LLMdata["Email"].append(email)
655
-
656
- # -------------------------------
657
- # ✅ STEP 5: Merge Links
658
- # -------------------------------
659
- for link in cont_data["links_RE"]:
660
- if not any(link.lower() == str(l).lower() for l in LLMdata["Link"]):
661
- LLMdata["Link"].append(link)
662
-
663
- # -------------------------------
664
- # ✅ STEP 6: Merge Phone Numbers
665
- # -------------------------------
666
- for num in cont_data["phone_numbers"]:
667
- norm = str(num)[-10:]
668
- if norm not in normalized_llm_numbers:
669
- LLMdata["Contact"].append(num)
670
- normalized_llm_numbers.add(norm)
671
-
672
- # -------------------------------
673
- # ✅ STEP 7: Remove duplicates (case-insensitive)
674
- # -------------------------------
675
- LLMdata = remove_duplicates_case_insensitive(LLMdata)
676
-
677
- # -------------------------------
678
- # ✅ STEP 8: Build final structure
679
- # -------------------------------
680
- processed_data = {
681
- "name": LLMdata["Name"],
682
- "contact_number": LLMdata["Contact"],
683
- "Designation": LLMdata["Designation"],
684
- "email": LLMdata["Email"],
685
- "Location": LLMdata["Address"],
686
- "Link": LLMdata["Link"],
687
- "Company": LLMdata["Company"],
688
- "extracted_text": extracted_text
689
- }
690
-
691
- # -------------------------------
692
- # ✅ STEP 9: Clean "Not found"
693
- # -------------------------------
694
- for key in ["name", "contact_number", "Designation", "email", "Location", "Link", "Company"]:
695
- processed_data[key] = [
696
- v for v in processed_data[key]
697
- if str(v).lower() != "not found"
698
- ]
699
-
700
- return processed_data
 
 
1
  import os
 
 
2
  import json
3
  import re
 
 
 
 
 
4
  import logging
5
+ from typing import List, Dict, Any
6
+ # Ensure langchain is available for paddlex/paddleocr
7
+ try:
8
+ import langchain
9
+ import langchain_community
10
+ except ImportError:
11
+ logging.warning("LangChain modules not found. PaddleOCR might fail.")
12
+
13
+ from core.ocr_engine import OCREngine
14
+ from core.vlm_engine import GroqVLMEngine
15
+ from core.ner_engine import NEREngine
16
+
17
+ # Global instances (Lazy load)
18
+ _ocr = None
19
+ _vlm = None
20
+ _ner = None
21
+
22
+ def get_ocr():
23
+ global _ocr
24
+ if not _ocr:
25
+ _ocr = OCREngine()
26
+ return _ocr
27
+
28
+ def get_vlm():
29
+ global _vlm
30
+ if not _vlm:
31
+ _vlm = GroqVLMEngine()
32
+ return _vlm
33
+
34
+ def get_ner():
35
+ global _ner
36
+ if not _ner:
37
+ _ner = NEREngine()
38
+ return _ner
39
+
40
+ def process_image_pipeline(image_paths: List[str]) -> Dict[str, Any]:
41
+ logging.info(f"Pipeline: Starting processing for {len(image_paths)} images.")
42
+ vlm = get_vlm()
43
+ ocr = get_ocr()
44
+ ner = get_ner()
45
+
46
+ final_results = {
47
+ "name": [],
48
+ "contact_number": [],
49
+ "Designation": [],
50
+ "email": [],
51
+ "Location": [],
52
+ "Link": [],
53
+ "Company": [],
54
+ "extracted_text": {},
55
+ "status_message": "Primary: Groq VLM"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  }
57
 
58
+ all_raw_text = {}
59
+
60
+ for path in image_paths:
61
+ img_name = os.path.basename(path)
62
+ # 1. Primary: VLM
63
+ logging.info(f"Pipeline: Attempting VLM extraction for {img_name}")
64
+ vlm_data = vlm.process(path)
65
+ if vlm_data:
66
+ merge_structured_data(final_results, vlm_data)
67
+ all_raw_text[path] = json.dumps(vlm_data)
68
+ logging.info(f"Pipeline: VLM success for {img_name}")
69
+ else:
70
+ # 2. Fallback: OCR + NER
71
+ logging.warning(f"Pipeline: VLM failed or skipped for {img_name}. Falling back to OCR+NER.")
72
+ raw_text = ocr.extract_text(path)
73
+ all_raw_text[path] = raw_text
74
+ if raw_text:
75
+ logging.info(f"Pipeline: OCR success for {img_name}, attempting NER.")
76
+ ner_data = ner.extract_entities(raw_text)
77
+ if ner_data:
78
+ merge_structured_data(final_results, ner_data)
79
+ logging.info(f"Pipeline: NER success for {img_name}")
80
+ else:
81
+ logging.warning(f"Pipeline: NER failed to extract entities for {img_name}")
82
+ final_results["status_message"] = "Fallback: OCR+NER"
83
+ else:
84
+ logging.error(f"Pipeline: Both VLM and OCR failed for {img_name}")
85
+
86
+ final_results["extracted_text"] = all_raw_text
87
+ cleaned = cleanup_results(final_results)
88
+ logging.info(f"Pipeline: Completed. Extracted data for {sum(1 for v in cleaned.values() if isinstance(v, list) and v)} fields.")
89
+ return cleaned
90
+
91
+ def merge_structured_data(main_data: Dict, new_data: Dict):
92
+ mapping = {
93
+ "Name": "name",
94
+ "Contact": "contact_number",
95
+ "Designation": "Designation",
96
+ "Email": "email",
97
+ "Address": "Location",
98
+ "Link": "Link",
99
+ "Company": "Company"
100
  }
101
 
102
+ for key, val in new_data.items():
103
+ canonical_key = mapping.get(key.capitalize(), key.lower())
104
+ if canonical_key in main_data:
105
+ if isinstance(val, list):
106
+ main_data[canonical_key].extend(val)
107
+ elif val:
108
+ main_data[canonical_key].append(val)
109
+
110
+ def cleanup_results(results: Dict) -> Dict:
111
+ for key, val in results.items():
112
+ if isinstance(val, list):
113
+ # Remove duplicates, empty strings, and 'not found'
114
+ seen = set()
115
+ unique = []
116
+ for item in val:
117
+ item_str = str(item).strip()
118
+ if item_str.lower() not in seen and item_str.lower() not in {"", "not found", "none", "null", "[]"}:
119
+ unique.append(item_str)
120
+ seen.add(item_str.lower())
121
+ results[key] = unique
122
+ return results
123
+
124
+ def extract_contact_details(text: str) -> Dict[str, List[str]]:
125
+ # Regex fallback for extra safety
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  email_regex = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b')
127
+ phone_regex = re.compile(r'(\+?\d{1,3}[-.\s()]?)?\(?\d{3,5}\)?[-.\s()]?\d{3,5}[-.\s()]?\d{3,5}')
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  return {
130
+ "emails": email_regex.findall(text),
131
+ "phone_numbers": phone_regex.findall(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  }