VladGeekPro commited on
Commit
63a687d
·
1 Parent(s): 9fee5c7

MergedUserSearchWithSupplierAlgorithm

Browse files
app.py CHANGED
@@ -16,7 +16,6 @@ from typing import Any, Optional
16
 
17
  import torch
18
  from flask import Flask, jsonify, request
19
- from sentence_transformers import SentenceTransformer
20
 
21
  # Импорт экстракторов
22
  from extractors import (
@@ -31,7 +30,6 @@ from extractors import (
31
  HF_TOKEN = os.getenv("HF_TOKEN")
32
 
33
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
- _MODEL: Optional[SentenceTransformer] = None
35
  _WHISPER_MODEL: Optional[Any] = None
36
  _WHISPER_PROCESSOR: Optional[Any] = None
37
 
@@ -139,16 +137,6 @@ TEST_PHRASES = [
139
  ]
140
 
141
 
142
- def get_embedding_model() -> SentenceTransformer:
143
- """Возвращает модель эмбеддингов (ленивая загрузка)."""
144
- global _MODEL
145
-
146
- if _MODEL is None:
147
- _MODEL = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B", device=DEVICE)
148
-
149
- return _MODEL
150
-
151
-
152
  def get_whisper_pipeline() -> Any:
153
  """Возвращает Whisper pipeline (ленивая загрузка)."""
154
  global _WHISPER_MODEL, _WHISPER_PROCESSOR
@@ -160,7 +148,7 @@ def get_whisper_pipeline() -> Any:
160
 
161
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
162
  model_id,
163
- torch_dtype=torch.float32,
164
  low_cpu_mem_usage=True,
165
  use_safetensors=True,
166
  )
@@ -204,60 +192,49 @@ class ExpenseTextExtractor:
204
  self.date_extractor = ExpenseDateExtractor()
205
  self.supplier_extractor = ExpenseSupplierExtractor(suppliers=suppliers)
206
  self.amount_extractor = ExpenseAmountExtractor(suppliers=suppliers)
207
- self.user_extractor = ExpenseUserExtractor(
208
- users=users,
209
- suppliers=suppliers,
210
- model=get_embedding_model()
211
- )
212
 
213
  def extract(
214
- self,
215
- text: str,
216
- reference_date: str | date | None = None,
217
- debug_supplier: bool = False
218
  ) -> dict[str, Any]:
219
- """
220
- Извлекает все данные из текста.
221
-
222
- Args:
223
- text: Текст для анализа
224
- reference_date: Базовая дата
225
- debug_supplier: Включить отладку поставщиков
226
-
227
- Returns:
228
- Словарь со всеми извлечёнными данными
229
- """
230
- timings = {}
231
-
232
  t0 = time.time()
233
- date_info = self.date_extractor.extract(text, reference_date=reference_date)
234
  timings["date_extractor"] = round(time.time() - t0, 3)
235
-
236
  t0 = time.time()
237
  supplier_info = self.supplier_extractor.extract(
238
  text,
239
  date_phrase=date_info.get("matched_date_phrase"),
240
- debug=debug_supplier,
241
  )
242
  timings["supplier_extractor"] = round(time.time() - t0, 3)
243
-
244
  t0 = time.time()
245
  user_info = self.user_extractor.extract(
246
  text,
247
  supplier_phrase=supplier_info.get("matched_supplier_phrase"),
248
  date_phrase=date_info.get("matched_date_phrase"),
 
249
  )
250
  timings["user_extractor"] = round(time.time() - t0, 3)
251
-
252
  t0 = time.time()
253
  amount_info = self.amount_extractor.extract(
254
  text,
255
  matched_date_phrase=date_info["matched_date_phrase"],
256
  matched_supplier_phrase=supplier_info["matched_supplier_phrase"],
 
257
  )
258
  timings["amount_extractor"] = round(time.time() - t0, 3)
259
-
260
- print(f"[TIMINGS] {timings}")
 
261
 
262
  result = {
263
  "text": text,
@@ -267,8 +244,16 @@ class ExpenseTextExtractor:
267
  "date": date_info["date"],
268
  "date_iso": date_info["date_iso"],
269
  }
270
- if debug_supplier and "supplier_debug" in supplier_info:
271
- result["supplier_debug"] = supplier_info["supplier_debug"]
 
 
 
 
 
 
 
 
272
  return result
273
 
274
 
@@ -345,10 +330,10 @@ def transcribe_audio_text(audio_path: str, suppliers: list[str] | None = None, u
345
  raise RuntimeError("Speech-to-text backend is unavailable.")
346
 
347
 
348
- def process_voice_request(audio_path: str, mode: str, payload: dict[str, Any]) -> dict[str, Any]:
349
  """Обрабатывает голосовой запрос."""
350
  total_start = time.time()
351
-
352
  context = payload.get("context", {}) if isinstance(payload, dict) else {}
353
  supplier_names = extract_names(context.get("suppliers"))
354
  user_names = extract_names(context.get("users"))
@@ -382,12 +367,12 @@ def process_voice_request(audio_path: str, mode: str, payload: dict[str, Any]) -
382
  pipeline_init_time = round(time.time() - t0, 3)
383
  print(f"[TIMINGS] pipeline_init: {pipeline_init_time}s")
384
 
385
- extracted = extractor.extract(transcript, reference_date=date.today().isoformat())
386
-
387
  total_time = round(time.time() - total_start, 3)
388
  print(f"[TIMINGS] TOTAL: {total_time}s (whisper: {whisper_time}s)")
389
 
390
- return {
391
  "status": "ok",
392
  "text": transcript,
393
  "notes": polish_notes_text(extracted.get("text") or transcript),
@@ -396,6 +381,9 @@ def process_voice_request(audio_path: str, mode: str, payload: dict[str, Any]) -
396
  "date": extracted.get("date_iso") or extracted.get("date"),
397
  "sum": extracted.get("amount"),
398
  }
 
 
 
399
 
400
 
401
  def require_auth():
@@ -455,7 +443,7 @@ def health():
455
  @app.get("/test-data")
456
  def test_data():
457
  """Тестирует извлечение данных из текста без использования Whisper."""
458
- debug_supplier = (request.args.get("debug") or "").strip().lower() in {"1", "true", "yes"}
459
  extractor = build_default_pipeline(suppliers=TEST_SUPPLIERS, users=TEST_USERS)
460
 
461
  started = time.time()
@@ -466,9 +454,9 @@ def test_data():
466
  extracted = extractor.extract(
467
  phrase,
468
  reference_date=date.today().isoformat(),
469
- debug_supplier=debug_supplier,
470
  )
471
- results.append({
472
  "text": phrase,
473
  "user": extracted.get("user"),
474
  "supplier": extracted.get("supplier"),
@@ -476,8 +464,10 @@ def test_data():
476
  "date": extracted.get("date"),
477
  "date_iso": extracted.get("date_iso"),
478
  "processing_time": round(time.time() - item_started, 3),
479
- **({"supplier_debug": extracted.get("supplier_debug")} if debug_supplier and extracted.get("supplier_debug") else {}),
480
- })
 
 
481
 
482
  return jsonify({
483
  "status": "ok",
@@ -500,6 +490,7 @@ def process_audio():
500
 
501
  audio = request.files.get("audio")
502
  mode = (request.form.get("mode") or "expense").strip()
 
503
  context = parse_context(request.form.get("context"))
504
 
505
  if audio is None:
@@ -513,7 +504,7 @@ def process_audio():
513
  temp_path = temp_file.name
514
  audio.save(temp_file)
515
 
516
- result = process_voice_request(audio_path=temp_path, mode=mode, payload={"context": context})
517
  return jsonify(result)
518
  except Exception as exception:
519
  return jsonify({"status": "error", "message": str(exception)}), 422
 
16
 
17
  import torch
18
  from flask import Flask, jsonify, request
 
19
 
20
  # Импорт экстракторов
21
  from extractors import (
 
30
  HF_TOKEN = os.getenv("HF_TOKEN")
31
 
32
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
33
  _WHISPER_MODEL: Optional[Any] = None
34
  _WHISPER_PROCESSOR: Optional[Any] = None
35
 
 
137
  ]
138
 
139
 
 
 
 
 
 
 
 
 
 
 
140
  def get_whisper_pipeline() -> Any:
141
  """Возвращает Whisper pipeline (ленивая загрузка)."""
142
  global _WHISPER_MODEL, _WHISPER_PROCESSOR
 
148
 
149
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
150
  model_id,
151
+ dtype=torch.float32,
152
  low_cpu_mem_usage=True,
153
  use_safetensors=True,
154
  )
 
192
  self.date_extractor = ExpenseDateExtractor()
193
  self.supplier_extractor = ExpenseSupplierExtractor(suppliers=suppliers)
194
  self.amount_extractor = ExpenseAmountExtractor(suppliers=suppliers)
195
+ self.user_extractor = ExpenseUserExtractor(users=users, suppliers=suppliers)
 
 
 
 
196
 
197
  def extract(
198
+ self,
199
+ text: str,
200
+ reference_date: str | date | None = None,
201
+ debug: bool = False,
202
  ) -> dict[str, Any]:
203
+ """Извлекает все данные из текста."""
204
+ timings: dict[str, float] = {}
205
+
 
 
 
 
 
 
 
 
 
 
206
  t0 = time.time()
207
+ date_info = self.date_extractor.extract(text, reference_date=reference_date, debug=debug)
208
  timings["date_extractor"] = round(time.time() - t0, 3)
209
+
210
  t0 = time.time()
211
  supplier_info = self.supplier_extractor.extract(
212
  text,
213
  date_phrase=date_info.get("matched_date_phrase"),
214
+ debug=debug,
215
  )
216
  timings["supplier_extractor"] = round(time.time() - t0, 3)
217
+
218
  t0 = time.time()
219
  user_info = self.user_extractor.extract(
220
  text,
221
  supplier_phrase=supplier_info.get("matched_supplier_phrase"),
222
  date_phrase=date_info.get("matched_date_phrase"),
223
+ debug=debug,
224
  )
225
  timings["user_extractor"] = round(time.time() - t0, 3)
226
+
227
  t0 = time.time()
228
  amount_info = self.amount_extractor.extract(
229
  text,
230
  matched_date_phrase=date_info["matched_date_phrase"],
231
  matched_supplier_phrase=supplier_info["matched_supplier_phrase"],
232
+ debug=debug,
233
  )
234
  timings["amount_extractor"] = round(time.time() - t0, 3)
235
+
236
+ if debug:
237
+ print(f"[TIMINGS] {timings}")
238
 
239
  result = {
240
  "text": text,
 
244
  "date": date_info["date"],
245
  "date_iso": date_info["date_iso"],
246
  }
247
+
248
+ if debug:
249
+ result["debug"] = {
250
+ "timings": timings,
251
+ "date": date_info.get("date_debug"),
252
+ "supplier": supplier_info.get("supplier_debug"),
253
+ "user": user_info.get("user_debug"),
254
+ "amount": amount_info.get("amount_debug"),
255
+ }
256
+
257
  return result
258
 
259
 
 
330
  raise RuntimeError("Speech-to-text backend is unavailable.")
331
 
332
 
333
+ def process_voice_request(audio_path: str, mode: str, payload: dict[str, Any], debug: bool = False) -> dict[str, Any]:
334
  """Обрабатывает голосовой запрос."""
335
  total_start = time.time()
336
+
337
  context = payload.get("context", {}) if isinstance(payload, dict) else {}
338
  supplier_names = extract_names(context.get("suppliers"))
339
  user_names = extract_names(context.get("users"))
 
367
  pipeline_init_time = round(time.time() - t0, 3)
368
  print(f"[TIMINGS] pipeline_init: {pipeline_init_time}s")
369
 
370
+ extracted = extractor.extract(transcript, reference_date=date.today().isoformat(), debug=debug)
371
+
372
  total_time = round(time.time() - total_start, 3)
373
  print(f"[TIMINGS] TOTAL: {total_time}s (whisper: {whisper_time}s)")
374
 
375
+ payload = {
376
  "status": "ok",
377
  "text": transcript,
378
  "notes": polish_notes_text(extracted.get("text") or transcript),
 
381
  "date": extracted.get("date_iso") or extracted.get("date"),
382
  "sum": extracted.get("amount"),
383
  }
384
+ if debug and extracted.get("debug"):
385
+ payload["debug"] = extracted.get("debug")
386
+ return payload
387
 
388
 
389
  def require_auth():
 
443
  @app.get("/test-data")
444
  def test_data():
445
  """Тестирует извлечение данных из текста без использования Whisper."""
446
+ debug = (request.args.get("debug") or "").strip().lower() in {"1", "true", "yes"}
447
  extractor = build_default_pipeline(suppliers=TEST_SUPPLIERS, users=TEST_USERS)
448
 
449
  started = time.time()
 
454
  extracted = extractor.extract(
455
  phrase,
456
  reference_date=date.today().isoformat(),
457
+ debug=debug,
458
  )
459
+ row = {
460
  "text": phrase,
461
  "user": extracted.get("user"),
462
  "supplier": extracted.get("supplier"),
 
464
  "date": extracted.get("date"),
465
  "date_iso": extracted.get("date_iso"),
466
  "processing_time": round(time.time() - item_started, 3),
467
+ }
468
+ if debug and extracted.get("debug"):
469
+ row["debug"] = extracted.get("debug")
470
+ results.append(row)
471
 
472
  return jsonify({
473
  "status": "ok",
 
490
 
491
  audio = request.files.get("audio")
492
  mode = (request.form.get("mode") or "expense").strip()
493
+ debug = ((request.form.get("debug") or request.args.get("debug") or "").strip().lower() in {"1", "true", "yes"})
494
  context = parse_context(request.form.get("context"))
495
 
496
  if audio is None:
 
504
  temp_path = temp_file.name
505
  audio.save(temp_file)
506
 
507
+ result = process_voice_request(audio_path=temp_path, mode=mode, payload={"context": context}, debug=debug)
508
  return jsonify(result)
509
  except Exception as exception:
510
  return jsonify({"status": "error", "message": str(exception)}), 422
extractors/amount_extractor.py CHANGED
@@ -42,21 +42,47 @@ class ExpenseAmountExtractor:
42
  text: str,
43
  matched_date_phrase: Optional[str] = None,
44
  matched_supplier_phrase: Optional[str] = None,
 
45
  ) -> dict[str, Any]:
46
  date_span = self.phrase_span(text, matched_date_phrase)
47
  supplier_span = self.phrase_span(text, matched_supplier_phrase)
 
48
 
49
  for match in AMOUNT_PATTERN.finditer(text):
50
  span = match.span()
 
 
 
51
 
52
- if self.overlaps(span, date_span):
53
- continue
54
- if self.overlaps(span, supplier_span):
 
 
 
 
 
 
55
  continue
56
 
57
- amount_text = match.group(0)
58
  amount = self.to_float(amount_text)
59
  if amount is not None:
60
- return {"amount": amount, "amount_text": amount_text}
 
 
 
 
 
 
 
 
61
 
62
- return {"amount": None, "amount_text": None}
 
 
 
 
 
 
 
 
 
42
  text: str,
43
  matched_date_phrase: Optional[str] = None,
44
  matched_supplier_phrase: Optional[str] = None,
45
+ debug: bool = False,
46
  ) -> dict[str, Any]:
47
  date_span = self.phrase_span(text, matched_date_phrase)
48
  supplier_span = self.phrase_span(text, matched_supplier_phrase)
49
+ candidates: list[dict[str, Any]] = []
50
 
51
  for match in AMOUNT_PATTERN.finditer(text):
52
  span = match.span()
53
+ overlaps_date = self.overlaps(span, date_span)
54
+ overlaps_supplier = self.overlaps(span, supplier_span)
55
+ amount_text = match.group(0)
56
 
57
+ if debug:
58
+ candidates.append({
59
+ "value": amount_text,
60
+ "span": [span[0], span[1]],
61
+ "overlaps_date": overlaps_date,
62
+ "overlaps_supplier": overlaps_supplier,
63
+ })
64
+
65
+ if overlaps_date or overlaps_supplier:
66
  continue
67
 
 
68
  amount = self.to_float(amount_text)
69
  if amount is not None:
70
+ payload = {"amount": amount, "amount_text": amount_text}
71
+ if debug:
72
+ payload["amount_debug"] = {
73
+ "date_span": list(date_span) if date_span else None,
74
+ "supplier_span": list(supplier_span) if supplier_span else None,
75
+ "candidates": candidates,
76
+ "selected": amount_text,
77
+ }
78
+ return payload
79
 
80
+ payload = {"amount": None, "amount_text": None}
81
+ if debug:
82
+ payload["amount_debug"] = {
83
+ "date_span": list(date_span) if date_span else None,
84
+ "supplier_span": list(supplier_span) if supplier_span else None,
85
+ "candidates": candidates,
86
+ "selected": None,
87
+ }
88
+ return payload
extractors/date_extractor.py CHANGED
@@ -492,26 +492,37 @@ class ExpenseDateExtractor:
492
  def __init__(self) -> None:
493
  self.parser = UniversalDateParser()
494
 
495
- def extract(self, text: str, reference_date: str | date | None = None) -> dict[str, Any]:
496
  """
497
  Извлекает дату из текста.
498
-
499
  Args:
500
  text: Текст для анализа
501
  reference_date: Базовая дата (по умолчанию сегодня)
502
-
 
503
  Returns:
504
  Словарь с date, date_iso, matched_date_phrase
505
  """
506
  ref_date = self.to_date(reference_date or date.today().isoformat())
507
  parsed = self.parser.parse(text=text, reference_date=ref_date)
508
 
509
- return {
510
  "date": datetime.strptime(parsed.date_iso, "%Y-%m-%d").strftime("%d.%m.%Y") if parsed else None,
511
  "date_iso": parsed.date_iso if parsed else None,
512
  "matched_date_phrase": parsed.matched_expression if parsed else None,
513
  }
514
 
 
 
 
 
 
 
 
 
 
 
515
  @staticmethod
516
  def to_date(value: str | date) -> date:
517
  """Преобразует строку или date в date."""
 
492
  def __init__(self) -> None:
493
  self.parser = UniversalDateParser()
494
 
495
+ def extract(self, text: str, reference_date: str | date | None = None, debug: bool = False) -> dict[str, Any]:
496
  """
497
  Извлекает дату из текста.
498
+
499
  Args:
500
  text: Текст для анализа
501
  reference_date: Базовая дата (по умолчанию сегодня)
502
+ debug: Включить отладочную информацию
503
+
504
  Returns:
505
  Словарь с date, date_iso, matched_date_phrase
506
  """
507
  ref_date = self.to_date(reference_date or date.today().isoformat())
508
  parsed = self.parser.parse(text=text, reference_date=ref_date)
509
 
510
+ payload = {
511
  "date": datetime.strptime(parsed.date_iso, "%Y-%m-%d").strftime("%d.%m.%Y") if parsed else None,
512
  "date_iso": parsed.date_iso if parsed else None,
513
  "matched_date_phrase": parsed.matched_expression if parsed else None,
514
  }
515
 
516
+ if debug:
517
+ payload["date_debug"] = {
518
+ "reference_date": ref_date.isoformat(),
519
+ "input_text": text,
520
+ "matched_date_phrase": payload["matched_date_phrase"],
521
+ "date_iso": payload["date_iso"],
522
+ }
523
+
524
+ return payload
525
+
526
  @staticmethod
527
  def to_date(value: str | date) -> date:
528
  """Преобразует строку или date в date."""
extractors/supplier_extractor.py CHANGED
@@ -143,7 +143,7 @@ class ExpenseSupplierExtractor:
143
  self.lexical_token_cache: dict[str, float] = {}
144
  self.phrase_support_cache: dict[str, float] = {}
145
  self.noise_terms = {
146
- "за", "на", "из", "для", "под", "над", "при", "без", "и", "или",
147
  "купил", "купила", "купили", "покупка", "заказал", "заказала", "заказали",
148
  "оплатил", "оплатила", "оплатили", "заплатил", "заплатила", "заплатили",
149
  "был", "была", "было", "были", "утром", "днем", "днём", "вечером", "ночью",
@@ -290,15 +290,22 @@ class ExpenseSupplierExtractor:
290
  best = {"supplier": self.suppliers[i], "score": local, "phrase": phrase, "variant": local_variant}
291
  return best
292
 
293
- def extract(self, text: str, date_phrase: str | None = None, debug: bool = False) -> dict[str, Any]:
 
 
 
 
 
 
294
  """
295
  Извлекает поставщика из текста.
296
-
297
  Args:
298
  text: Текст для анализа
299
  date_phrase: Фраза даты для исключения
 
300
  debug: Включить отладочную информацию
301
-
302
  Returns:
303
  Словарь с supplier, supplier_score, matched_supplier_phrase
304
  """
@@ -306,6 +313,10 @@ class ExpenseSupplierExtractor:
306
  excluded_tokens: set[str] = set()
307
  if date_phrase:
308
  excluded_tokens.update(normalize_text(date_phrase).split())
 
 
 
 
309
  excluded_tokens.update(self.noise_terms)
310
 
311
  raw_tokens = normalize_text(text).split()
@@ -396,6 +407,7 @@ class ExpenseSupplierExtractor:
396
  payload["supplier_debug"] = {
397
  "tokens": tokens,
398
  "phrases_count": len(phrases),
 
399
  "top_candidates": top_candidates,
400
  }
401
 
 
143
  self.lexical_token_cache: dict[str, float] = {}
144
  self.phrase_support_cache: dict[str, float] = {}
145
  self.noise_terms = {
146
+ "для", "под", "над", "при", "без", "или",
147
  "купил", "купила", "купили", "покупка", "заказал", "заказала", "заказали",
148
  "оплатил", "оплатила", "оплатили", "заплатил", "заплатила", "заплатили",
149
  "был", "была", "было", "были", "утром", "днем", "днём", "вечером", "ночью",
 
290
  best = {"supplier": self.suppliers[i], "score": local, "phrase": phrase, "variant": local_variant}
291
  return best
292
 
293
+ def extract(
294
+ self,
295
+ text: str,
296
+ date_phrase: str | None = None,
297
+ excluded_phrases: list[str] | None = None,
298
+ debug: bool = False,
299
+ ) -> dict[str, Any]:
300
  """
301
  Извлекает поставщика из текста.
302
+
303
  Args:
304
  text: Текст для анализа
305
  date_phrase: Фраза даты для исключения
306
+ excluded_phrases: Дополнительные фразы для исключения
307
  debug: Включить отладочную информацию
308
+
309
  Returns:
310
  Словарь с supplier, supplier_score, matched_supplier_phrase
311
  """
 
313
  excluded_tokens: set[str] = set()
314
  if date_phrase:
315
  excluded_tokens.update(normalize_text(date_phrase).split())
316
+ if excluded_phrases:
317
+ for phrase in excluded_phrases:
318
+ if phrase:
319
+ excluded_tokens.update(normalize_text(phrase).split())
320
  excluded_tokens.update(self.noise_terms)
321
 
322
  raw_tokens = normalize_text(text).split()
 
407
  payload["supplier_debug"] = {
408
  "tokens": tokens,
409
  "phrases_count": len(phrases),
410
+ "excluded_tokens": sorted(excluded_tokens)[:80],
411
  "top_candidates": top_candidates,
412
  }
413
 
extractors/user_extractor.py CHANGED
@@ -1,152 +1,72 @@
1
- """
2
- Экстрактор пользователей из текста.
3
-
4
- Использует семантические эмбеддинги для поиска пользователей.
5
- """
6
 
7
  from __future__ import annotations
8
 
9
  import re
10
- import unicodedata
11
  from typing import Any
12
 
13
- import torch
14
- from pymorphy3 import MorphAnalyzer
15
- from sentence_transformers import SentenceTransformer
16
-
17
-
18
- MORPH = MorphAnalyzer()
19
-
20
-
21
- def normalize_text(text: str) -> str:
22
- """Нормализует текст: lowercase, удаление диакритики и пунктуации."""
23
- text = unicodedata.normalize("NFKD", text.lower())
24
- text = "".join(ch for ch in text if not unicodedata.combining(ch))
25
- return re.sub(r"[^\w\s]", "", text).strip()
26
-
27
-
28
- def tokenize_text(text: str) -> list[str]:
29
- """Токенизирует текст."""
30
- return normalize_text(text).split()
31
-
32
-
33
- def lemmatize_word(word: str) -> str:
34
- """Возвращает лемму слова."""
35
- return MORPH.parse(word)[0].normal_form if re.fullmatch(r"[а-я]+", word) else word
36
-
37
-
38
- def lemmatize_text(text: str) -> list[str]:
39
- """Лемматизирует текст."""
40
- return [lemmatize_word(word) for word in tokenize_text(text)]
41
 
42
 
43
  class ExpenseUserExtractor:
44
- """
45
- Экстрактор пользователей из текста.
46
-
47
- Сначала использует точное/текстовое совпадение имени, а затем
48
- один батч эмбеддингов для оставшихся кандидатов.
49
- """
50
 
51
  def __init__(
52
  self,
53
  users: list[str],
54
  suppliers: list[str],
55
- model: SentenceTransformer,
56
- threshold: float = 0.6
57
  ) -> None:
58
  self.users = users
59
- self.model = model
60
  self.threshold = threshold
61
  self.supplier_terms = {normalize_text(supplier) for supplier in suppliers}
62
- self.user_terms = [normalize_text(user) for user in users]
63
- self.user_embeddings = model.encode(
64
- [f"passage: {user}" for user in self.user_terms],
65
- convert_to_tensor=True,
66
- normalize_embeddings=True,
67
- show_progress_bar=False,
68
- )
69
-
70
- @staticmethod
71
- def _contains_whole_phrase(text: str, phrase: str) -> bool:
72
- if not phrase:
73
- return False
74
- return re.search(rf"(?<!\w){re.escape(phrase)}(?!\w)", text) is not None
75
-
76
- def _extract_candidates(self, text: str, excluded_tokens: set[str]) -> list[str]:
77
- candidates: list[str] = []
78
- seen: set[str] = set()
79
-
80
- for word in lemmatize_text(text):
81
- if len(word) < 3:
82
- continue
83
- if word in excluded_tokens or word in self.supplier_terms or word in seen:
84
- continue
85
- seen.add(word)
86
- candidates.append(word)
87
-
88
- return candidates
89
 
90
  def extract(
91
  self,
92
  text: str,
93
  supplier_phrase: str | None = None,
94
- date_phrase: str | None = None
 
95
  ) -> dict[str, Any]:
96
- excluded_tokens: set[str] = set()
97
- if supplier_phrase:
98
- excluded_tokens.update(normalize_text(supplier_phrase).split())
99
- if date_phrase:
100
- excluded_tokens.update(normalize_text(date_phrase).split())
101
-
102
  normalized_text = normalize_text(text)
103
 
104
- for user_term, original_user in zip(self.user_terms, self.users):
105
- if user_term and self._contains_whole_phrase(normalized_text, user_term):
106
- return {
107
- "user": original_user,
108
- "user_score": 1.0,
109
- "matched_user_phrase": user_term,
110
- }
111
-
112
  if re.search(r"(?<!\S)я(?!\S)", normalized_text, re.IGNORECASE):
113
- return {
114
  "user": "Я",
115
  "user_score": 1.0,
116
  "matched_user_phrase": "я",
117
  }
 
 
 
 
 
 
118
 
119
- candidates = self._extract_candidates(text, excluded_tokens)
120
- if not candidates:
121
- return {
122
- "user": None,
123
- "user_score": None,
124
- "matched_user_phrase": None,
125
- }
126
-
127
- query_embeddings = self.model.encode(
128
- [f"query: {word}" for word in candidates],
129
- convert_to_tensor=True,
130
- normalize_embeddings=True,
131
- show_progress_bar=False,
132
- batch_size=max(1, min(32, len(candidates))),
133
  )
134
 
135
- similarity_matrix = torch.matmul(query_embeddings, self.user_embeddings.T)
136
- flat_index = int(torch.argmax(similarity_matrix))
137
- candidate_index = flat_index // len(self.users)
138
- user_index = flat_index % len(self.users)
139
- score = similarity_matrix[candidate_index, user_index].item()
 
140
 
141
- if score >= self.threshold:
142
- return {
143
- "user": self.users[user_index],
144
- "user_score": round(score, 4),
145
- "matched_user_phrase": candidates[candidate_index],
 
 
146
  }
147
 
148
- return {
149
- "user": None,
150
- "user_score": None,
151
- "matched_user_phrase": None,
152
- }
 
1
+ """Экстрактор пользователей на той же логике, что и поиск поставщика."""
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  import re
 
6
  from typing import Any
7
 
8
+ from extractors.supplier_extractor import ExpenseSupplierExtractor, normalize_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class ExpenseUserExtractor:
12
+ """Ищет пользователя тем же fuzzy-matcher, что и поставщика."""
 
 
 
 
 
13
 
14
  def __init__(
15
  self,
16
  users: list[str],
17
  suppliers: list[str],
18
+ model: Any = None,
19
+ threshold: float = 0.5,
20
  ) -> None:
21
  self.users = users
 
22
  self.threshold = threshold
23
  self.supplier_terms = {normalize_text(supplier) for supplier in suppliers}
24
+ self.user_matcher = ExpenseSupplierExtractor(suppliers=users)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def extract(
27
  self,
28
  text: str,
29
  supplier_phrase: str | None = None,
30
+ date_phrase: str | None = None,
31
+ debug: bool = False,
32
  ) -> dict[str, Any]:
 
 
 
 
 
 
33
  normalized_text = normalize_text(text)
34
 
 
 
 
 
 
 
 
 
35
  if re.search(r"(?<!\S)я(?!\S)", normalized_text, re.IGNORECASE):
36
+ payload = {
37
  "user": "Я",
38
  "user_score": 1.0,
39
  "matched_user_phrase": "я",
40
  }
41
+ if debug:
42
+ payload["user_debug"] = {
43
+ "mode": "direct-pronoun",
44
+ "normalized_text": normalized_text,
45
+ }
46
+ return payload
47
 
48
+ match = self.user_matcher.extract(
49
+ text=text,
50
+ date_phrase=date_phrase,
51
+ excluded_phrases=[supplier_phrase] if supplier_phrase else None,
52
+ debug=debug,
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
+ score = match.get("supplier_score")
56
+ payload = {
57
+ "user": match.get("supplier") if score is not None and score >= self.threshold else None,
58
+ "user_score": score if score is not None and score >= self.threshold else None,
59
+ "matched_user_phrase": match.get("matched_supplier_phrase") if score is not None and score >= self.threshold else None,
60
+ }
61
 
62
+ if debug:
63
+ payload["user_debug"] = {
64
+ "mode": "supplier-matcher",
65
+ "threshold": self.threshold,
66
+ "excluded_supplier_phrase": supplier_phrase,
67
+ "normalized_text": normalized_text,
68
+ "matcher_debug": match.get("supplier_debug"),
69
  }
70
 
71
+ return payload
72
+