Upload folder using huggingface_hub
Browse files- modeling_internlm2.py +4 -181
- modeling_internvl_chat.py +2 -2
modeling_internlm2.py
CHANGED
@@ -39,18 +39,6 @@ try:
|
|
39 |
from transformers.generation.streamers import BaseStreamer
|
40 |
except: # noqa # pylint: disable=bare-except
|
41 |
BaseStreamer = None
|
42 |
-
from typing import Any, List, Optional, Tuple, Union
|
43 |
-
import torch.distributed as dist
|
44 |
-
import torch.utils.checkpoint
|
45 |
-
from torch import nn
|
46 |
-
from torch.nn import CrossEntropyLoss
|
47 |
-
from transformers.generation.logits_process import LogitsProcessorList
|
48 |
-
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
49 |
-
from transformers.generation.streamers import BaseStreamer
|
50 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
51 |
-
from transformers.modeling_utils import PreTrainedModel
|
52 |
-
from transformers.utils import ModelOutput, logging
|
53 |
-
from transformers.generation.utils import GreedySearchOutput, validate_stopping_criteria, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput
|
54 |
|
55 |
from .configuration_internlm2 import InternLM2Config
|
56 |
|
@@ -1094,13 +1082,16 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
1094 |
output = (logits,) + outputs[1:]
|
1095 |
return (loss,) + output if loss is not None else output
|
1096 |
|
1097 |
-
|
|
|
1098 |
loss=loss,
|
1099 |
logits=logits,
|
1100 |
past_key_values=outputs.past_key_values,
|
1101 |
hidden_states=outputs.hidden_states,
|
1102 |
attentions=outputs.attentions,
|
1103 |
)
|
|
|
|
|
1104 |
|
1105 |
def prepare_inputs_for_generation(
|
1106 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
@@ -1284,174 +1275,6 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
1284 |
|
1285 |
return consumer()
|
1286 |
|
1287 |
-
def greedy_search(
|
1288 |
-
self,
|
1289 |
-
input_ids: torch.LongTensor,
|
1290 |
-
logits_processor: Optional[LogitsProcessorList] = None,
|
1291 |
-
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1292 |
-
max_length: Optional[int] = None,
|
1293 |
-
pad_token_id: Optional[int] = None,
|
1294 |
-
eos_token_id: Optional[Union[int, List[int]]] = None,
|
1295 |
-
output_attentions: Optional[bool] = None,
|
1296 |
-
output_hidden_states: Optional[bool] = None,
|
1297 |
-
output_scores: Optional[bool] = None,
|
1298 |
-
return_dict_in_generate: Optional[bool] = None,
|
1299 |
-
synced_gpus: bool = False,
|
1300 |
-
streamer: Optional["BaseStreamer"] = None,
|
1301 |
-
**model_kwargs,
|
1302 |
-
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
1303 |
-
# init values
|
1304 |
-
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
1305 |
-
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1306 |
-
if max_length is not None:
|
1307 |
-
warnings.warn(
|
1308 |
-
"`max_length` is deprecated in this function, use"
|
1309 |
-
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
1310 |
-
UserWarning,
|
1311 |
-
)
|
1312 |
-
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
1313 |
-
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
1314 |
-
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
1315 |
-
if isinstance(eos_token_id, int):
|
1316 |
-
eos_token_id = [eos_token_id]
|
1317 |
-
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
1318 |
-
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
1319 |
-
output_attentions = (
|
1320 |
-
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
1321 |
-
)
|
1322 |
-
output_hidden_states = (
|
1323 |
-
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
1324 |
-
)
|
1325 |
-
return_dict_in_generate = (
|
1326 |
-
return_dict_in_generate
|
1327 |
-
if return_dict_in_generate is not None
|
1328 |
-
else self.generation_config.return_dict_in_generate
|
1329 |
-
)
|
1330 |
-
|
1331 |
-
# init attention / hidden states / scores tuples
|
1332 |
-
scores = () if (return_dict_in_generate and output_scores) else None
|
1333 |
-
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
1334 |
-
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
1335 |
-
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
1336 |
-
|
1337 |
-
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
1338 |
-
if return_dict_in_generate and self.config.is_encoder_decoder:
|
1339 |
-
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
1340 |
-
encoder_hidden_states = (
|
1341 |
-
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
1342 |
-
)
|
1343 |
-
|
1344 |
-
# keep track of which sequences are already finished
|
1345 |
-
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
1346 |
-
|
1347 |
-
this_peer_finished = False # used by synced_gpus only
|
1348 |
-
while True:
|
1349 |
-
if synced_gpus:
|
1350 |
-
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
1351 |
-
# The following logic allows an early break if all peers finished generating their sequence
|
1352 |
-
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
1353 |
-
# send 0.0 if we finished, 1.0 otherwise
|
1354 |
-
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
1355 |
-
# did all peers finish? the reduced sum will be 0.0 then
|
1356 |
-
if this_peer_finished_flag.item() == 0.0:
|
1357 |
-
break
|
1358 |
-
|
1359 |
-
# prepare model inputs
|
1360 |
-
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1361 |
-
|
1362 |
-
# forward pass to get next token
|
1363 |
-
outputs = self(
|
1364 |
-
**model_inputs,
|
1365 |
-
return_dict=True,
|
1366 |
-
output_attentions=output_attentions,
|
1367 |
-
output_hidden_states=output_hidden_states,
|
1368 |
-
)
|
1369 |
-
|
1370 |
-
if synced_gpus and this_peer_finished:
|
1371 |
-
continue # don't waste resources running the code we don't need
|
1372 |
-
|
1373 |
-
next_token_logits = outputs.logits[:, -1, :]
|
1374 |
-
|
1375 |
-
# pre-process distribution
|
1376 |
-
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
1377 |
-
|
1378 |
-
# Store scores, attentions and hidden_states when required
|
1379 |
-
if return_dict_in_generate:
|
1380 |
-
if output_scores:
|
1381 |
-
scores += (next_tokens_scores,)
|
1382 |
-
if output_attentions:
|
1383 |
-
decoder_attentions += (
|
1384 |
-
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
1385 |
-
)
|
1386 |
-
if self.config.is_encoder_decoder:
|
1387 |
-
cross_attentions += (outputs.cross_attentions,)
|
1388 |
-
|
1389 |
-
if output_hidden_states:
|
1390 |
-
decoder_hidden_states += (
|
1391 |
-
(outputs.decoder_hidden_states,)
|
1392 |
-
if self.config.is_encoder_decoder
|
1393 |
-
else (outputs.hidden_states,)
|
1394 |
-
)
|
1395 |
-
|
1396 |
-
# argmax
|
1397 |
-
next_tokens = torch.argmax(next_tokens_scores, dim=-1).to(device=input_ids.device)
|
1398 |
-
# finished sentences should have their next token be a padding token
|
1399 |
-
if eos_token_id is not None:
|
1400 |
-
if pad_token_id is None:
|
1401 |
-
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
1402 |
-
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
1403 |
-
|
1404 |
-
# update generated ids, model inputs, and length for next step
|
1405 |
-
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
1406 |
-
if streamer is not None:
|
1407 |
-
streamer.put(next_tokens.cpu())
|
1408 |
-
model_kwargs = self._update_model_kwargs_for_generation(
|
1409 |
-
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1410 |
-
)
|
1411 |
-
|
1412 |
-
# if eos_token was found in one sentence, set sentence to finished
|
1413 |
-
if eos_token_id_tensor is not None:
|
1414 |
-
unfinished_sequences = unfinished_sequences.mul(
|
1415 |
-
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
1416 |
-
)
|
1417 |
-
|
1418 |
-
# stop when each sentence is finished
|
1419 |
-
if unfinished_sequences.max() == 0:
|
1420 |
-
this_peer_finished = True
|
1421 |
-
|
1422 |
-
# stop if we exceed the maximum length
|
1423 |
-
if stopping_criteria(input_ids, scores):
|
1424 |
-
this_peer_finished = True
|
1425 |
-
|
1426 |
-
if this_peer_finished and not synced_gpus:
|
1427 |
-
break
|
1428 |
-
|
1429 |
-
if streamer is not None:
|
1430 |
-
streamer.end()
|
1431 |
-
|
1432 |
-
if return_dict_in_generate:
|
1433 |
-
if self.config.is_encoder_decoder:
|
1434 |
-
return GreedySearchEncoderDecoderOutput(
|
1435 |
-
sequences=input_ids,
|
1436 |
-
scores=scores,
|
1437 |
-
encoder_attentions=encoder_attentions,
|
1438 |
-
encoder_hidden_states=encoder_hidden_states,
|
1439 |
-
decoder_attentions=decoder_attentions,
|
1440 |
-
cross_attentions=cross_attentions,
|
1441 |
-
decoder_hidden_states=decoder_hidden_states,
|
1442 |
-
past_key_values=model_kwargs.get("past_key_values"),
|
1443 |
-
)
|
1444 |
-
else:
|
1445 |
-
return GreedySearchDecoderOnlyOutput(
|
1446 |
-
sequences=input_ids,
|
1447 |
-
scores=scores,
|
1448 |
-
attentions=decoder_attentions,
|
1449 |
-
hidden_states=decoder_hidden_states,
|
1450 |
-
past_key_values=model_kwargs.get("past_key_values"),
|
1451 |
-
)
|
1452 |
-
else:
|
1453 |
-
return input_ids
|
1454 |
-
|
1455 |
|
1456 |
# Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
|
1457 |
@add_start_docstrings(
|
|
|
39 |
from transformers.generation.streamers import BaseStreamer
|
40 |
except: # noqa # pylint: disable=bare-except
|
41 |
BaseStreamer = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
from .configuration_internlm2 import InternLM2Config
|
44 |
|
|
|
1082 |
output = (logits,) + outputs[1:]
|
1083 |
return (loss,) + output if loss is not None else output
|
1084 |
|
1085 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1086 |
+
output = CausalLMOutputWithPast(
|
1087 |
loss=loss,
|
1088 |
logits=logits,
|
1089 |
past_key_values=outputs.past_key_values,
|
1090 |
hidden_states=outputs.hidden_states,
|
1091 |
attentions=outputs.attentions,
|
1092 |
)
|
1093 |
+
output['logits'] = output['logits'].to(device)
|
1094 |
+
return output
|
1095 |
|
1096 |
def prepare_inputs_for_generation(
|
1097 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
|
|
1275 |
|
1276 |
return consumer()
|
1277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1278 |
|
1279 |
# Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
|
1280 |
@add_start_docstrings(
|
modeling_internvl_chat.py
CHANGED
@@ -313,8 +313,8 @@ class InternVLChatModel(PreTrainedModel):
|
|
313 |
if return_history:
|
314 |
return response, history
|
315 |
else:
|
316 |
-
|
317 |
-
|
318 |
return response
|
319 |
return response
|
320 |
|
|
|
313 |
if return_history:
|
314 |
return response, history
|
315 |
else:
|
316 |
+
query_to_print = query.replace(image_tokens, '<image>')
|
317 |
+
print(query_to_print, response)
|
318 |
return response
|
319 |
return response
|
320 |
|