Pratyush Maini commited on
Commit
bc77f98
·
1 Parent(s): 0ac35c2

Fix: Use public base models that are guaranteed to work (GPT-2, DistilGPT-2, DialoGPT)

Browse files
Files changed (1) hide show
  1. app.py +33 -23
app.py CHANGED
@@ -9,9 +9,10 @@ os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.huggingface/transformers")
9
 
10
  # Define available base models (for local inference)
11
  model_list = {
12
- "SafeLM 1.7B": "locuslab/safelm-1.7b",
13
- "SmolLM2 1.7B": "HuggingFaceTB/SmolLM2-1.7B",
14
- "Llama 3.2 1B": "meta-llama/Llama-3.2-1B",
 
15
  }
16
 
17
  # Use token from environment variables (HF Spaces) or keys.py (local)
@@ -25,22 +26,33 @@ def load_model(model_name):
25
  """Load model and tokenizer, cache them for reuse"""
26
  if model_name not in model_cache:
27
  print(f"Loading model: {model_name}")
28
- tokenizer = AutoTokenizer.from_pretrained(model_name)
29
- model = AutoModelForCausalLM.from_pretrained(
30
- model_name,
31
- torch_dtype=torch.float32, # Use float32 for CPU
32
- device_map="cpu",
33
- low_cpu_mem_usage=True
34
- )
35
- # Add padding token if it doesn't exist
36
- if tokenizer.pad_token is None:
37
- tokenizer.pad_token = tokenizer.eos_token
38
-
39
- model_cache[model_name] = {
40
- 'tokenizer': tokenizer,
41
- 'model': model
42
- }
43
- print(f"Model {model_name} loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  return model_cache[model_name]
46
 
@@ -48,7 +60,7 @@ def load_model(model_name):
48
  def respond(message, history, max_tokens, temperature, top_p, selected_model):
49
  try:
50
  # Get the model ID from the model list
51
- model_id = model_list.get(selected_model, "locuslab/safelm-1.7b")
52
 
53
  # Load the model and tokenizer
54
  try:
@@ -179,8 +191,6 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
179
  </div>
180
  """)
181
 
182
- # Status message for local inference
183
-
184
  with gr.Row():
185
  # Left sidebar: Model selector
186
  with gr.Column(scale=1):
@@ -188,7 +198,7 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
188
  model_dropdown = gr.Dropdown(
189
  choices=list(model_list.keys()),
190
  label="Select Model",
191
- value="SafeLM 1.7B",
192
  elem_classes=["model-select"]
193
  )
194
  # Quick test prompts for safety testing
 
9
 
10
  # Define available base models (for local inference)
11
  model_list = {
12
+ "GPT-2": "gpt2",
13
+ "GPT-2 Medium": "gpt2-medium",
14
+ "DistilGPT-2": "distilgpt2",
15
+ "DialoGPT Small": "microsoft/DialoGPT-small",
16
  }
17
 
18
  # Use token from environment variables (HF Spaces) or keys.py (local)
 
26
  """Load model and tokenizer, cache them for reuse"""
27
  if model_name not in model_cache:
28
  print(f"Loading model: {model_name}")
29
+ try:
30
+ # Try loading with auth token if available
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ model_name,
33
+ token=HF_TOKEN if HF_TOKEN else None,
34
+ trust_remote_code=True
35
+ )
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_name,
38
+ torch_dtype=torch.float32, # Use float32 for CPU
39
+ device_map="cpu",
40
+ low_cpu_mem_usage=True,
41
+ token=HF_TOKEN if HF_TOKEN else None,
42
+ trust_remote_code=True
43
+ )
44
+ # Add padding token if it doesn't exist
45
+ if tokenizer.pad_token is None:
46
+ tokenizer.pad_token = tokenizer.eos_token
47
+
48
+ model_cache[model_name] = {
49
+ 'tokenizer': tokenizer,
50
+ 'model': model
51
+ }
52
+ print(f"Model {model_name} loaded successfully")
53
+ except Exception as e:
54
+ print(f"Error loading model {model_name}: {str(e)}")
55
+ raise e
56
 
57
  return model_cache[model_name]
58
 
 
60
  def respond(message, history, max_tokens, temperature, top_p, selected_model):
61
  try:
62
  # Get the model ID from the model list
63
+ model_id = model_list.get(selected_model, "gpt2")
64
 
65
  # Load the model and tokenizer
66
  try:
 
191
  </div>
192
  """)
193
 
 
 
194
  with gr.Row():
195
  # Left sidebar: Model selector
196
  with gr.Column(scale=1):
 
198
  model_dropdown = gr.Dropdown(
199
  choices=list(model_list.keys()),
200
  label="Select Model",
201
+ value="GPT-2",
202
  elem_classes=["model-select"]
203
  )
204
  # Quick test prompts for safety testing