PearlIsa commited on
Commit
1d35a0c
·
verified ·
1 Parent(s): ec1b642

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -47
app.py CHANGED
@@ -82,47 +82,23 @@ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
82
  logging.basicConfig(level=logging.INFO)
83
  logger = logging.getLogger(__name__)
84
 
85
- # Define the path for the zipped model
86
- model_zip_path = "./checkpoint-500.zip"
87
- extracted_model_dir = "./checkpoint-500"
88
-
89
- # Unzip the file if it’s not already extracted
90
- if not os.path.exists(extracted_model_dir):
91
- with zipfile.ZipFile(model_zip_path, 'r') as zip_ref:
92
- zip_ref.extractall(extracted_model_dir)
93
-
94
- # Load the model from the extracted directory
95
- self.model = AutoModelForCausalLM.from_pretrained(
96
- extracted_model_dir,
97
- device_map="auto",
98
- load_in_8bit=True,
99
- torch_dtype=torch.float16,
100
- low_cpu_mem_usage=True
101
- )
102
 
103
  class ModelManager:
104
  """Handles model loading and resource management"""
105
 
106
  @staticmethod
107
- def verify_model_path(checkpoint_path: str) -> str:
108
- """Verify and return valid model path"""
109
- if os.path.exists(checkpoint_path):
110
- return checkpoint_path
 
 
 
 
 
111
 
112
- alternate_paths = [
113
- f"{os.getcwd()}/checkpoint-500.zip",
114
- "./checkpoint-500.zip",
115
- "../checkpoint-500.zip"
116
- ]
117
-
118
- for path in alternate_paths:
119
- if os.path.exists(path):
120
- return path
121
-
122
- raise FileNotFoundError(
123
- f"Model checkpoint not found in any of these locations: "
124
- f"{[checkpoint_path] + alternate_paths}"
125
- )
126
 
127
  @staticmethod
128
  def clear_gpu_memory():
@@ -132,8 +108,9 @@ class ModelManager:
132
  gc.collect()
133
 
134
  class PearlyBot:
135
- def __init__(self, model_path: str = "./checkpoint-500.zip"):
136
- self.setup_model(model_path)
 
137
  self.setup_rag()
138
  self.conversation_history = []
139
  self.last_interaction_time = time.time()
@@ -145,17 +122,9 @@ class PearlyBot:
145
  logger.info("Starting model initialization...")
146
  ModelManager.clear_gpu_memory()
147
 
148
- # Verify model path
149
- verified_path = ModelManager.verify_model_path(model_path)
150
- logger.info(f"Using model checkpoint from: {verified_path}")
151
-
152
- # Base model configuration
153
- base_model_id = "google/gemma-2b"
154
- logger.info(f"Loading base model: {base_model_id}")
155
-
156
  # Load tokenizer
157
  try:
158
- self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)
159
  self.tokenizer.pad_token = self.tokenizer.eos_token
160
  logger.info("Tokenizer loaded successfully")
161
  except Exception as e:
@@ -165,7 +134,7 @@ class PearlyBot:
165
  # Load model
166
  try:
167
  self.model = AutoModelForCausalLM.from_pretrained(
168
- verified_path,
169
  device_map="auto",
170
  load_in_8bit=True,
171
  torch_dtype=torch.float16,
 
82
  logging.basicConfig(level=logging.INFO)
83
  logger = logging.getLogger(__name__)
84
 
85
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  class ModelManager:
88
  """Handles model loading and resource management"""
89
 
90
  @staticmethod
91
+ def verify_and_extract_model(checkpoint_zip_path: str, extracted_model_dir: str) -> str:
92
+ """Verify and extract the model if it's not already extracted"""
93
+ if not os.path.exists(extracted_model_dir):
94
+ # Unzip the model if it hasn’t been extracted yet
95
+ with zipfile.ZipFile(checkpoint_zip_path, 'r') as zip_ref:
96
+ zip_ref.extractall(extracted_model_dir)
97
+ logger.info(f"Extracted model to: {extracted_model_dir}")
98
+ else:
99
+ logger.info(f"Model already extracted: {extracted_model_dir}")
100
 
101
+ return extracted_model_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  @staticmethod
104
  def clear_gpu_memory():
 
108
  gc.collect()
109
 
110
  class PearlyBot:
111
+ def __init__(self, model_zip_path: str = "./checkpoint-500.zip", model_dir: str = "./checkpoint-500"):
112
+ self.model_dir = ModelManager.verify_and_extract_model(model_zip_path, model_dir)
113
+ self.setup_model(self.model_dir)
114
  self.setup_rag()
115
  self.conversation_history = []
116
  self.last_interaction_time = time.time()
 
122
  logger.info("Starting model initialization...")
123
  ModelManager.clear_gpu_memory()
124
 
 
 
 
 
 
 
 
 
125
  # Load tokenizer
126
  try:
127
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
128
  self.tokenizer.pad_token = self.tokenizer.eos_token
129
  logger.info("Tokenizer loaded successfully")
130
  except Exception as e:
 
134
  # Load model
135
  try:
136
  self.model = AutoModelForCausalLM.from_pretrained(
137
+ model_path,
138
  device_map="auto",
139
  load_in_8bit=True,
140
  torch_dtype=torch.float16,