VladGeekPro commited on
Commit
9fa4ecb
·
1 Parent(s): 619ed81

OptimezedSumAndUserSearch

Browse files
app.py CHANGED
@@ -173,7 +173,7 @@ def get_whisper_pipeline() -> Any:
173
  model=model,
174
  tokenizer=_WHISPER_PROCESSOR.tokenizer,
175
  feature_extractor=_WHISPER_PROCESSOR.feature_extractor,
176
- torch_dtype=torch.float32,
177
  device="cpu",
178
  )
179
 
 
173
  model=model,
174
  tokenizer=_WHISPER_PROCESSOR.tokenizer,
175
  feature_extractor=_WHISPER_PROCESSOR.feature_extractor,
176
+ dtype=torch.float32,
177
  device="cpu",
178
  )
179
 
extractors/amount_extractor.py CHANGED
@@ -1,60 +1,29 @@
1
- """
2
- Экстрактор сумм из текста.
3
-
4
- Использует GLiNER для извлечения денежных сумм.
5
- """
6
 
7
  from __future__ import annotations
8
 
9
  import re
10
  from typing import Any, Optional
11
 
12
- from gliner import GLiNER
13
-
14
-
15
- # Глобальная модель для извлечения сумм
16
- _AMOUNT_MODEL: Optional[GLiNER] = None
17
-
18
 
19
- def get_amount_model() -> Optional[GLiNER]:
20
- """Возвращает модель для извлечения сумм (ленивая загрузка)."""
21
- global _AMOUNT_MODEL
22
-
23
- if _AMOUNT_MODEL is None:
24
- _AMOUNT_MODEL = GLiNER.from_pretrained("urchade/gliner_multi-v2.1")
25
-
26
- return _AMOUNT_MODEL
27
 
28
 
29
  class ExpenseAmountExtractor:
30
- """
31
- Экстрактор денежных сумм из текста.
32
-
33
- Использует GLiNER для поиска упоминаний денег.
34
- """
35
-
36
  def __init__(self, suppliers: list[str] | None = None) -> None:
37
- """
38
- Args:
39
- suppliers: Список поставщиков (не используется, для совместимости)
40
- """
41
- self.model = get_amount_model()
42
 
43
  @staticmethod
44
  def to_float(value: str) -> Optional[float]:
45
- """Преобразует строку в число."""
46
- cleaned = value.replace(" ", "").replace("\u00A0", "")
47
- match = re.search(r"\d+(?:[,]\d{1,2})?", cleaned)
48
- if not match:
49
- return None
50
  try:
51
- return float(match.group(0).replace(",", "."))
52
  except ValueError:
53
  return None
54
 
55
  @staticmethod
56
  def phrase_span(text: str, phrase: Optional[str]) -> Optional[tuple[int, int]]:
57
- """Возвращает позицию фразы в тексте."""
58
  if not phrase:
59
  return None
60
  idx = text.lower().find(phrase.lower())
@@ -64,61 +33,30 @@ class ExpenseAmountExtractor:
64
 
65
  @staticmethod
66
  def overlaps(span1: tuple[int, int], span2: Optional[tuple[int, int]]) -> bool:
67
- """Проверяет пересечение двух диапазонов."""
68
  if span2 is None:
69
  return False
70
  return span1[0] < span2[1] and span2[0] < span1[1]
71
 
72
- @staticmethod
73
- def expand_amount_text(text: str, start: int, end: int) -> tuple[str, tuple[int, int]]:
74
- """Расширяет текст суммы (для дробных чисел)."""
75
- suffix = re.match(r",\d{1,2}", text[end:])
76
- if suffix:
77
- new_end = end + len(suffix.group(0))
78
- return text[start:new_end].strip(), (start, new_end)
79
-
80
- prefix = re.search(r"(\d{1,3}(?:\s*\d{3})*),", text[:start])
81
- if prefix:
82
- new_start = prefix.start(1)
83
- return text[new_start:end].strip(), (new_start, end)
84
-
85
- return text[start:end].strip(), (start, end)
86
-
87
  def extract(
88
  self,
89
  text: str,
90
  matched_date_phrase: Optional[str] = None,
91
  matched_supplier_phrase: Optional[str] = None,
92
  ) -> dict[str, Any]:
93
- """
94
- Извлекает сумму из текста.
95
-
96
- Args:
97
- text: Текст для анализа
98
- matched_date_phrase: Фраза даты для исключения
99
- matched_supplier_phrase: Фраза поставщика для исключения
100
-
101
- Returns:
102
- Словарь с amount и amount_text
103
- """
104
- if self.model is None:
105
- return {"amount": None, "amount_text": None}
106
-
107
  date_span = self.phrase_span(text, matched_date_phrase)
108
  supplier_span = self.phrase_span(text, matched_supplier_phrase)
109
- entities = self.model.predict_entities(text, ["money"], threshold=0.3)
110
 
