Pratyush Maini commited on
Commit
0770dbf
·
1 Parent(s): bc77f98

Revert "Fix: Use public base models that are guaranteed to work (GPT-2, DistilGPT-2, DialoGPT)"

Browse files

This reverts commit bc77f988a44886e25d2c45aad4dfc1b904a8fe94.

Files changed (1) hide show
  1. app.py +23 -33
app.py CHANGED
@@ -9,10 +9,9 @@ os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.huggingface/transformers")
9
 
10
  # Define available base models (for local inference)
11
  model_list = {
12
- "GPT-2": "gpt2",
13
- "GPT-2 Medium": "gpt2-medium",
14
- "DistilGPT-2": "distilgpt2",
15
- "DialoGPT Small": "microsoft/DialoGPT-small",
16
  }
17
 
18
  # Use token from environment variables (HF Spaces) or keys.py (local)
@@ -26,33 +25,22 @@ def load_model(model_name):
26
  """Load model and tokenizer, cache them for reuse"""
27
  if model_name not in model_cache:
28
  print(f"Loading model: {model_name}")
29
- try:
30
- # Try loading with auth token if available
31
- tokenizer = AutoTokenizer.from_pretrained(
32
- model_name,
33
- token=HF_TOKEN if HF_TOKEN else None,
34
- trust_remote_code=True
35
- )
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_name,
38
- torch_dtype=torch.float32, # Use float32 for CPU
39
- device_map="cpu",
40
- low_cpu_mem_usage=True,
41
- token=HF_TOKEN if HF_TOKEN else None,
42
- trust_remote_code=True
43
- )
44
- # Add padding token if it doesn't exist
45
- if tokenizer.pad_token is None:
46
- tokenizer.pad_token = tokenizer.eos_token
47
-
48
- model_cache[model_name] = {
49
- 'tokenizer': tokenizer,
50
- 'model': model
51
- }
52
- print(f"Model {model_name} loaded successfully")
53
- except Exception as e:
54
- print(f"Error loading model {model_name}: {str(e)}")
55
- raise e
56
 
57
  return model_cache[model_name]
58
 
@@ -60,7 +48,7 @@ def load_model(model_name):
60
  def respond(message, history, max_tokens, temperature, top_p, selected_model):
61
  try:
62
  # Get the model ID from the model list
63
- model_id = model_list.get(selected_model, "gpt2")
64
 
65
  # Load the model and tokenizer
66
  try:
@@ -191,6 +179,8 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
191
  </div>
192
  """)
193
 
 
 
194
  with gr.Row():
195
  # Left sidebar: Model selector
196
  with gr.Column(scale=1):
@@ -198,7 +188,7 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
198
  model_dropdown = gr.Dropdown(
199
  choices=list(model_list.keys()),
200
  label="Select Model",
201
- value="GPT-2",
202
  elem_classes=["model-select"]
203
  )
204
  # Quick test prompts for safety testing
 
9
 
10
  # Define available base models (for local inference)
11
  model_list = {
12
+ "SafeLM 1.7B": "locuslab/safelm-1.7b",
13
+ "SmolLM2 1.7B": "HuggingFaceTB/SmolLM2-1.7B",
14
+ "Llama 3.2 1B": "meta-llama/Llama-3.2-1B",
 
15
  }
16
 
17
  # Use token from environment variables (HF Spaces) or keys.py (local)
 
25
  """Load model and tokenizer, cache them for reuse"""
26
  if model_name not in model_cache:
27
  print(f"Loading model: {model_name}")
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_name,
31
+ torch_dtype=torch.float32, # Use float32 for CPU
32
+ device_map="cpu",
33
+ low_cpu_mem_usage=True
34
+ )
35
+ # Add padding token if it doesn't exist
36
+ if tokenizer.pad_token is None:
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+
39
+ model_cache[model_name] = {
40
+ 'tokenizer': tokenizer,
41
+ 'model': model
42
+ }
43
+ print(f"Model {model_name} loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  return model_cache[model_name]
46
 
 
48
  def respond(message, history, max_tokens, temperature, top_p, selected_model):
49
  try:
50
  # Get the model ID from the model list
51
+ model_id = model_list.get(selected_model, "locuslab/safelm-1.7b")
52
 
53
  # Load the model and tokenizer
54
  try:
 
179
  </div>
180
  """)
181
 
182
+ # Status message for local inference
183
+
184
  with gr.Row():
185
  # Left sidebar: Model selector
186
  with gr.Column(scale=1):
 
188
  model_dropdown = gr.Dropdown(
189
  choices=list(model_list.keys()),
190
  label="Select Model",
191
+ value="SafeLM 1.7B",
192
  elem_classes=["model-select"]
193
  )
194
  # Quick test prompts for safety testing