abdull4h commited on
Commit
ffd2a10
·
verified ·
1 Parent(s): bc39f18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -54
app.py CHANGED
@@ -1,8 +1,37 @@
1
  import os
 
2
  import gradio as gr
3
  from huggingface_hub import login
4
  import spaces
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  # Model ID
7
  model_id = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
8
 
@@ -15,7 +44,6 @@ else:
15
  print("No HF_TOKEN found. Please set the HF_TOKEN environment variable.")
16
 
17
  # Import libraries at the module level
18
- import torch
19
  from transformers import AutoTokenizer, AutoModelForCausalLM
20
 
21
  # Pre-load tokenizer at module level
@@ -27,50 +55,34 @@ except Exception as e:
27
  print(f"Failed to load tokenizer: {str(e)}")
28
  tokenizer = None
29
 
30
- # To track if model was loaded
31
- model_loaded = False
32
- print(f"Initial model_loaded state: {model_loaded}")
33
-
34
  # Single combined function that handles both loading and generation
35
  @spaces.GPU
36
- def load_and_generate(prompt, max_length=100, temperature=0.3, force_reload=False):
37
- global model_loaded
38
-
39
- # First make sure model is loaded
40
- if not model_loaded or force_reload:
41
- print(f"Loading model (current state: {model_loaded}, force_reload: {force_reload})...")
42
- try:
43
- # Load model with GPU acceleration
44
- model = AutoModelForCausalLM.from_pretrained(
45
- model_id,
46
- token=hf_token,
47
- torch_dtype=torch.float16,
48
- device_map="auto"
49
- )
50
- model_loaded = True
51
- print("Model loaded successfully within the function!")
52
- except Exception as e:
53
- import traceback
54
- error_details = traceback.format_exc()
55
- print(f"Error loading model: {str(e)}\n{error_details}")
56
- return f"Failed to load model: {str(e)}"
57
- else:
58
- print("Model was already loaded")
59
 
60
- # We still need to load the model within this function call due to ZeroGPU isolation
61
- try:
62
- model = AutoModelForCausalLM.from_pretrained(
63
- model_id,
64
- token=hf_token,
65
- torch_dtype=torch.float16,
66
- device_map="auto"
67
- )
68
- print("Model reloaded for this function call")
69
- except Exception as e:
70
- print(f"Error reloading model: {str(e)}")
71
- return f"Error reloading model: {str(e)}"
 
 
 
 
 
 
72
 
73
- # Now generate text with the loaded model
74
  if not prompt.strip():
75
  return "Please enter a prompt."
76
 