111
- for ent in sorted(entities, key=lambda item: float(item.get("score", 0.0)), reverse=True):
112
- raw_span = (int(ent.get("start", 0)), int(ent.get("end", 0)))
113
- amount_text, span = self.expand_amount_text(text, raw_span[0], raw_span[1])
114
- amount = self.to_float(amount_text)
115
- overlaps_date = self.overlaps(span, date_span)
116
- overlaps_supplier = self.overlaps(span, supplier_span)
117
 
118
- if amount is None:
119
  continue
120
- if overlaps_date or overlaps_supplier:
121
  continue
122
- return {"amount": amount, "amount_text": amount_text}
 
 
 
 
123
 
124
  return {"amount": None, "amount_text": None}
 
1
+ """Простой regex-экстрактор суммы из текста."""
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  import re
6
  from typing import Any, Optional
7
 
 
 
 
 
 
 
8
 
9
+ AMOUNT_PATTERN = re.compile(r"\d+(?:,\d{1,2})?", re.IGNORECASE)
 
 
 
 
 
 
 
10
 
11
 
12
  class ExpenseAmountExtractor:
13
+ """Извлекает сумму как целое число или число с запятой."""
14
+
 
 
 
 
15
  def __init__(self, suppliers: list[str] | None = None) -> None:
16
+ self.suppliers = suppliers or []
 
 
 
 
17
 
18
  @staticmethod
19
  def to_float(value: str) -> Optional[float]:
 
 
 
 
 
20
  try:
21
+ return float(value.replace(",", "."))
22
  except ValueError:
23
  return None
24
 
25
  @staticmethod
26
  def phrase_span(text: str, phrase: Optional[str]) -> Optional[tuple[int, int]]:
 
27
  if not phrase:
28
  return None
29
  idx = text.lower().find(phrase.lower())
 
33
 
34
  @staticmethod
35
  def overlaps(span1: tuple[int, int], span2: Optional[tuple[int, int]]) -> bool:
 
36
  if span2 is None:
37
  return False
38
  return span1[0] < span2[1] and span2[0] < span1[1]
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def extract(
41
  self,
42
  text: str,
43
  matched_date_phrase: Optional[str] = None,
44
  matched_supplier_phrase: Optional[str] = None,
45
  ) -> dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  date_span = self.phrase_span(text, matched_date_phrase)
47
  supplier_span = self.phrase_span(text, matched_supplier_phrase)
 
48
 
49
+ for match in AMOUNT_PATTERN.finditer(text):
50
+ span = match.span()
 
 
 
 
51
 
52
+ if self.overlaps(span, date_span):
53
  continue
54
+ if self.overlaps(span, supplier_span):
55
  continue
56
+
57
+ amount_text = match.group(0)
58
+ amount = self.to_float(amount_text)
59
+ if amount is not None:
60
+ return {"amount": amount, "amount_text": amount_text}
61
 
62
  return {"amount": None, "amount_text": None}
extractors/user_extractor.py CHANGED
@@ -43,98 +43,109 @@ def lemmatize_text(text: str) -> list[str]:
43
  class ExpenseUserExtractor:
44
  """
45
  Экстрактор пользователей из текста.
46
-
47
- Использует семантические эмбеддинги для сопоставления слов из текста
48
- с известными пользователями.
49
  """
50
-
51
  def __init__(
52
- self,
53
- users: list[str],
54
- suppliers: list[str],
55
- model: SentenceTransformer,
56
  threshold: float = 0.6
57
  ) -> None:
58
- """
59
- Args:
60
- users: Список известных пользователей
61
- suppliers: Список поставщиков (для исключения)
62
- model: Модель для создания эмбеддингов
63
- threshold: Порог схожести
64
- """
65
  self.users = users
66
  self.model = model
67
  self.threshold = threshold
68
  self.supplier_terms = {normalize_text(supplier) for supplier in suppliers}
69
  self.user_terms = [normalize_text(user) for user in users]
 
70
  self.user_embeddings = model.encode(
71
  [f"passage: {user}" for user in self.user_terms],
72
  convert_to_tensor=True,
73
  normalize_embeddings=True,
 
74
  )
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def extract(
77
- self,
78
- text: str,
79
- supplier_phrase: str | None = None,
80
  date_phrase: str | None = None
81
  ) -> dict[str, Any]:
82
- """
83
- Извлекает пользователя из текста.
84
-
85
- Args:
86
- text: Текст для анализа
87
- supplier_phrase: Фраза поставщика для исключения
88
- date_phrase: Фраза даты для исключения
89
-
90
- Returns:
91
- Словарь с user, user_score, matched_user_phrase
92
- """
93
  excluded_tokens: set[str] = set()
94
  if supplier_phrase:
95
  excluded_tokens.update(normalize_text(supplier_phrase).split())
96
  if date_phrase:
97
  excluded_tokens.update(normalize_text(date_phrase).split())
98
 
99
- best_user = None
100
- best_score = -1.0
101
- best_phrase = None
102
-
103
- for word in lemmatize_text(text):
104
- if len(word) < 3:
105
- continue
106
- if word in excluded_tokens or word in self.supplier_terms:
107
- continue
108
 
