abdull4h commited on
Commit
ff2315d
·
verified ·
1 Parent(s): ca20ed2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -95,16 +95,21 @@ def generate_text(prompt, max_length=100, temperature=0.7, force_arabic=True):
95
  else:
96
  enhanced_prompt = prompt
97
 
98
- # Create input for the model
99
  print(f"Generating response for: {enhanced_prompt[:50]}...")
100
 
101
- # Use direct input approach instead of chat template
102
- input_ids = tokenizer.encode(enhanced_prompt, return_tensors="pt").to(model.device)
 
 
103
 
104
- # Generate with all compiler features disabled
 
 
105
  with torch.inference_mode():
106
  gen_tokens = model.generate(
107
  input_ids,
 
108
  max_new_tokens=int(max_length),
109
  do_sample=True,
110
  temperature=float(temperature),
@@ -113,20 +118,18 @@ def generate_text(prompt, max_length=100, temperature=0.7, force_arabic=True):
113
  pad_token_id=tokenizer.eos_token_id
114
  )
115
 
116
- # Decode full output
117
- full_output = tokenizer.decode(gen_tokens[0], skip_special_tokens=False)
118
- print(f"Raw output: {full_output[:100]}...")
119
 
120
- # Clean the response
121
- cleaned_response = clean_response(full_output)
 
122
 
123
- # If response still starts with the prompt, remove it
124
- if cleaned_response.startswith(enhanced_prompt):
125
- cleaned_response = cleaned_response[len(enhanced_prompt):].strip()
126
- elif cleaned_response.startswith(prompt):
127
- cleaned_response = cleaned_response[len(prompt):].strip()
128
-
129
  print(f"Final cleaned response: {cleaned_response[:100]}...")
 
130
  return cleaned_response
131
 
132
  except Exception as e:
@@ -163,7 +166,7 @@ with gr.Blocks(title="Cohere Arabic Model Demo") as demo:
163
  gr.Button(example).click(
164
  fn=lambda e=example: e,
165
  inputs=[],
166
- outputs=prompt
167
  )
168
 
169
  # Parameters
 
95
  else:
96
  enhanced_prompt = prompt
97
 
98
+ # Create input for the model using proper tokenization with attention mask
99
  print(f"Generating response for: {enhanced_prompt[:50]}...")
100
 
101
+ # Use the tokenizer to get both input_ids and attention_mask
102
+ encoding = tokenizer(enhanced_prompt, return_tensors="pt", padding=True)
103
+ input_ids = encoding.input_ids.to(model.device)
104
+ attention_mask = encoding.attention_mask.to(model.device)
105
 
106
+ print(f"Input shape: {input_ids.shape}, Attention mask shape: {attention_mask.shape}")
107
+
108
+ # Generate with all compiler features disabled and proper attention mask
109
  with torch.inference_mode():
110
  gen_tokens = model.generate(
111
  input_ids,
112
+ attention_mask=attention_mask, # Add attention mask
113
  max_new_tokens=int(max_length),
114
  do_sample=True,
115
  temperature=float(temperature),
 
118
  pad_token_id=tokenizer.eos_token_id
119
  )
120
 
121
+ # Get only the generated part (exclude the prompt)
122
+ input_length = input_ids.shape[1]
123
+ generated_tokens = gen_tokens[0][input_length:]
124
 
125
+ # Decode just the generated part
126
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
127
+ print(f"Generated text (after input): {generated_text[:100]}...")
128
 
129
+ # Clean any remaining special tokens
130
+ cleaned_response = clean_response(generated_text)
 
 
 
 
131
  print(f"Final cleaned response: {cleaned_response[:100]}...")
132
+
133
  return cleaned_response
134
 
135
  except Exception as e:
 
166
  gr.Button(example).click(
167
  fn=lambda e=example: e,
168
  inputs=[],
169
+ outputs=[prompt]
170
  )
171
 
172
  # Parameters