Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForTokenClassification, | |
| AutoModelForCausalLM, | |
| pipeline | |
| ) | |
| import torch | |
| import pandas as pd | |
| import re | |
| from typing import List, Dict, Tuple | |
| import numpy as np | |
| # Set device and dtype for optimization | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Load the NER model and tokenizer | |
| print("Loading NER model...") | |
| ner_model_name = "HooshvareLab/bert-base-parsbert-ner-uncased" | |
| ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) | |
| ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name) | |
| ner_model.to(device) | |
| # Create NER pipeline | |
| ner_pipeline = pipeline( | |
| "ner", | |
| model=ner_model, | |
| tokenizer=ner_tokenizer, | |
| device=0 if device == "cuda" else -1, | |
| aggregation_strategy="simple" | |
| ) | |
| # Load Gemma model for stock symbol detection | |
| print("Loading Gemma-2-9b-it model for context understanding...") | |
| gemma_model_name = "google/gemma-2-9b-it" | |
| # Load with optimization settings for better performance | |
| gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model_name) | |
| gemma_model = AutoModelForCausalLM.from_pretrained( | |
| gemma_model_name, | |
| torch_dtype=dtype, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| low_cpu_mem_usage=True | |
| ) | |
| if device == "cpu": | |
| gemma_model = gemma_model.to(device) | |
| # Create text generation pipeline | |
| gemma_pipeline = pipeline( | |
| "text-generation", | |
| model=gemma_model, | |
| tokenizer=gemma_tokenizer, | |
| device=0 if device == "cuda" else -1, | |
| max_new_tokens=50, | |
| temperature=0.1, # Low temperature for consistent outputs | |
| do_sample=False, # Deterministic outputs | |
| pad_token_id=gemma_tokenizer.eos_token_id | |
| ) | |
| # Load stock symbols from CSV | |
| def load_stock_symbols(csv_path="symbols.csv"): | |
| """Load Iranian stock market symbols from CSV file""" | |
| try: | |
| df = pd.read_csv(csv_path, encoding='utf-8') | |
| symbols_dict = {} | |
| for _, row in df.iterrows(): | |
| symbol = row['symbol'] | |
| symbols_dict[symbol] = { | |
| 'company': row['company_name'], | |
| 'bazaar': row['bazaar'], | |
| 'bazaar_group': row['bazaar_group'] | |
| } | |
| return symbols_dict | |
| except Exception as e: | |
| print(f"Error loading symbols CSV: {e}") | |
| # Provide default symbols for demo | |
| return { | |
| 'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'}, | |
| 'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'}, | |
| 'فارس': {'company': 'پتروشیمی فارس', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'محصولات شیمیایی'}, | |
| 'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآوردههای نفتی'}, | |
| 'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'}, | |
| 'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'}, | |
| 'وتوسکا': {'company': 'سرمایه گذاری توسعه توکا', 'bazaar': 'بورس', 'bazaar_group': 'سرمایه گذاریها'}, | |
| 'پی پاد': {'company': 'پرداخت الکترونیک پاسارگاد', 'bazaar': 'بورس', 'bazaar_group': 'رایانه و فعالیت های وابسته'}, | |
| } | |
| # Load symbols | |
| STOCK_SYMBOLS = load_stock_symbols() | |
| SYMBOL_NAMES = set(STOCK_SYMBOLS.keys()) | |
| # Market context keywords | |
| MARKET_KEYWORDS = { | |
| 'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش', | |
| 'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایهگذاری', 'پرتفوی', 'نماد', | |
| 'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود', | |
| 'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت' | |
| } | |
| def use_gemma_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float: | |
| """ | |
| Use Gemma-2-9b-it to determine if a word is used as a stock symbol | |
| Returns confidence score (0-1) | |
| """ | |
| try: | |
| # Create a focused prompt for Gemma | |
| prompt = f"""<bos><start_of_turn>user | |
| You are a Persian financial text analyzer. Determine if the word "{potential_symbol}" in the following Persian text is used as a stock market symbol or as a regular word. | |
| Context information: | |
| - The word "{potential_symbol}" could be a stock symbol for "{symbol_info['company']}" (industry: {symbol_info['bazaar_group']}) | |
| - Stock symbols usually appear with financial terms like: سهام، بورس، معامله، قیمت، خرید، فروش | |
| Text to analyze: | |
| "{text}" | |
| Answer ONLY with one of these: | |
| 1. "STOCK" if it's used as a stock market symbol | |
| 2. "WORD" if it's used as a regular word | |
| Reasoning: Consider the surrounding context. If the text discusses trading, prices, or stock market activities, it's likely a stock symbol. If it discusses the general meaning (like فولاد meaning steel in manufacturing context), it's a regular word. | |
| Answer:<end_of_turn> | |
| <start_of_turn>model | |
| """ | |
| # Generate response | |
| response = gemma_pipeline( | |
| prompt, | |
| max_new_tokens=20, | |
| temperature=0.1, | |
| do_sample=False, | |
| return_full_text=False | |
| ) | |
| # Extract the answer | |
| answer = response[0]['generated_text'].strip().upper() | |
| # Determine confidence based on response | |
| if "STOCK" in answer: | |
| return 0.9 # High confidence it's a stock symbol | |
| elif "WORD" in answer: | |
| return 0.1 # Low confidence it's a stock symbol | |
| else: | |
| # If unclear, analyze the response for clues | |
| if any(keyword in answer.lower() for keyword in ['نماد', 'سهام', 'بورس']): | |
| return 0.7 | |
| else: | |
| return 0.3 | |
| except Exception as e: | |
| print(f"Gemma inference error: {e}") | |
| return 0.5 # Neutral confidence on error | |
| def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]: | |
| """ | |
| Check if a potential symbol is actually used as a stock symbol in context | |
| Using both heuristics and Gemma model | |
| """ | |
| # Get surrounding context | |
| symbol_pos = text.find(potential_symbol) | |
| if symbol_pos == -1: | |
| return False, 0.0 | |
| start_context = max(0, symbol_pos - 100) | |
| end_context = min(len(text), symbol_pos + len(potential_symbol) + 100) | |
| context_window = text[start_context:end_context] | |
| # Count market keywords | |
| words_in_context = context_window.split() | |
| market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS) | |
| # Calculate heuristic score | |
| heuristic_score = min(market_keyword_count * 0.2, 1.0) | |
| # Strong heuristic signals | |
| if market_keyword_count >= 5: | |
| return True, 0.95 | |
| elif market_keyword_count == 0 and len(words_in_context) > 10: | |
| return False, 0.05 | |
| # Use Gemma for disambiguation | |
| gemma_score = use_gemma_for_disambiguation(context_window, potential_symbol, symbol_info) | |
| # Combine scores (give more weight to Gemma as it understands context better) | |
| final_score = (heuristic_score * 0.2 + gemma_score * 0.8) | |
| # Decision threshold | |
| is_stock = final_score > 0.5 | |
| return is_stock, final_score | |
| def find_stock_symbols_in_text(text: str) -> List[Dict]: | |
| """Find and validate stock symbols in text""" | |
| found_symbols = [] | |
| processed_positions = set() # To avoid duplicate processing | |
| # Pattern to match Persian/Arabic words | |
| pattern = r'\b[\u0600-\u06FF]+\b' | |
| for match in re.finditer(pattern, text): | |
| word = match.group() | |
| if word in SYMBOL_NAMES and match.start() not in processed_positions: | |
| symbol_info = STOCK_SYMBOLS[word] | |
| # Check context using Gemma | |
| is_stock, confidence = check_stock_symbol_context(text, word, symbol_info) | |
| if is_stock: | |
| found_symbols.append({ | |
| 'word': word, | |
| 'start': match.start(), | |
| 'end': match.end(), | |
| 'entity_group': 'STOCK', | |
| 'score': confidence, | |
| 'company': symbol_info['company'], | |
| 'bazaar': symbol_info['bazaar'], | |
| 'bazaar_group': symbol_info['bazaar_group'] | |
| }) | |
| processed_positions.add(match.start()) | |
| return found_symbols | |
| # Label colors and names | |
| label_colors = { | |
| "B-PER": "#FF6B6B", | |
| "I-PER": "#FFB3B3", | |
| "B-ORG": "#4ECDC4", | |
| "I-ORG": "#A7E9E4", | |
| "B-LOC": "#95E1D3", | |
| "I-LOC": "#C7F0E8", | |
| "B-DAT": "#FFA07A", | |
| "I-DAT": "#FFDAB9", | |
| "B-TIM": "#DDA0DD", | |
| "I-TIM": "#E6D0E6", | |
| "B-MON": "#FFD700", | |
| "I-MON": "#FFEB99", | |
| "B-PCT": "#87CEEB", | |
| "I-PCT": "#B3DFEF", | |
| "STOCK": "#00FA9A", | |
| } | |
| label_names = { | |
| "PER": "شخص (Person)", | |
| "ORG": "سازمان (Organization)", | |
| "LOC": "مکان (Location)", | |
| "DAT": "تاریخ (Date)", | |
| "TIM": "زمان (Time)", | |
| "MON": "پول (Money)", | |
| "PCT": "درصد (Percent)", | |
| "STOCK": "نماد بورسی (Stock Symbol)", | |
| } | |
| def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]: | |
| """Merge entities, removing overlaps (stock symbols take precedence)""" | |
| all_entities = [] | |
| # Add stock entities first | |
| all_entities.extend(stock_entities) | |
| # Add NER entities that don't overlap | |
| for ner_ent in entities: | |
| overlap = False | |
| for stock_ent in stock_entities: | |
| if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']): | |
| overlap = True | |
| break | |
| if not overlap: | |
| all_entities.append(ner_ent) | |
| return all_entities | |
| def highlight_entities(text, all_entities): | |
| """Create HTML with highlighted entities""" | |
| if not all_entities: | |
| return text | |
| entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True) | |
| result = text | |
| for entity in entities_sorted: | |
| start = entity['start'] | |
| end = entity['end'] | |
| label = entity['entity_group'] | |
| word = text[start:end] | |
| score = entity['score'] | |
| color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC") | |
| tooltip_info = f"{label} (confidence: {score:.2f})" | |
| if label == 'STOCK': | |
| company = entity.get('company', '') | |
| bazaar = entity.get('bazaar', '') | |
| if company: | |
| tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})" | |
| highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{tooltip_info}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>' | |
| result = result[:start] + highlighted + result[end:] | |
| return result | |
| def perform_ner(text): | |
| """Perform integrated NER and stock symbol detection""" | |
| if not text.strip(): | |
| return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", "" | |
| try: | |
| # Perform standard NER | |
| entities = ner_pipeline(text) | |
| # Find stock symbols using Gemma | |
| stock_entities = find_stock_symbols_in_text(text) | |
| # Merge entities | |
| all_entities = merge_overlapping_entities(entities, stock_entities) | |
| # Create highlighted HTML | |
| highlighted_html = f""" | |
| <div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2.5; | |
| padding: 20px; border: 1px solid #ddd; border-radius: 5px; | |
| background-color: #f9f9f9; font-family: Tahoma, Arial;'> | |
| {highlight_entities(text, all_entities)} | |
| </div> | |
| """ | |
| # Create entities table | |
| if all_entities: | |
| entity_info = "### موجودیتهای شناسایی شده (Detected Entities):\n\n" | |
| entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n" | |
| entity_info += "|:------------|:-----------|:------------------|:---------------------|\n" | |
| all_entities.sort(key=lambda x: x['start']) | |
| for ent in all_entities: | |
| label_fa = label_names.get(ent['entity_group'], ent['entity_group']) | |
| details = "" | |
| if ent['entity_group'] == 'STOCK': | |
| company = ent.get('company', '') | |
| bazaar = ent.get('bazaar', '') | |
| group = ent.get('bazaar_group', '') | |
| details = f"{company}<br>{bazaar}<br>{group}" | |
| entity_info += f"| **{ent['word']}** | {label_fa} | {details} | {ent['score']:.2%} |\n" | |
| else: | |
| entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)" | |
| # Statistics | |
| stats = f"\n\n### آمار (Statistics):\n" | |
| stats += f"- تعداد کل موجودیتها: {len(all_entities)}\n" | |
| stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n" | |
| stats += f"- اشخاص: {len([e for e in all_entities if e['entity_group'] == 'PER'])}\n" | |
| stats += f"- سازمانها: {len([e for e in all_entities if e['entity_group'] == 'ORG'])}\n" | |
| stats += f"- مکانها: {len([e for e in all_entities if e['entity_group'] == 'LOC'])}\n" | |
| return highlighted_html, entity_info + stats | |
| except Exception as e: | |
| return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", str(e) | |
| # Examples | |
| examples = [ | |
| ["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."], | |
| ["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."], | |
| ["صنعت فولاد در اصفهان یکی از مهمترین صنایع کشور است."], | |
| ["قیمت سهام شپنا در معاملات امروز ۵ درصد رشد داشت و به ۱۲۰۰۰ ریال رسید."], | |
| ["بانک ملت اعلام کرد که سود سهام وبملت را در تاریخ ۱۵ خرداد ۱۴۰۳ پرداخت خواهد کرد."], | |
| ["شرکت فولاد مبارکه با نماد فولاد در بورس تهران فعال است و محصولات فولادی تولید میکند."], | |
| ["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."], | |
| ] | |
| # Gradio interface | |
| with gr.Blocks( | |
| title="Persian NER + Stock Symbols | شناسایی موجودیتها و نمادهای بورسی", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🏦 شناسایی هوشمند موجودیتها و نمادهای بورس ایران | |
| ## Persian Named Entity Recognition with Stock Symbol Detection | |
| ### Powered by Google Gemma-2-9B-IT | |
| <div class="rtl-text"> | |
| این برنامه با استفاده از مدل قدرتمند Gemma-2-9B، متنهای فارسی را تحلیل کرده و موجودیتهای مختلف را شناسایی میکند. | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| input_text = gr.Textbox( | |
| label="متن فارسی را وارد کنید (Enter Persian Text)", | |
| placeholder="مثال: سهام فولاد در بورس تهران معامله میشود...", | |
| lines=6, | |
| rtl=True, | |
| elem_classes=["rtl-text"] | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("🔍 تحلیل متن", variant="primary", scale=2) | |
| clear_btn = gr.Button("🗑️ پاک کردن", scale=1) | |
| with gr.Column(scale=6): | |
| output_html = gr.HTML( | |
| label="نتیجه تحلیل (Analysis Result)", | |
| elem_classes=["rtl-text"] | |
| ) | |
| with gr.Row(): | |
| output_entities = gr.Markdown( | |
| label="جدول موجودیتها (Entity Table)", | |
| elem_classes=["rtl-text"] | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=input_text, | |
| label="نمونههای آماده (Ready Examples)", | |
| examples_per_page=4 | |
| ) | |
| # User guide | |
| with gr.Accordion("📖 راهنمای استفاده (User Guide)", open=True): | |
| gr.Markdown(""" | |
| <div class="rtl-text"> | |
| ## چگونه از این برنامه استفاده کنیم؟ | |
| 1. **متن فارسی خود را در کادر بالا وارد کنید** | |
| 2. **دکمه «تحلیل متن» را بزنید** | |
| 3. **نتایج را در دو بخش مشاهده کنید:** | |
| - متن با موجودیتهای رنگی شده | |
| - جدول کامل موجودیتها با جزئیات | |
| ## انواع موجودیتهایی که شناسایی میشوند: | |
| | رنگ | نوع | مثال | | |
| |:---:|:----|:-----| | |
| | 🔴 قرمز | **اشخاص** | علی احمدی، مریم رضایی | | |
| | 🔵 آبی | **سازمانها** | شرکت ملی نفت، بانک ملت | | |
| | 🟢 سبز | **مکانها** | تهران، اصفهان، ایران | | |
| | 🟠 نارنجی | **تاریخها** | ۱۵ خرداد ۱۴۰۳ | | |
| | 🟣 بنفش | **زمانها** | ساعت ۱۰ صبح | | |
| | 🟡 زرد | **مبالغ پولی** | ۱۰۰۰ ریال، ۵ میلیارد تومان | | |
| | 🔷 آبی آسمانی | **درصدها** | ۲۰ درصد، ۵٪ | | |
| | 💚 سبز روشن | **نمادهای بورسی** | فولاد، وبملت، شپنا | | |
| ## ویژگی خاص: تشخیص هوشمند با Gemma-2-9B | |
| این برنامه از **مدل Gemma-2-9B** گوگل استفاده میکند که: | |
| - درک عمیق از زبان فارسی دارد | |
| - متن را به صورت کامل تحلیل میکند | |
| - بین نماد بورسی و کلمه عادی تمایز قائل میشود | |
| **مثال:** | |
| - «سهام **فولاد** در بورس معامله شد» ← فولاد = نماد بورسی ✅ | |
| - «صنعت **فولاد** در کشور مهم است» ← فولاد = کلمه عادی ❌ | |
| ## نحوه تفسیر نتایج: | |
| - **رنگها**: نوع موجودیت را نشان میدهند | |
| - **برچسبها**: نوع موجودیت به صورت مختصر | |
| - **درصد اطمینان**: میزان اطمینان سیستم (۰-۱۰۰٪) | |
| - **جزئیات نمادها**: نام شرکت، بازار و گروه صنعت | |
| ## مدلهای استفاده شده: | |
| - **ParsBERT NER**: شناسایی موجودیتهای عمومی | |
| - **Google Gemma-2-9B-IT**: تحلیل هوشمند متن و تشخیص نمادهای بورسی | |
| ⚠️ **توجه**: مدل Gemma به دلیل حجم بالا (9 میلیارد پارامتر) ممکن است کمی کندتر باشد | |
| </div> | |
| """) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=perform_ner, | |
| inputs=input_text, | |
| outputs=[output_html, output_entities] | |
| ) | |
| clear_btn.click( | |
| lambda: ("", "", ""), | |
| outputs=[input_text, output_html, output_entities] | |
| ) | |
| input_text.submit( | |
| fn=perform_ner, | |
| inputs=input_text, | |
| outputs=[output_html, output_entities] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| print("Starting Persian NER + Stock Symbol Detection System...") | |
| print(f"Using device: {device}") | |
| print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols") | |
| print("Models:") | |
| print(" - NER: HooshvareLab/bert-base-parsbert-ner-uncased") | |
| print(" - Context Understanding: Google Gemma-2-9B-IT") | |
| print("\nNote: Gemma-2-9B is a large model. First run may take time to download.") | |
| print("For better performance, consider using GPU if available.") | |
| demo.launch( | |
| share=False, | |
| debug=True | |
| ) |