109
- query_emb = self.model.encode(
110
- f"query: {word}",
111
- convert_to_tensor=True,
112
- normalize_embeddings=True,
113
- )
114
- similarities = torch.cosine_similarity(query_emb.unsqueeze(0), self.user_embeddings, dim=1)
115
- idx = int(torch.argmax(similarities))
116
- score = similarities[idx].item()
117
-
118
- if score > best_score:
119
- best_score = score
120
- best_user = self.users[idx]
121
- best_phrase = word
122
-
123
- if best_score >= self.threshold:
124
- return {
125
- "user": best_user,
126
- "user_score": round(best_score, 4),
127
- "matched_user_phrase": best_phrase,
128
- }
129
 
130
- # Проверка на местоимение "я"
131
- if re.search(r"(?<!\S)я(?!\S)", normalize_text(text), re.IGNORECASE):
132
  return {
133
  "user": "Я",
134
  "user_score": 1.0,
135
  "matched_user_phrase": "я",
136
  }
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  return {
139
  "user": None,
140
  "user_score": None,
 
43
  class ExpenseUserExtractor:
44
  """
45
  Экстрактор пользователей из текста.
46
+
47
+ Сначала использует точное/текстовое совпадение имени, а затем
48
+ один батч эмбеддингов для оставшихся кандидатов.
49
  """
50
+
51
  def __init__(
52
+ self,
53
+ users: list[str],
54
+ suppliers: list[str],
55
+ model: SentenceTransformer,
56
  threshold: float = 0.6
57
  ) -> None:
 
 
 
 
 
 
 
58
  self.users = users
59
  self.model = model
60
  self.threshold = threshold
61
  self.supplier_terms = {normalize_text(supplier) for supplier in suppliers}
62
  self.user_terms = [normalize_text(user) for user in users]
63
+ self.user_lookup = dict(zip(self.user_terms, self.users))
64
  self.user_embeddings = model.encode(
65
  [f"passage: {user}" for user in self.user_terms],
66
  convert_to_tensor=True,
67
  normalize_embeddings=True,
68
+ show_progress_bar=False,
69
  )
70
 
71
+ @staticmethod
72
+ def _contains_whole_phrase(text: str, phrase: str) -> bool:
73
+ if not phrase:
74
+ return False
75
+ return re.search(rf"(?<!\w){re.escape(phrase)}(?!\w)", text) is not None
76
+
77
+ def _extract_candidates(self, text: str, excluded_tokens: set[str]) -> list[str]:
78
+ candidates: list[str] = []
79
+ seen: set[str] = set()
80
+
81
+ for word in lemmatize_text(text):
82
+ if len(word) < 3:
83
+ continue
84
+ if word in excluded_tokens or word in self.supplier_terms or word in seen:
85
+ continue
86
+ seen.add(word)
87
+ candidates.append(word)
88
+
89
+ return candidates
90
+
91
  def extract(
92
+ self,
93
+ text: str,
94
+ supplier_phrase: str | None = None,
95
  date_phrase: str | None = None
96
  ) -> dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
97
  excluded_tokens: set[str] = set()
98
  if supplier_phrase:
99
  excluded_tokens.update(normalize_text(supplier_phrase).split())
100
  if date_phrase:
101
  excluded_tokens.update(normalize_text(date_phrase).split())
102
 
103
+ normalized_text = normalize_text(text)
 
 
 
 
 
 
 
 
104
 
105
+ for user_term, original_user in zip(self.user_terms, self.users):
106
+ if user_term and self._contains_whole_phrase(normalized_text, user_term):
107
+ return {
108
+ "user": original_user,
109
+ "user_score": 1.0,
110
+ "matched_user_phrase": user_term,
111
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ if re.search(r"(?<!\S)я(?!\S)", normalized_text, re.IGNORECASE):
 
114
  return {
115
  "user": "Я",
116
  "user_score": 1.0,
117
  "matched_user_phrase": "я",
118
  }
119
 
120
+ candidates = self._extract_candidates(text, excluded_tokens)
121
+ if not candidates:
122
+ return {
123
+ "user": None,
124
+ "user_score": None,
125
+ "matched_user_phrase": None,
126
+ }
127
+
128
+ query_embeddings = self.model.encode(
129
+ [f"query: {word}" for word in candidates],
130
+ convert_to_tensor=True,
131
+ normalize_embeddings=True,
132
+ show_progress_bar=False,
133
+ batch_size=max(1, min(32, len(candidates))),
134
+ )
135
+
136
+ similarity_matrix = torch.matmul(query_embeddings, self.user_embeddings.T)
137
+ flat_index = int(torch.argmax(similarity_matrix))
138
+ candidate_index = flat_index // len(self.users)
139
+ user_index = flat_index % len(self.users)
140
+ score = similarity_matrix[candidate_index, user_index].item()
141
+
142
+ if score >= self.threshold:
143
+ return {
144
+ "user": self.users[user_index],
145
+ "user_score": round(score, 4),
146
+ "matched_user_phrase": candidates[candidate_index],
147
+ }
148
+
149
  return {
150
  "user": None,
151
  "user_score": None,