Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -95,16 +95,21 @@ def generate_text(prompt, max_length=100, temperature=0.7, force_arabic=True):
|
|
| 95 |
else:
|
| 96 |
enhanced_prompt = prompt
|
| 97 |
|
| 98 |
-
# Create input for the model
|
| 99 |
print(f"Generating response for: {enhanced_prompt[:50]}...")
|
| 100 |
|
| 101 |
-
# Use
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
with torch.inference_mode():
|
| 106 |
gen_tokens = model.generate(
|
| 107 |
input_ids,
|
|
|
|
| 108 |
max_new_tokens=int(max_length),
|
| 109 |
do_sample=True,
|
| 110 |
temperature=float(temperature),
|
|
@@ -113,20 +118,18 @@ def generate_text(prompt, max_length=100, temperature=0.7, force_arabic=True):
|
|
| 113 |
pad_token_id=tokenizer.eos_token_id
|
| 114 |
)
|
| 115 |
|
| 116 |
-
#
|
| 117 |
-
|
| 118 |
-
|
| 119 |
|
| 120 |
-
#
|
| 121 |
-
|
|
|
|
| 122 |
|
| 123 |
-
#
|
| 124 |
-
|
| 125 |
-
cleaned_response = cleaned_response[len(enhanced_prompt):].strip()
|
| 126 |
-
elif cleaned_response.startswith(prompt):
|
| 127 |
-
cleaned_response = cleaned_response[len(prompt):].strip()
|
| 128 |
-
|
| 129 |
print(f"Final cleaned response: {cleaned_response[:100]}...")
|
|
|
|
| 130 |
return cleaned_response
|
| 131 |
|
| 132 |
except Exception as e:
|
|
@@ -163,7 +166,7 @@ with gr.Blocks(title="Cohere Arabic Model Demo") as demo:
|
|
| 163 |
gr.Button(example).click(
|
| 164 |
fn=lambda e=example: e,
|
| 165 |
inputs=[],
|
| 166 |
-
outputs=prompt
|
| 167 |
)
|
| 168 |
|
| 169 |
# Parameters
|
|
|
|
| 95 |
else:
|
| 96 |
enhanced_prompt = prompt
|
| 97 |
|
| 98 |
+
# Create input for the model using proper tokenization with attention mask
|
| 99 |
print(f"Generating response for: {enhanced_prompt[:50]}...")
|
| 100 |
|
| 101 |
+
# Use the tokenizer to get both input_ids and attention_mask
|
| 102 |
+
encoding = tokenizer(enhanced_prompt, return_tensors="pt", padding=True)
|
| 103 |
+
input_ids = encoding.input_ids.to(model.device)
|
| 104 |
+
attention_mask = encoding.attention_mask.to(model.device)
|
| 105 |
|
| 106 |
+
print(f"Input shape: {input_ids.shape}, Attention mask shape: {attention_mask.shape}")
|
| 107 |
+
|
| 108 |
+
# Generate with all compiler features disabled and proper attention mask
|
| 109 |
with torch.inference_mode():
|
| 110 |
gen_tokens = model.generate(
|
| 111 |
input_ids,
|
| 112 |
+
attention_mask=attention_mask, # Add attention mask
|
| 113 |
max_new_tokens=int(max_length),
|
| 114 |
do_sample=True,
|
| 115 |
temperature=float(temperature),
|
|
|
|
| 118 |
pad_token_id=tokenizer.eos_token_id
|
| 119 |
)
|
| 120 |
|
| 121 |
+
# Get only the generated part (exclude the prompt)
|
| 122 |
+
input_length = input_ids.shape[1]
|
| 123 |
+
generated_tokens = gen_tokens[0][input_length:]
|
| 124 |
|
| 125 |
+
# Decode just the generated part
|
| 126 |
+
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 127 |
+
print(f"Generated text (after input): {generated_text[:100]}...")
|
| 128 |
|
| 129 |
+
# Clean any remaining special tokens
|
| 130 |
+
cleaned_response = clean_response(generated_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
print(f"Final cleaned response: {cleaned_response[:100]}...")
|
| 132 |
+
|
| 133 |
return cleaned_response
|
| 134 |
|
| 135 |
except Exception as e:
|
|
|
|
| 166 |
gr.Button(example).click(
|
| 167 |
fn=lambda e=example: e,
|
| 168 |
inputs=[],
|
| 169 |
+
outputs=[prompt]
|
| 170 |
)
|
| 171 |
|
| 172 |
# Parameters
|