Commit
•
6065b7a
1
Parent(s):
b161793
Fix generation when `repetition_penalty` is activated (#57)
Browse files- make sure input_ids do not contain negative numbers (indicating images) after they are no longer needed (5905c926df4db18660da263a9777998ca66a14fe)
Co-authored-by: Yen-Chun Chen <YenChunChen@users.noreply.huggingface.co>
- image_embedding_phi3_v.py +10 -1
image_embedding_phi3_v.py
CHANGED
@@ -12,6 +12,7 @@
|
|
12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
|
|
15 |
|
16 |
import torch
|
17 |
from torch import nn
|
@@ -191,7 +192,15 @@ class Phi3ImageEmbedding(nn.Module):
|
|
191 |
# positions for image tokens
|
192 |
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
|
193 |
has_image = len(positions[0].tolist()) > 0
|
194 |
-
input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
hidden_states = self.wte(input_ids)
|
196 |
|
197 |
if has_image:
|
|
|
12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
+
import warnings
|
16 |
|
17 |
import torch
|
18 |
from torch import nn
|
|
|
192 |
# positions for image tokens
|
193 |
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
|
194 |
has_image = len(positions[0].tolist()) > 0
|
195 |
+
# input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach()
|
196 |
+
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
|
197 |
+
warnings.warn(
|
198 |
+
"Phi-3-V modifies `input_ids` in-place and the tokens indicating images will be "
|
199 |
+
"removed after model forward. If your workflow requires multiple forward passes on "
|
200 |
+
"the same `input_ids`, please make a copy of `input_ids` before passing it to the "
|
201 |
+
"model."
|
202 |
+
)
|
203 |
+
|
204 |
hidden_states = self.wte(input_ids)
|
205 |
|
206 |
if has_image:
|