Connexus commited on
Commit
9998352
·
verified ·
1 Parent(s): 549488a

Upload 2 files

Browse files
Files changed (1) hide show
  1. services/grammar_service.py +68 -27
services/grammar_service.py CHANGED
@@ -2,69 +2,110 @@ import os
2
  import nltk
3
  import torch
4
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
5
- from huggingface_hub import snapshot_download # We will use this for a more robust download
6
 
7
  class GrammarService:
 
 
 
 
 
8
  _models = {}
9
- _hf_repo_name = "Connexus/grammar-genie-models"
 
 
10
  _hf_token = os.environ.get("HUGGING_FACE_TOKEN")
11
 
12
  @classmethod
13
  def load_models(cls):
 
 
 
14
  print("="*50)
15
- print(f"BULLETPROOF STARTUP: Loading models from '{cls._hf_repo_name}'...")
16
 
17
- # --- NLTK Setup (already correct) ---
18
- local_nltk_data_path = os.path.join(os.getcwd(), "nltk_data")
19
- if not os.path.exists(local_nltk_data_path):
20
- os.makedirs(local_nltk_data_path)
21
- nltk.data.path.append(local_nltk_data_path)
22
  try:
23
  nltk.data.find('tokenizers/punkt')
24
- print(f" > NLTK 'punkt' tokenizer found.")
25
  except LookupError:
26
- print(f" > NLTK 'punkt' not found. Downloading to: {local_nltk_data_path}")
27
- nltk.download('punkt', download_dir=local_nltk_data_path)
 
 
28
 
29
  supported_languages = ["english", "french"]
30
 
31
  if not cls._hf_token:
32
- print(" > [FATAL ERROR] HUGGING_FACE_TOKEN not set.")
33
  return
34
 
35
  for lang in supported_languages:
36
  model_subfolder = lang
37
  print(f" > Processing model for '{lang}'...")
38
  try:
39
- # --- NEW, MORE ROBUST DOWNLOAD STEP ---
40
- print(f" > Step 1: Downloading all files from subfolder '{model_subfolder}'...")
41
- # snapshot_download is the most reliable way to download a whole folder
42
- # It will use the token and save files to a local cache directory
43
  local_model_dir = snapshot_download(
44
  repo_id=cls._hf_repo_name,
45
- allow_patterns=f"{model_subfolder}/*", # Only download files from this subfolder
46
  use_auth_token=cls._hf_token,
47
  repo_type="model"
48
  )
49
- print(f" > Download complete. Files are cached locally.")
50
 
51
- # --- The pipeline now loads from the local cache, not the internet ---
52
- print(f" > Step 2: Loading pipeline from local cache...")
53
- device_num = 0 if torch.cuda.is_available() else -1
54
- # We point the pipeline to the specific subfolder inside the cache
55
  final_model_path = os.path.join(local_model_dir, model_subfolder)
56
  cls._models[lang] = pipeline(
57
  "text2text-generation",
58
- model=final_model_path, # Load from the specific local directory
59
- device=device_num
60
  )
61
  print(f" > Model for '{lang}' loaded successfully into memory.")
62
 
63
  except Exception as e:
64
- print(f" > [FATAL ERROR] during processing for '{lang}'.")
65
- print(f" > Details: {e}")
66
 
67
  print("Model loading complete.")
68
  print("="*50)
69
 
70
- # ... (correct_paragraph method is unchanged) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import nltk
3
  import torch
4
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
5
+ from huggingface_hub import snapshot_download
6
 
7
  class GrammarService:
8
+ """
9
+ Final, definitive service class.
10
+ - Models are downloaded from a private Hugging Face Hub.
11
+ - NLTK data is expected to be pre-installed by the Docker build process.
12
+ """
13
  _models = {}
14
+
15
+ # --- CONFIGURATION ---
16
+ _hf_repo_name = "Connexus/grammar-genie-models" # Your specific repo name
17
  _hf_token = os.environ.get("HUGGING_FACE_TOKEN")
