Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ import re
|
|
| 11 |
from typing import List, Dict, Tuple
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
-
# Set device
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 17 |
|
|
@@ -31,33 +31,41 @@ ner_pipeline = pipeline(
|
|
| 31 |
aggregation_strategy="simple"
|
| 32 |
)
|
| 33 |
|
| 34 |
-
# Load
|
| 35 |
-
print("Loading
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Load stock symbols from CSV
|
| 63 |
def load_stock_symbols(csv_path="symbols.csv"):
|
|
@@ -99,65 +107,101 @@ MARKET_KEYWORDS = {
|
|
| 99 |
'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت'
|
| 100 |
}
|
| 101 |
|
| 102 |
-
def
|
| 103 |
"""
|
| 104 |
-
Use
|
| 105 |
Returns confidence score (0-1)
|
| 106 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
try:
|
| 108 |
-
# Create a
|
| 109 |
-
prompt = f"""
|
| 110 |
-
You are a Persian financial text analyzer. Determine if the word "{potential_symbol}" in the following Persian text is used as a stock market symbol or as a regular word.
|
| 111 |
|
| 112 |
-
Context
|
| 113 |
-
- The word "{potential_symbol}" could be a stock symbol for "{symbol_info['company']}" (industry: {symbol_info['bazaar_group']})
|
| 114 |
-
- Stock symbols usually appear with financial terms like: سهام، بورس، معامله، قیمت، خرید، فروش
|
| 115 |
|
| 116 |
-
Text
|
| 117 |
-
"{text}"
|
| 118 |
|
| 119 |
-
Answer
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
)
|
| 137 |
|
| 138 |
-
#
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
return 0.9 # High confidence it's a stock symbol
|
| 144 |
-
elif "WORD" in answer:
|
| 145 |
-
return 0.1 # Low confidence it's a stock symbol
|
| 146 |
else:
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
return 0.7
|
| 150 |
-
else:
|
| 151 |
-
return 0.3
|
| 152 |
-
|
| 153 |
except Exception as e:
|
| 154 |
-
print(f"
|
| 155 |
-
return 0.5
|
| 156 |
|
| 157 |
def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
|
| 158 |
"""
|
| 159 |
Check if a potential symbol is actually used as a stock symbol in context
|
| 160 |
-
Using both heuristics and Gemma model
|
| 161 |
"""
|
| 162 |
# Get surrounding context
|
| 163 |
symbol_pos = text.find(potential_symbol)
|
|
@@ -181,11 +225,14 @@ def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Di
|
|
| 181 |
elif market_keyword_count == 0 and len(words_in_context) > 10:
|
| 182 |
return False, 0.05
|
| 183 |
|
| 184 |
-
# Use
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
# Combine scores
|
| 188 |
-
final_score = (heuristic_score * 0.
|
| 189 |
|
| 190 |
# Decision threshold
|
| 191 |
is_stock = final_score > 0.5
|
|
@@ -195,7 +242,7 @@ def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Di
|
|
| 195 |
def find_stock_symbols_in_text(text: str) -> List[Dict]:
|
| 196 |
"""Find and validate stock symbols in text"""
|
| 197 |
found_symbols = []
|
| 198 |
-
processed_positions = set()
|
| 199 |
|
| 200 |
# Pattern to match Persian/Arabic words
|
| 201 |
pattern = r'\b[\u0600-\u06FF]+\b'
|
|
@@ -206,7 +253,7 @@ def find_stock_symbols_in_text(text: str) -> List[Dict]:
|
|
| 206 |
if word in SYMBOL_NAMES and match.start() not in processed_positions:
|
| 207 |
symbol_info = STOCK_SYMBOLS[word]
|
| 208 |
|
| 209 |
-
# Check context
|
| 210 |
is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
|
| 211 |
|
| 212 |
if is_stock:
|
|
@@ -255,13 +302,10 @@ label_names = {
|
|
| 255 |
}
|
| 256 |
|
| 257 |
def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]:
|
| 258 |
-
"""Merge entities, removing overlaps
|
| 259 |
all_entities = []
|
| 260 |
-
|
| 261 |
-
# Add stock entities first
|
| 262 |
all_entities.extend(stock_entities)
|
| 263 |
|
| 264 |
-
# Add NER entities that don't overlap
|
| 265 |
for ner_ent in entities:
|
| 266 |
overlap = False
|
| 267 |
for stock_ent in stock_entities:
|
|
@@ -312,7 +356,7 @@ def perform_ner(text):
|
|
| 312 |
# Perform standard NER
|
| 313 |
entities = ner_pipeline(text)
|
| 314 |
|
| 315 |
-
# Find stock symbols
|
| 316 |
stock_entities = find_stock_symbols_in_text(text)
|
| 317 |
|
| 318 |
# Merge entities
|
|
@@ -379,13 +423,13 @@ with gr.Blocks(
|
|
| 379 |
.rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; }
|
| 380 |
"""
|
| 381 |
) as demo:
|
| 382 |
-
gr.Markdown("""
|
| 383 |
# 🏦 شناسایی هوشمند موجودیتها و نمادهای بورس ایران
|
| 384 |
## Persian Named Entity Recognition with Stock Symbol Detection
|
| 385 |
-
###
|
| 386 |
|
| 387 |
<div class="rtl-text">
|
| 388 |
-
این برنامه
|
| 389 |
</div>
|
| 390 |
""")
|
| 391 |
|
|
@@ -447,12 +491,9 @@ with gr.Blocks(
|
|
| 447 |
| 🔷 آبی آسمانی | **درصدها** | ۲۰ درصد، ۵٪ |
|
| 448 |
| 💚 سبز روشن | **نمادهای بورسی** | فولاد، وبملت، شپنا |
|
| 449 |
|
| 450 |
-
## ویژگی خاص: تشخیص هوشمند
|
| 451 |
|
| 452 |
-
|
| 453 |
-
- درک عمیق از زبان فارسی دارد
|
| 454 |
-
- متن را به صورت کامل تحلیل میکند
|
| 455 |
-
- بین نماد بورسی و کلمه عادی تمایز قائل میشود
|
| 456 |
|
| 457 |
**مثال:**
|
| 458 |
- «سهام **فولاد** در بورس معامله شد» ← فولاد = نماد بورسی ✅
|
|
@@ -468,9 +509,7 @@ with gr.Blocks(
|
|
| 468 |
## مدلهای استفاده شده:
|
| 469 |
|
| 470 |
- **ParsBERT NER**: شناسایی موجودیتهای عمومی
|
| 471 |
-
- **
|
| 472 |
-
|
| 473 |
-
⚠️ **توجه**: مدل Gemma به دلیل حجم بالا (9 میلیارد پارامتر) ممکن است کمی کندتر باشد
|
| 474 |
|
| 475 |
</div>
|
| 476 |
""")
|
|
@@ -498,11 +537,9 @@ if __name__ == "__main__":
|
|
| 498 |
print("Starting Persian NER + Stock Symbol Detection System...")
|
| 499 |
print(f"Using device: {device}")
|
| 500 |
print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
|
| 501 |
-
print("Models:")
|
| 502 |
print(" - NER: HooshvareLab/bert-base-parsbert-ner-uncased")
|
| 503 |
-
print(" - Context
|
| 504 |
-
print("\nNote: Gemma-2-9B is a large model. First run may take time to download.")
|
| 505 |
-
print("For better performance, consider using GPU if available.")
|
| 506 |
demo.launch(
|
| 507 |
share=False,
|
| 508 |
debug=True
|
|
|
|
| 11 |
from typing import List, Dict, Tuple
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
+
# Set device
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 17 |
|
|
|
|
| 31 |
aggregation_strategy="simple"
|
| 32 |
)
|
| 33 |
|
| 34 |
+
# Load a smaller, open model for context understanding
|
| 35 |
+
print("Loading context understanding model...")
|
| 36 |
+
# Using Microsoft Phi-2 (small and efficient) or Mistral-7B-Instruct (if you have more resources)
|
| 37 |
+
context_model_name = "microsoft/phi-2" # 2.7B parameters, works well on CPU
|
| 38 |
|
| 39 |
+
try:
|
| 40 |
+
context_tokenizer = AutoTokenizer.from_pretrained(context_model_name, trust_remote_code=True)
|
| 41 |
+
context_model = AutoModelForCausalLM.from_pretrained(
|
| 42 |
+
context_model_name,
|
| 43 |
+
torch_dtype=dtype,
|
| 44 |
+
trust_remote_code=True,
|
| 45 |
+
device_map="auto" if torch.cuda.is_available() else None
|
| 46 |
+
)
|
| 47 |
+
if device == "cpu":
|
| 48 |
+
context_model = context_model.to(device)
|
| 49 |
+
|
| 50 |
+
# Set pad token if not set
|
| 51 |
+
if context_tokenizer.pad_token is None:
|
| 52 |
+
context_tokenizer.pad_token = context_tokenizer.eos_token
|
| 53 |
+
|
| 54 |
+
use_llm_model = True
|
| 55 |
+
print(f"Successfully loaded {context_model_name}")
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Could not load Phi-2 model: {e}")
|
| 59 |
+
print("Falling back to zero-shot classification model...")
|
| 60 |
+
|
| 61 |
+
# Fallback to mDeBERTa for zero-shot classification
|
| 62 |
+
context_model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
|
| 63 |
+
classifier = pipeline(
|
| 64 |
+
"zero-shot-classification",
|
| 65 |
+
model=context_model_name,
|
| 66 |
+
device=0 if device == "cuda" else -1
|
| 67 |
+
)
|
| 68 |
+
use_llm_model = False
|
| 69 |
|
| 70 |
# Load stock symbols from CSV
|
| 71 |
def load_stock_symbols(csv_path="symbols.csv"):
|
|
|
|
| 107 |
'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت'
|
| 108 |
}
|
| 109 |
|
| 110 |
+
def use_phi_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float:
|
| 111 |
"""
|
| 112 |
+
Use Phi-2 model to determine if a word is used as a stock symbol
|
| 113 |
Returns confidence score (0-1)
|
| 114 |
"""
|
| 115 |
+
if not use_llm_model:
|
| 116 |
+
# Use zero-shot classification instead
|
| 117 |
+
return use_zero_shot_classification(text, potential_symbol, symbol_info)
|
| 118 |
+
|
| 119 |
try:
|
| 120 |
+
# Create a simple prompt for Phi-2
|
| 121 |
+
prompt = f"""Analyze this Persian text and determine if "{potential_symbol}" is used as a stock market symbol.
|
|
|
|
| 122 |
|
| 123 |
+
Context: "{potential_symbol}" could be a stock symbol for {symbol_info['company']} company.
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
Text: {text}
|
|
|
|
| 126 |
|
| 127 |
+
Answer with only "STOCK" if it's a stock symbol, or "WORD" if it's a regular word:
|
| 128 |
+
Answer: """
|
| 129 |
+
|
| 130 |
+
# Tokenize and generate
|
| 131 |
+
inputs = context_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 132 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 133 |
+
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
outputs = context_model.generate(
|
| 136 |
+
**inputs,
|
| 137 |
+
max_new_tokens=10,
|
| 138 |
+
temperature=0.1,
|
| 139 |
+
do_sample=False,
|
| 140 |
+
pad_token_id=context_tokenizer.eos_token_id
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Decode the response
|
| 144 |
+
response = context_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
| 145 |
+
response = response.strip().upper()
|
| 146 |
+
|
| 147 |
+
# Parse response
|
| 148 |
+
if "STOCK" in response:
|
| 149 |
+
return 0.9
|
| 150 |
+
elif "WORD" in response:
|
| 151 |
+
return 0.1
|
| 152 |
+
else:
|
| 153 |
+
# Ambiguous response, use heuristics
|
| 154 |
+
return 0.5
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Phi-2 inference error: {e}")
|
| 158 |
+
return 0.5
|
| 159 |
|
| 160 |
+
def use_zero_shot_classification(text: str, potential_symbol: str, symbol_info: Dict) -> float:
|
| 161 |
+
"""
|
| 162 |
+
Fallback: Use zero-shot classification to determine if a word is a stock symbol
|
| 163 |
+
"""
|
| 164 |
+
try:
|
| 165 |
+
# Get context around the symbol
|
| 166 |
+
symbol_pos = text.find(potential_symbol)
|
| 167 |
+
if symbol_pos == -1:
|
| 168 |
+
return 0.5
|
| 169 |
+
|
| 170 |
+
start = max(0, symbol_pos - 100)
|
| 171 |
+
end = min(len(text), symbol_pos + len(potential_symbol) + 100)
|
| 172 |
+
context_text = text[start:end]
|
| 173 |
+
|
| 174 |
+
# Define candidate labels
|
| 175 |
+
candidate_labels = [
|
| 176 |
+
f"نماد بورسی {symbol_info['company']}",
|
| 177 |
+
f"کلمه عادی {potential_symbol}",
|
| 178 |
+
"stock market symbol",
|
| 179 |
+
"regular word"
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
# Perform classification
|
| 183 |
+
result = classifier(
|
| 184 |
+
context_text,
|
| 185 |
+
candidate_labels=candidate_labels,
|
| 186 |
+
multi_label=False
|
| 187 |
)
|
| 188 |
|
| 189 |
+
# Check top label
|
| 190 |
+
top_label = result['labels'][0]
|
| 191 |
+
top_score = result['scores'][0]
|
| 192 |
|
| 193 |
+
if 'نماد بورسی' in top_label or 'stock' in top_label.lower():
|
| 194 |
+
return top_score
|
|
|
|
|
|
|
|
|
|
| 195 |
else:
|
| 196 |
+
return 1 - top_score
|
| 197 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
except Exception as e:
|
| 199 |
+
print(f"Classification error: {e}")
|
| 200 |
+
return 0.5
|
| 201 |
|
| 202 |
def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
|
| 203 |
"""
|
| 204 |
Check if a potential symbol is actually used as a stock symbol in context
|
|
|
|
| 205 |
"""
|
| 206 |
# Get surrounding context
|
| 207 |
symbol_pos = text.find(potential_symbol)
|
|
|
|
| 225 |
elif market_keyword_count == 0 and len(words_in_context) > 10:
|
| 226 |
return False, 0.05
|
| 227 |
|
| 228 |
+
# Use AI model for disambiguation
|
| 229 |
+
if use_llm_model:
|
| 230 |
+
ai_score = use_phi_for_disambiguation(context_window, potential_symbol, symbol_info)
|
| 231 |
+
else:
|
| 232 |
+
ai_score = use_zero_shot_classification(context_window, potential_symbol, symbol_info)
|
| 233 |
|
| 234 |
+
# Combine scores
|
| 235 |
+
final_score = (heuristic_score * 0.3 + ai_score * 0.7)
|
| 236 |
|
| 237 |
# Decision threshold
|
| 238 |
is_stock = final_score > 0.5
|
|
|
|
| 242 |
def find_stock_symbols_in_text(text: str) -> List[Dict]:
|
| 243 |
"""Find and validate stock symbols in text"""
|
| 244 |
found_symbols = []
|
| 245 |
+
processed_positions = set()
|
| 246 |
|
| 247 |
# Pattern to match Persian/Arabic words
|
| 248 |
pattern = r'\b[\u0600-\u06FF]+\b'
|
|
|
|
| 253 |
if word in SYMBOL_NAMES and match.start() not in processed_positions:
|
| 254 |
symbol_info = STOCK_SYMBOLS[word]
|
| 255 |
|
| 256 |
+
# Check context
|
| 257 |
is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
|
| 258 |
|
| 259 |
if is_stock:
|
|
|
|
| 302 |
}
|
| 303 |
|
| 304 |
def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]:
|
| 305 |
+
"""Merge entities, removing overlaps"""
|
| 306 |
all_entities = []
|
|
|
|
|
|
|
| 307 |
all_entities.extend(stock_entities)
|
| 308 |
|
|
|
|
| 309 |
for ner_ent in entities:
|
| 310 |
overlap = False
|
| 311 |
for stock_ent in stock_entities:
|
|
|
|
| 356 |
# Perform standard NER
|
| 357 |
entities = ner_pipeline(text)
|
| 358 |
|
| 359 |
+
# Find stock symbols
|
| 360 |
stock_entities = find_stock_symbols_in_text(text)
|
| 361 |
|
| 362 |
# Merge entities
|
|
|
|
| 423 |
.rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; }
|
| 424 |
"""
|
| 425 |
) as demo:
|
| 426 |
+
gr.Markdown(f"""
|
| 427 |
# 🏦 شناسایی هوشمند موجودیتها و نمادهای بورس ایران
|
| 428 |
## Persian Named Entity Recognition with Stock Symbol Detection
|
| 429 |
+
### Using {context_model_name.split('/')[-1]} for Context Understanding
|
| 430 |
|
| 431 |
<div class="rtl-text">
|
| 432 |
+
این برنامه متنهای فارسی را تحلیل کرده و موجودیتهای مختلف را شناسایی میکند.
|
| 433 |
</div>
|
| 434 |
""")
|
| 435 |
|
|
|
|
| 491 |
| 🔷 آبی آسمانی | **درصدها** | ۲۰ درصد، ۵٪ |
|
| 492 |
| 💚 سبز روشن | **نمادهای بورسی** | فولاد، وبملت، شپنا |
|
| 493 |
|
| 494 |
+
## ویژگی خاص: تشخیص هوشمند نمادهای بورسی
|
| 495 |
|
| 496 |
+
برنامه با استفاده از **هوش مصنوعی** تشخیص میدهد که آیا یک کلمه نماد بورسی است یا خیر.
|
|
|
|
|
|
|
|
|
|
| 497 |
|
| 498 |
**مثال:**
|
| 499 |
- «سهام **فولاد** در بورس معامله شد» ← فولاد = نماد بورسی ✅
|
|
|
|
| 509 |
## مدلهای استفاده شده:
|
| 510 |
|
| 511 |
- **ParsBERT NER**: شناسایی موجودیتهای عمومی
|
| 512 |
+
- **Microsoft Phi-2 / mDeBERTa**: تحلیل هوشمند متن برای تشخیص نمادهای بورسی
|
|
|
|
|
|
|
| 513 |
|
| 514 |
</div>
|
| 515 |
""")
|
|
|
|
| 537 |
print("Starting Persian NER + Stock Symbol Detection System...")
|
| 538 |
print(f"Using device: {device}")
|
| 539 |
print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
|
| 540 |
+
print("Models loaded:")
|
| 541 |
print(" - NER: HooshvareLab/bert-base-parsbert-ner-uncased")
|
| 542 |
+
print(f" - Context: {context_model_name}")
|
|
|
|
|
|
|
| 543 |
demo.launch(
|
| 544 |
share=False,
|
| 545 |
debug=True
|