| """ |
| Model loader for all Arabic NLP models. |
| Handles loading and inference with error handling. |
| """ |
|
|
| import os |
| import logging |
| from pathlib import Path |
| import json |
| import pickle |
| import difflib |
| import torch |
| import importlib.util |
| from transformers import ( |
| MBartForConditionalGeneration, |
| AutoTokenizer, |
| AutoConfig, |
| AutoModelForSeq2SeqLM, |
| AutoModelForCausalLM, |
| EncoderDecoderModel, |
| BertConfig, |
| EncoderDecoderConfig |
| ) |
|
|
| |
| |
| ArabicSpellChecker = None |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| MODEL_BASE_PATH = Path(__file__).parent.parent / "models" |
| SUMMARIZATION_PATH = MODEL_BASE_PATH / "Summarization" / "Model" |
| SPELLING_PATH = MODEL_BASE_PATH / "Spelling" / "Model" |
| AUTOCOMPLETE_PATH = MODEL_BASE_PATH / "Autocomplete" / "Model" |
| GRAMMAR_PATH = MODEL_BASE_PATH / "Grammrar" / "Model" |
| PUNCTUATION_PATH = MODEL_BASE_PATH / "Punctuation" / "Model" |
|
|
|
|
| class SummarizationModel: |
| """Wrapper class for the Arabic summarization model.""" |
| |
| def __init__(self, model_path): |
| """ |
| Initialize the model. |
| |
| Args: |
| model_path: Path to the model directory |
| |
| Raises: |
| FileNotFoundError: If model files are not found |
| RuntimeError: If model loading fails |
| """ |
| self.model_source = str(model_path) |
| self._is_remote_source = self._looks_like_remote_source(self.model_source) |
| self.model_path = None if self._is_remote_source else Path(model_path) |
| self.model = None |
| self.tokenizer = None |
| self.device = None |
| |
| self._validate_path() |
| self._load_model() |
|
|
| @staticmethod |
| def _looks_like_remote_source(source): |
| """Detect Hugging Face repo ids or URLs.""" |
| source = str(source) |
| return source.startswith(("http://", "https://")) or ("/" in source and not Path(source).exists()) |
| |
| def _validate_path(self): |
| """Validate that the model path exists and contains required files.""" |
| if self._is_remote_source: |
| logger.info(f"Using remote model source: {self.model_source}") |
| return |
|
|
| if not self.model_path.exists(): |
| raise FileNotFoundError(f"Model path does not exist: {self.model_path}") |
| |
| required_files = ['config.json', 'tokenizer.json', 'model.safetensors'] |
| missing_files = [] |
| |
| for file in required_files: |
| if not (self.model_path / file).exists(): |
| missing_files.append(file) |
| |
| if missing_files: |
| raise FileNotFoundError( |
| f"Missing required model files: {', '.join(missing_files)}" |
| ) |
| |
| logger.info(f"Model path validated: {self.model_path}") |
| |
| def _fix_generation_config(self): |
| """Fix generation_config.json and config.json if early_stopping is None/null.""" |
| if self._is_remote_source: |
| return |
|
|
| gen_config_path = self.model_path / "generation_config.json" |
| config_path = self.model_path / "config.json" |
| |
| try: |
| |
| if config_path.exists(): |
| with open(config_path, 'r', encoding='utf-8') as f: |
| config = json.load(f) |
| |
| |
| if 'early_stopping' in config and config['early_stopping'] is None: |
| logger.info("Fixing early_stopping in config.json (was None)") |
| config['early_stopping'] = True |
| with open(config_path, 'w', encoding='utf-8') as f: |
| json.dump(config, f, indent=2, ensure_ascii=False) |
| logger.info("Fixed config.json - set early_stopping to True") |
| |
| |
| if gen_config_path.exists(): |
| with open(gen_config_path, 'r', encoding='utf-8') as f: |
| gen_config = json.load(f) |
| |
| |
| if gen_config.get('early_stopping') is None: |
| logger.info("Fixing early_stopping in generation_config.json (was None)") |
| gen_config['early_stopping'] = True |
| with open(gen_config_path, 'w', encoding='utf-8') as f: |
| json.dump(gen_config, f, indent=2, ensure_ascii=False) |
| logger.info("Fixed generation_config.json - set early_stopping to True") |
| |
| except Exception as e: |
| logger.warning(f"Could not fix generation config files: {str(e)}") |
| |
| |
| def _load_model(self): |
| """Load the model and tokenizer.""" |
| try: |
| |
| if torch.cuda.is_available(): |
| self.device = torch.device('cuda') |
| logger.info(f"Using CUDA device: {torch.cuda.get_device_name(0)}") |
| else: |
| self.device = torch.device('cpu') |
| logger.info("Using CPU device") |
| |
| |
| logger.info("Loading tokenizer...") |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.model_source, |
| local_files_only=True, |
| trust_remote_code=False |
| ) |
| logger.info("Tokenizer loaded successfully") |
| except Exception as e: |
| logger.error(f"Failed to load tokenizer: {str(e)}") |
| |
| logger.info("Retrying tokenizer load without local_files_only...") |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.model_source, |
| trust_remote_code=False |
| ) |
| logger.info("Tokenizer loaded successfully (fallback method)") |
| |
| |
| logger.info("Loading model (this may take a while)...") |
| |
| |
| self._fix_generation_config() |
| |
| |
| try: |
| config = AutoConfig.from_pretrained( |
| self.model_source, |
| local_files_only=True, |
| trust_remote_code=False |
| ) |
| |
| |
| if hasattr(config, 'early_stopping') and config.early_stopping is None: |
| logger.info("Fixing early_stopping in loaded config (was None)") |
| config.early_stopping = True |
| |
| |
| try: |
| self.model = MBartForConditionalGeneration.from_pretrained( |
| self.model_source, |
| config=config, |
| local_files_only=True, |
| trust_remote_code=False, |
| torch_dtype=torch.float16 |
| ) |
| except Exception as e: |
| logger.warning(f"Failed to load with config: {str(e)}") |
| |
| self.model = MBartForConditionalGeneration.from_pretrained( |
| self.model_source, |
| local_files_only=True, |
| trust_remote_code=False, |
| torch_dtype=torch.float16 |
| ) |
| except Exception as e: |
| logger.warning(f"Failed to load config: {str(e)}") |
| logger.info("Retrying model load without config fix...") |
| try: |
| self.model = MBartForConditionalGeneration.from_pretrained( |
| self.model_source, |
| local_files_only=True, |
| trust_remote_code=False, |
| torch_dtype=torch.float16 |
| ) |
| except Exception as e2: |
| logger.warning(f"Failed to load with local_files_only: {str(e2)}") |
| logger.info("Retrying model load without local_files_only...") |
| self.model = MBartForConditionalGeneration.from_pretrained( |
| self.model_source, |
| trust_remote_code=False, |
| torch_dtype=torch.float16 |
| ) |
| |
| |
| self.model.to(self.device) |
| self.model.eval() |
| |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| logger.info(f"Model loaded successfully on {self.device}") |
| |
| except Exception as e: |
| logger.error(f"Error loading model: {str(e)}") |
| import traceback |
| logger.error(traceback.format_exc()) |
| raise RuntimeError(f"Failed to load model: {str(e)}") |
| |
| def summarize(self, text, max_length=150, min_length=30, num_beams=1, **kwargs): |
| """ |
| Summarize Arabic text. |
| |
| Args: |
| text: Input Arabic text to summarize |
| max_length: Maximum length of the summary |
| min_length: Minimum length of the summary |
| num_beams: Number of beams for beam search |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| str: Summarized text |
| |
| Raises: |
| RuntimeError: If summarization fails |
| """ |
| if self.model is None or self.tokenizer is None: |
| raise RuntimeError("Model not loaded") |
| |
| try: |
| |
| inputs = self.tokenizer( |
| text, |
| max_length=1024, |
| truncation=True, |
| padding=True, |
| return_tensors="pt" |
| ) |
| |
| |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| |
| |
| generate_kwargs = dict( |
| max_new_tokens=max(20, min(max_length, 160)), |
| min_new_tokens=max(0, min_length), |
| num_beams=num_beams, |
| do_sample=False, |
| early_stopping=False, |
| no_repeat_ngram_size=3, |
| repetition_penalty=1.1, |
| ) |
| generate_kwargs.update(kwargs) |
|
|
| |
| generate_kwargs.pop('max_length', None) |
| generate_kwargs.pop('min_length', None) |
|
|
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| **generate_kwargs, |
| ) |
| |
| |
| summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| summary = summary.strip() |
| if self._needs_fallback(text, summary): |
| return self._extractive_fallback(text, max_words=max(12, min(max_length, 80))) |
|
|
| return summary |
| |
| except Exception as e: |
| logger.error(f"Error during summarization: {str(e)}") |
| raise RuntimeError(f"Summarization failed: {str(e)}") |
|
|
| def _needs_fallback(self, source_text, summary_text): |
| """Return True when generated summary appears too far from the source text.""" |
| if not summary_text: |
| return True |
|
|
| source_words = set(source_text.split()) |
| summary_words = summary_text.split() |
| if not summary_words: |
| return True |
|
|
| overlap = sum(1 for word in summary_words if word in source_words) |
| overlap_ratio = overlap / max(1, len(summary_words)) |
|
|
| |
| ratio = difflib.SequenceMatcher(None, source_text[:500], summary_text[:500]).ratio() |
|
|
| return overlap_ratio < 0.35 or ratio < 0.22 |
|
|
| def _extractive_fallback(self, source_text, max_words=40): |
| """Build a conservative summary from the opening sentences of the source text.""" |
| text = source_text.strip() |
| if not text: |
| return text |
|
|
| sentence_endings = ['.', '!', '?', '؟', '۔', '،'] |
| sentences = [] |
| current = [] |
| for chunk in text.replace('\n', ' ').split(' '): |
| current.append(chunk) |
| if chunk and chunk[-1] in sentence_endings: |
| sentence = ' '.join(current).strip() |
| if sentence: |
| sentences.append(sentence) |
| current = [] |
|
|
| if current: |
| sentence = ' '.join(current).strip() |
| if sentence: |
| sentences.append(sentence) |
|
|
| if not sentences: |
| words = text.split() |
| return ' '.join(words[:max_words]).strip() |
|
|
| chosen = [] |
| total_words = 0 |
| for sentence in sentences: |
| sentence_words = sentence.split() |
| if not sentence_words: |
| continue |
| if total_words + len(sentence_words) > max_words and chosen: |
| break |
| chosen.append(sentence) |
| total_words += len(sentence_words) |
| if total_words >= max_words: |
| break |
|
|
| if not chosen: |
| return ' '.join(text.split()[:max_words]).strip() |
|
|
| return ' '.join(chosen).strip() |
| |
| def __del__(self): |
| """Cleanup resources.""" |
| try: |
| if getattr(self, 'model', None) is not None: |
| del self.model |
| |
| if 'torch' in globals() and hasattr(torch, 'cuda'): |
| try: |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| except Exception: |
| |
| pass |
| except Exception: |
| |
| pass |
|
|
|
|
|
|
| class SpellingModel: |
| """Wrapper class for the Arabic spelling correction model.""" |
| |
| def __init__(self, model_path=None): |
| """ |
| Initialize the spelling model. |
| |
| Args: |
| model_path: Path to the model directory (defaults to SPELLING_PATH) |
| """ |
| self.model_path = Path(model_path) if model_path else SPELLING_PATH |
| self.model = None |
| self.device = None |
| |
| self._validate_path() |
| self._load_model() |
| |
| def _validate_path(self): |
| """Validate that the model path exists.""" |
| if not self.model_path.exists(): |
| raise FileNotFoundError(f"Model path does not exist: {self.model_path}") |
| |
| |
| pt_files = list(self.model_path.glob("*.pt")) |
| if not pt_files: |
| raise FileNotFoundError(f"No .pt model file found in: {self.model_path}") |
| |
| logger.info(f"Spelling model path validated: {self.model_path}") |
| |
| def _load_model(self): |
| """Load the spelling model.""" |
| try: |
| global ArabicSpellChecker |
| if ArabicSpellChecker is None: |
| try: |
| from ara_spell import ArabicSpellChecker as _ArabicSpellChecker |
| except ImportError: |
| try: |
| from src.ara_spell import ArabicSpellChecker as _ArabicSpellChecker |
| except ImportError: |
| from .ara_spell import ArabicSpellChecker as _ArabicSpellChecker |
| ArabicSpellChecker = _ArabicSpellChecker |
|
|
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| logger.info(f"Loading spelling model on {self.device}...") |
| |
| |
| logger.info("Loading tokenizer for spelling model...") |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv02") |
| except Exception as e: |
| logger.warning(f"Could not load aubmindlab/bert-base-arabertv02 tokenizer: {e}") |
| |
| raise RuntimeError(f"Tokenizer load failed: {e}") |
|
|
| |
| |
| |
| config_encoder = BertConfig( |
| vocab_size=64000, |
| hidden_size=768, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| intermediate_size=3072 |
| ) |
| config_decoder = BertConfig( |
| vocab_size=64000, |
| hidden_size=768, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| intermediate_size=3072, |
| is_decoder=True, |
| add_cross_attention=True |
| ) |
| config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) |
| |
| |
| self.model = EncoderDecoderModel(config=config) |
| |
| |
| pt_file = list(self.model_path.glob("*.pt"))[0] |
| logger.info(f"Loading weights from {pt_file}...") |
| checkpoint = torch.load(pt_file, map_location=self.device) |
| |
| if "model_state_dict" in checkpoint: |
| state_dict = checkpoint["model_state_dict"] |
| else: |
| state_dict = checkpoint |
| |
| |
| self.model.load_state_dict(state_dict) |
| |
| self.model.to(self.device) |
| self.model.eval() |
| |
| |
| |
| self.model.config.decoder_start_token_id = self.tokenizer.cls_token_id |
| self.model.config.eos_token_id = self.tokenizer.sep_token_id |
| self.model.config.pad_token_id = self.tokenizer.pad_token_id |
| self.model.config.vocab_size = self.model.config.encoder.vocab_size |
| |
| |
| logger.info("Initializing ArabicSpellChecker engine...") |
| self.engine = ArabicSpellChecker(self.model, self.tokenizer, self.device, use_contextual=True) |
| logger.info("ArabicSpellChecker engine initialized successfully") |
|
|
| logger.info("Spelling model loaded successfully") |
| |
| except Exception as e: |
| logger.error(f"Error loading spelling model: {str(e)}") |
| raise RuntimeError(f"Failed to load spelling model: {str(e)}") |
| |
| def correct(self, text): |
| """ |
| Correct spelling in Arabic text. |
| |
| Args: |
| text: Input Arabic text |
| |
| Returns: |
| str: Corrected text |
| """ |
| if not hasattr(self, 'engine'): |
| |
| logger.warning("Engine not found, reloading spelling model...") |
| self._load_model() |
| |
| try: |
| |
| return self.engine.correct(text) |
| |
| except Exception as e: |
| logger.error(f"Error during spelling correction: {str(e)}") |
| |
| logger.warning("Returning original text due to error.") |
| return text |
|
|
|
|
| class AutocompleteModel: |
| """Wrapper class for the Arabic autocomplete model.""" |
| |
| def __init__(self, model_path=None, lazy=True): |
| """ |
| Initialize the autocomplete model. |
| |
| Args: |
| model_path: Path to the model directory (defaults to AUTOCOMPLETE_PATH) |
| lazy: If True, models will be loaded on first use instead of initialization |
| """ |
| self.model_path = Path(model_path) if model_path else AUTOCOMPLETE_PATH |
| |
| self.bigram_model = None |
| self.ngram_model = None |
| self.unigrams = None |
| self.bigrams = None |
| self._hybrid = None |
| self._gpt_tokenizer = None |
| self._gpt_model = None |
| self.lazy = lazy |
| self.enabled = os.environ.get("LOAD_AUTOCOMPLETE", "false").lower() == "true" |
| |
| if not self.enabled: |
| logger.info("Autocomplete model is disabled (LOAD_AUTOCOMPLETE=false)") |
| return |
|
|
| self._validate_path() |
| if not self.lazy: |
| self._load_model() |
| |
| def _validate_path(self): |
| """Validate that the model path exists.""" |
| if not self.model_path.exists(): |
| raise FileNotFoundError(f"Model path does not exist: {self.model_path}") |
| |
| logger.info(f"Autocomplete model path validated: {self.model_path}") |
| |
| def _load_model(self): |
| """Load GPT-2 autocomplete model using hybrid_module (GPT-2 only, no bigram .pkl).""" |
| try: |
| logger.info("Loading autocomplete models...") |
|
|
| |
| hybrid_path = self.model_path / "hybrid_module.py" |
| if not hybrid_path.exists(): |
| raise FileNotFoundError(f"hybrid_module.py not found at: {hybrid_path}") |
|
|
| spec = importlib.util.spec_from_file_location("autocomplete_hybrid", str(hybrid_path)) |
| if spec is None or spec.loader is None: |
| raise ImportError(f"Could not load spec for hybrid_module from: {hybrid_path}") |
|
|
| hybrid = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(hybrid) |
| self._hybrid = hybrid |
| logger.info("hybrid_module imported successfully") |
|
|
| |
| if hasattr(hybrid, "load_gpt2"): |
| try: |
| logger.info("Loading GPT-2 model for autocomplete (CPU, GPT-2 only mode)...") |
| tokenizer, model = hybrid.load_gpt2() |
| |
| model.to(torch.device("cpu")) |
| self._gpt_tokenizer = tokenizer |
| self._gpt_model = model |
| logger.info("GPT-2 model loaded successfully for autocomplete (CPU)") |
| except Exception as gpt_err: |
| logger.error(f"Failed to load GPT-2 for autocomplete: {gpt_err}") |
| raise |
| else: |
| raise RuntimeError("hybrid_module.load_gpt2 not found; cannot run GPT-2-only autocomplete") |
|
|
| logger.info("Autocomplete GPT-2-only model loaded successfully") |
| except Exception as e: |
| logger.error(f"Error loading autocomplete models: {str(e)}") |
| raise RuntimeError(f"Failed to load autocomplete models: {str(e)}") |
| |
| def predict(self, text, n=5): |
| """ |
| Predict next words for autocomplete. |
| |
| Args: |
| text: Input Arabic text |
| n: Number of suggestions to return |
| |
| Returns: |
| list: List of suggested completions |
| """ |
| if not self.enabled: |
| return [] |
|
|
| if self._hybrid is None or self._gpt_model is None or self._gpt_tokenizer is None: |
| if self.lazy: |
| try: |
| self._load_model() |
| except Exception as e: |
| logger.error(f"Lazy loading of autocomplete failed: {str(e)}") |
| self.enabled = False |
| return [] |
| else: |
| raise RuntimeError("Model not loaded and lazy loading is off") |
| |
| try: |
| |
| if not self._hybrid or self._gpt_model is None or self._gpt_tokenizer is None: |
| raise RuntimeError("GPT-2 autocomplete components not loaded") |
|
|
| logger.info( |
| f"[Autocomplete] predict called | mode=gpt2-only | text_tail='{text[-50:]}'" |
| ) |
|
|
| if not hasattr(self._hybrid, "gpt2_next_token_probs"): |
| raise RuntimeError("hybrid_module.gpt2_next_token_probs not found") |
|
|
| |
| prob_dict = self._hybrid.gpt2_next_token_probs( |
| text, |
| self._gpt_tokenizer, |
| self._gpt_model, |
| top_k=max(n * 2, 10), |
| ) |
|
|
| |
| preds = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True) |
| logger.info(f"[Autocomplete] raw GPT-2 predictions (top 5): {preds[:5]}") |
|
|
| |
| suggestions = [w for (w, _p) in preds[:n] if w] |
| logger.info(f"[Autocomplete] suggestions returned to API: {suggestions}") |
| return suggestions |
| except Exception as e: |
| logger.error(f"Error during autocomplete prediction: {str(e)}") |
| raise RuntimeError(f"Autocomplete prediction failed: {str(e)}") |
|
|
|
|
| class GrammarModel: |
| """Wrapper class for the Arabic grammar correction model.""" |
| |
| def __init__(self, model_path=None): |
| """ |
| Initialize the grammar model. |
| |
| Args: |
| model_path: Path to the model directory (defaults to GRAMMAR_PATH) |
| """ |
| self.model_path = Path(model_path) if model_path else GRAMMAR_PATH |
| self.model = None |
| self.tokenizer = None |
| self.device = None |
| |
| self._validate_path() |
| self._load_model() |
| |
| def _validate_path(self): |
| """Validate that the model path exists and contains required files.""" |
| if not self.model_path.exists(): |
| raise FileNotFoundError(f"Model path does not exist: {self.model_path}") |
| |
| required_files = ['config.json', 'tokenizer.json', 'model.safetensors'] |
| missing_files = [f for f in required_files if not (self.model_path / f).exists()] |
| |
| if missing_files: |
| raise FileNotFoundError(f"Missing required files: {', '.join(missing_files)}") |
| |
| logger.info(f"Grammar model path validated: {self.model_path}") |
| |
| def _load_model(self): |
| """Load the grammar model and tokenizer.""" |
| try: |
| |
| self.device = torch.device('cpu') |
| logger.info("Loading grammar model on CPU (GPU disabled by design)...") |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| str(self.model_path), |
| local_files_only=True |
| ) |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| str(self.model_path), |
| local_files_only=True, |
| trust_remote_code=True, |
| torch_dtype=torch.float32 |
| ) |
| |
| self.model.to(self.device) |
| self.model.eval() |
| |
| logger.info("Grammar model loaded successfully") |
| except Exception as e: |
| logger.error(f"Error loading grammar model: {str(e)}") |
| raise RuntimeError(f"Failed to load grammar model: {str(e)}") |
| |
| def correct(self, text): |
| """ |
| Correct grammar in Arabic text with a timeout. |
| |
| Args: |
| text: Input Arabic text |
| |
| Returns: |
| str: Grammar-corrected text |
| """ |
| if self.model is None or self.tokenizer is None: |
| raise RuntimeError("Model not loaded") |
| |
| import threading |
| result = [text] |
| error = [None] |
| |
| def _generate(): |
| try: |
| |
| messages = [{"role": "user", "content": text}] |
| prompt = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| |
| inputs = self.tokenizer( |
| prompt, |
| max_length=256, |
| truncation=True, |
| padding=True, |
| return_tensors="pt", |
| add_special_tokens=False, |
| ) |
| |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| input_length = inputs['input_ids'].shape[1] |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=64, |
| do_sample=False, |
| ) |
| |
| |
| new_tokens = outputs[0][input_length:] |
| corrected = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
| |
| |
| if not corrected or len(corrected) < 2: |
| return |
| |
| |
| generic_phrases = ( |
| "أعد كتابتها", "أعد كتابته", "أعد كتابتها.", "أعد كتابته.", |
| "اعيد كتابتها", "اعيد كتابته", "أعد كتابة", "اعيد كتابة", |
| "أعد كتابتها فقط", "أعد كتابته فقط", |
| ) |
| corrected_lower = corrected.strip() |
| for phrase in generic_phrases: |
| if phrase in corrected_lower or corrected_lower.startswith(phrase): |
| return |
| |
| |
| corrected_lines = [l.strip() for l in corrected.split('\n') if l.strip()] |
| if corrected_lines: |
| first = corrected_lines[0] |
| if first and len(first) >= 2 and first not in generic_phrases: |
| result[0] = first |
| except Exception as e: |
| error[0] = e |
| |
| |
| thread = threading.Thread(target=_generate) |
| thread.start() |
| thread.join(timeout=30) |
| |
| if thread.is_alive(): |
| logger.warning("[Grammar] Timed out after 30s — returning original text") |
| return text |
| |
| if error[0]: |
| logger.error(f"Error during grammar correction: {str(error[0])}") |
| logger.warning("Returning original text due to grammar error.") |
| |
| return result[0] |
|
|
|
|
| class PunctuationModel: |
| """Wrapper class for the Arabic punctuation model.""" |
| |
| def __init__(self, model_path=None): |
| """ |
| Initialize the punctuation model. |
| |
| Args: |
| model_path: Path to the model directory (defaults to PUNCTUATION_PATH) |
| """ |
| self.model_path = Path(model_path) if model_path else PUNCTUATION_PATH |
| self.model = None |
| self.tokenizer = None |
| self.device = None |
| |
| self._validate_path() |
| self._load_model() |
| |
| def _validate_path(self): |
| """Validate that the model path exists and contains required files.""" |
| if not self.model_path.exists(): |
| raise FileNotFoundError(f"Model path does not exist: {self.model_path}") |
| |
| required_files = ['config.json', 'tokenizer.json', 'model.safetensors'] |
| missing_files = [f for f in required_files if not (self.model_path / f).exists()] |
| |
| if missing_files: |
| raise FileNotFoundError(f"Missing required files: {', '.join(missing_files)}") |
| |
| logger.info(f"Punctuation model path validated: {self.model_path}") |
| |
| def _load_model(self): |
| """Load the punctuation model and tokenizer.""" |
| try: |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| logger.info(f"Loading punctuation model on {self.device}...") |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| str(self.model_path), |
| local_files_only=True |
| ) |
| |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( |
| str(self.model_path), |
| local_files_only=True |
| ) |
| |
| self.model.to(self.device) |
| self.model.eval() |
| |
| logger.info("Punctuation model loaded successfully") |
| except Exception as e: |
| logger.error(f"Error loading punctuation model: {str(e)}") |
| raise RuntimeError(f"Failed to load punctuation model: {str(e)}") |
| |
| def add_punctuation(self, text): |
| """ |
| Add punctuation to Arabic text. |
| |
| Args: |
| text: Input Arabic text without punctuation |
| |
| Returns: |
| str: Text with punctuation added |
| """ |
| if self.model is None or self.tokenizer is None: |
| raise RuntimeError("Model not loaded") |
| |
| try: |
| inputs = self.tokenizer( |
| text, |
| max_length=512, |
| truncation=True, |
| padding=True, |
| return_tensors="pt" |
| ) |
| |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_length=512, |
| decoder_start_token_id=self.tokenizer.cls_token_id, |
| eos_token_id=self.tokenizer.sep_token_id, |
| pad_token_id=self.tokenizer.pad_token_id |
| ) |
| |
| punctuated = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return punctuated.strip() |
| |
| except Exception as e: |
| logger.error(f"Error during punctuation: {str(e)}") |
| logger.warning("Returning original text due to punctuation error.") |
| return text |
|
|