18
 
19
  @classmethod
20
  def load_models(cls):
21
+ """
22
+ Loads all available models from the private Hugging Face repository into memory.
23
+ """
24
  print("="*50)
25
+ print(f"Final Version Startup: Loading models from '{cls._hf_repo_name}'...")
26
 
27
+ # --- FINAL NLTK SETUP ---
28
+ # The Dockerfile is now responsible for the download.
29
+ # This code just verifies that the data is present where NLTK can find it.
 
 
30
  try:
31
  nltk.data.find('tokenizers/punkt')
32
+ print(" > NLTK 'punkt' tokenizer found successfully.")
33
  except LookupError:
34
+ print(" > [FATAL ERROR] NLTK 'punkt' not found. The Docker build may have failed to download it.")
35
+ # Stop the application if NLTK data is missing, as it cannot function.
36
+ return
37
+ # --- END OF NLTK SETUP ---
38
 
39
  supported_languages = ["english", "french"]
40
 
41
  if not cls._hf_token:
42
+ print(" > [FATAL ERROR] HUGGING_FACE_TOKEN environment variable not set.")
43
  return
44
 
45
  for lang in supported_languages:
46
  model_subfolder = lang
47
  print(f" > Processing model for '{lang}'...")
48
  try:
49
+ print(f" - Step 1: Downloading files from subfolder '{model_subfolder}'...")
 
 
 
50
  local_model_dir = snapshot_download(
51
  repo_id=cls._hf_repo_name,
52
+ allow_patterns=f"{model_subfolder}/*",
53
  use_auth_token=cls._hf_token,
54
  repo_type="model"
55
  )
56
+ print(f" - Download complete.")
57
 
58
+ print(f" - Step 2: Loading pipeline from local cache...")
 
 
 
59
  final_model_path = os.path.join(local_model_dir, model_subfolder)
60
  cls._models[lang] = pipeline(
61
  "text2text-generation",
62
+ model=final_model_path,
63
+ device=-1 # Force CPU
64
  )
65
  print(f" > Model for '{lang}' loaded successfully into memory.")
66
 
67
  except Exception as e:
68
+ print(f" > [FATAL ERROR] during processing for '{lang}'. Details: {e}")
69
+ return
70
 
71
  print("Model loading complete.")
72
  print("="*50)
73
 
74
+ @classmethod
75
+ def correct_paragraph(cls, paragraph: str, language: str) -> str:
76
+ """
77
+ Corrects the grammar of a paragraph for a specified language.
78
+ """
79
+ if language not in cls._models:
80
+ return f"Error: Language '{language}' is not supported or its model failed to load."
81
+
82
+ corrector = cls._models[language]
83
+ sentences = nltk.sent_tokenize(paragraph)
84
+
85
+ if language == 'english':
86
+ prefix = "fix grammatical errors in the following text: "
87
+ elif language == 'french':
88
+ prefix = ""
89
+ else:
90
+ prefix = "correct grammar: "
91
+
92
+ corrected_sentences = []
93
+ for sentence in sentences:
94
+ input_text = f"{prefix}{sentence}"
95
+ try:
96
+ results = corrector(input_text, max_length=256, num_beams=5)
97
+ raw_output = results[0]['generated_text']
98
+
99
+ if prefix and raw_output.startswith(prefix):
100
+ clean_sentence = raw_output.replace(prefix, "", 1).strip()
101
+ else:
102
+ clean_sentence = raw_output.strip()
103
+
104
+ corrected_sentences.append(clean_sentence)
105
+ except Exception as e:
106
+ print(f" > [WARNING] Failed to process a sentence. Using original. Error: {e}")
107
+ corrected_sentences.append(sentence)
108
+
109
+ return " ".join(corrected_sentences)
110
+
111
+ ### Next Steps