@@ -89,14 +101,21 @@ def load_and_generate(prompt, max_length=100, temperature=0.3, force_reload=Fals
89
  # Move to model device
90
  input_ids = input_ids.to(model.device)
91
 
92
- # Generate
93
- gen_tokens = model.generate(
94
- input_ids,
95
- max_new_tokens=int(max_length),
96
- do_sample=True if temperature > 0 else False,
97
- temperature=float(temperature) if temperature > 0 else None,
98
- top_p=0.95 if temperature > 0 else None
99
- )
 
 
 
 
 
 
 
100
 
101
  # Decode and return
102
  gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
@@ -141,10 +160,11 @@ with gr.Blocks(title="Cohere Arabic Model Demo") as demo:
141
  with gr.Row():
142
  for example in example_prompts[i:i+2]:
143
  if example: # Make sure example exists
144
- def create_click_handler(ex):
 
145
  return lambda: ex
146
  gr.Button(example).click(
147
- fn=create_click_handler(example),
148
  inputs=[],
149
  outputs=[prompt]
150
  )
@@ -153,7 +173,6 @@ with gr.Blocks(title="Cohere Arabic Model Demo") as demo:
153
  with gr.Accordion("Parameters", open=False):
154
  max_tokens = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max New Tokens")
155
  temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Temperature")
156
- force_reload = gr.Checkbox(label="Force reload model (use only if needed)", value=False)
157
 
158
  # Action buttons
159
  with gr.Row():
@@ -166,8 +185,8 @@ with gr.Blocks(title="Cohere Arabic Model Demo") as demo:
166
 
167
  # Set up event handlers
168
  submit_btn.click(
169
- fn=load_and_generate,
170
- inputs=[prompt, max_tokens, temp, force_reload],
171
  outputs=[output]
172
  )
173
  clear_btn.click(fn=lambda: "", inputs=[], outputs=[prompt, output])
 
1
  import os
2
+ import sys
3
  import gradio as gr
4
  from huggingface_hub import login
5
  import spaces
6
 
7
+ # CRITICAL: Disable PyTorch compiler BEFORE importing torch
8
+ os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
9
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
10
+ os.environ["TORCH_INDUCTOR_DISABLE"] = "1"
11
+ os.environ["TORCHINDUCTOR_DISABLE_CUDAGRAPHS"] = "1"
12
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
13
+ os.environ["TORCH_USE_CUDA_DSA"] = "0"
14
+
15
+ # Now import torch and disable its compiler features
16
+ import torch
17
+ if hasattr(torch, "_dynamo"):
18
+ if hasattr(torch._dynamo, "config"):
19
+ torch._dynamo.config.suppress_errors = True
20
+ if hasattr(torch._dynamo, "disable"):
21
+ torch._dynamo.disable()
22
+ print("Disabled torch._dynamo")
23
+
24
+ # Disable JIT functionality safely
25
+ if hasattr(torch, "_C") and hasattr(torch._C, "_jit_set_profiling_executor"):
26
+ torch._C._jit_set_profiling_executor(False)
27
+ print("Disabled JIT profiling executor")
28
+ if hasattr(torch, "_C") and hasattr(torch._C, "_jit_set_profiling_mode"):
29
+ torch._C._jit_set_profiling_mode(False)
30
+ print("Disabled JIT profiling mode")
31
+ if hasattr(torch, "_C") and hasattr(torch._C, "_set_graph_executor_optimize"):
32
+ torch._C._set_graph_executor_optimize(False)
33
+ print("Disabled graph executor optimization")
34
+
35
  # Model ID
36
  model_id = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
37
 
 
44
  print("No HF_TOKEN found. Please set the HF_TOKEN environment variable.")
45
 
46
  # Import libraries at the module level
 
47
  from transformers import AutoTokenizer, AutoModelForCausalLM
48
 
49
  # Pre-load tokenizer at module level
 
55
  print(f"Failed to load tokenizer: {str(e)}")
56
  tokenizer = None
57
 
 
 
 
 
58
  # Single combined function that handles both loading and generation
59
  @spaces.GPU
60
+ def generate_text(prompt, max_length=100, temperature=0.3):
61
+ # Load model with compiler disabled
62
+ try:
63
+ # Configure the model loading to avoid compiler
64
+ print("Loading model with compiler disabled...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # Load model with no optimizations
67
+ model = AutoModelForCausalLM.from_pretrained(
68
+ model_id,
69
+ token=hf_token,
70
+ torch_dtype=torch.float16,
71
+ device_map="auto",
72
+ # Disable features that might trigger compiler
73
+ use_cache=True,
74
+ use_flash_attention_2=False,
75
+ _attn_implementation="eager"
76
+ )
77
+ print(f"Model loaded successfully on {next(model.parameters()).device}")
78
+
79
+ except Exception as e:
80
+ import traceback
81
+ error_details = traceback.format_exc()
82
+ print(f"Error loading model: {str(e)}\n{error_details}")
83
+ return f"Failed to load model: {str(e)}"
84
 
85
+ # Generate text with the loaded model
86
  if not prompt.strip():
87
  return "Please enter a prompt."
88
 
 
101
  # Move to model device
102
  input_ids = input_ids.to(model.device)
103
 
104
+ # Generate with compiler completely disabled
105
+ with torch.inference_mode():
106
+ # Force eager execution
107
+ torch._C._jit_override_can_fuse_on_cpu(False)
108
+ torch._C._jit_override_can_fuse_on_gpu(False)
109
+
110
+ # Safe generation
111
+ gen_tokens = model.generate(
112
+ input_ids,
113
+ max_new_tokens=int(max_length),
114
+ do_sample=True if temperature > 0 else False,
115
+ temperature=float(temperature) if temperature > 0 else None,
116
+ top_p=0.95 if temperature > 0 else None,
117
+ use_cache=True
118
+ )
119
 
120
  # Decode and return
121
  gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
 
160
  with gr.Row():
161
  for example in example_prompts[i:i+2]:
162
  if example: # Make sure example exists
163
+ # This is a workaround for closure binding in loops
164
+ def make_click_handler(ex):
165
  return lambda: ex
166
  gr.Button(example).click(
167
+ fn=make_click_handler(example),
168
  inputs=[],
169
  outputs=[prompt]
170
  )
 
173
  with gr.Accordion("Parameters", open=False):
174
  max_tokens = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max New Tokens")
175
  temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Temperature")
 
176
 
177
  # Action buttons
178
  with gr.Row():
 
185
 
186
  # Set up event handlers
187
  submit_btn.click(
188
+ fn=generate_text,
189
+ inputs=[prompt, max_tokens, temp],
190
  outputs=[output]
191
  )
192
  clear_btn.click(fn=lambda: "", inputs=[], outputs=[prompt, output])