czczup commited on
Commit
44e41e8
1 Parent(s): 0fbe95a

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. modeling_internlm2.py +4 -181
  2. modeling_internvl_chat.py +2 -2
modeling_internlm2.py CHANGED
@@ -39,18 +39,6 @@ try:
39
  from transformers.generation.streamers import BaseStreamer
40
  except: # noqa # pylint: disable=bare-except
41
  BaseStreamer = None
42
- from typing import Any, List, Optional, Tuple, Union
43
- import torch.distributed as dist
44
- import torch.utils.checkpoint
45
- from torch import nn
46
- from torch.nn import CrossEntropyLoss
47
- from transformers.generation.logits_process import LogitsProcessorList
48
- from transformers.generation.stopping_criteria import StoppingCriteriaList
49
- from transformers.generation.streamers import BaseStreamer
50
- from transformers.modeling_outputs import CausalLMOutputWithPast
51
- from transformers.modeling_utils import PreTrainedModel
52
- from transformers.utils import ModelOutput, logging
53
- from transformers.generation.utils import GreedySearchOutput, validate_stopping_criteria, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput
54
 
55
  from .configuration_internlm2 import InternLM2Config
56
 
@@ -1094,13 +1082,16 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1094
  output = (logits,) + outputs[1:]
1095
  return (loss,) + output if loss is not None else output
1096
 
1097
- return CausalLMOutputWithPast(
 
1098
  loss=loss,
1099
  logits=logits,
1100
  past_key_values=outputs.past_key_values,
1101
  hidden_states=outputs.hidden_states,
1102
  attentions=outputs.attentions,
1103
  )
 
 
1104
 
1105
  def prepare_inputs_for_generation(
1106
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
@@ -1284,174 +1275,6 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1284
 
1285
  return consumer()
1286
 
1287
- def greedy_search(
1288
- self,
1289
- input_ids: torch.LongTensor,
1290
- logits_processor: Optional[LogitsProcessorList] = None,
1291
- stopping_criteria: Optional[StoppingCriteriaList] = None,
1292
- max_length: Optional[int] = None,
1293
- pad_token_id: Optional[int] = None,
1294
- eos_token_id: Optional[Union[int, List[int]]] = None,
1295
- output_attentions: Optional[bool] = None,
1296
- output_hidden_states: Optional[bool] = None,
1297
- output_scores: Optional[bool] = None,
1298
- return_dict_in_generate: Optional[bool] = None,
1299
- synced_gpus: bool = False,
1300
- streamer: Optional["BaseStreamer"] = None,
1301
- **model_kwargs,
1302
- ) -> Union[GreedySearchOutput, torch.LongTensor]:
1303
- # init values
1304
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1305
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1306
- if max_length is not None:
1307
- warnings.warn(
1308
- "`max_length` is deprecated in this function, use"
1309
- " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
1310
- UserWarning,
1311
- )
1312
- stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1313
- pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
1314
- eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
1315
- if isinstance(eos_token_id, int):
1316
- eos_token_id = [eos_token_id]
1317
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1318
- output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
1319
- output_attentions = (
1320
- output_attentions if output_attentions is not None else self.generation_config.output_attentions
1321
- )
1322
- output_hidden_states = (
1323
- output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
1324
- )
1325
- return_dict_in_generate = (
1326
- return_dict_in_generate
1327
- if return_dict_in_generate is not None
1328
- else self.generation_config.return_dict_in_generate
1329
- )
1330
-
1331
- # init attention / hidden states / scores tuples
1332
- scores = () if (return_dict_in_generate and output_scores) else None
1333
- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1334
- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1335
- decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
1336
-
1337
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
1338
- if return_dict_in_generate and self.config.is_encoder_decoder:
1339
- encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
1340
- encoder_hidden_states = (
1341
- model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
1342
- )
1343
-
1344
- # keep track of which sequences are already finished
1345
- unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1346
-
1347
- this_peer_finished = False # used by synced_gpus only
1348
- while True:
1349
- if synced_gpus:
1350
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1351
- # The following logic allows an early break if all peers finished generating their sequence
1352
- this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
1353
- # send 0.0 if we finished, 1.0 otherwise
1354
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1355
- # did all peers finish? the reduced sum will be 0.0 then
1356
- if this_peer_finished_flag.item() == 0.0:
1357
- break
1358
-
1359
- # prepare model inputs
1360
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1361
-
1362
- # forward pass to get next token
1363
- outputs = self(
1364
- **model_inputs,
1365
- return_dict=True,
1366
- output_attentions=output_attentions,
1367
- output_hidden_states=output_hidden_states,
1368
- )
1369
-
1370
- if synced_gpus and this_peer_finished:
1371
- continue # don't waste resources running the code we don't need
1372
-
1373
- next_token_logits = outputs.logits[:, -1, :]
1374
-
1375
- # pre-process distribution
1376
- next_tokens_scores = logits_processor(input_ids, next_token_logits)
1377
-
1378
- # Store scores, attentions and hidden_states when required
1379
- if return_dict_in_generate:
1380
- if output_scores:
1381
- scores += (next_tokens_scores,)
1382
- if output_attentions:
1383
- decoder_attentions += (
1384
- (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
1385
- )
1386
- if self.config.is_encoder_decoder:
1387
- cross_attentions += (outputs.cross_attentions,)
1388
-
1389
- if output_hidden_states:
1390
- decoder_hidden_states += (
1391
- (outputs.decoder_hidden_states,)
1392
- if self.config.is_encoder_decoder
1393
- else (outputs.hidden_states,)
1394
- )
1395
-
1396
- # argmax
1397
- next_tokens = torch.argmax(next_tokens_scores, dim=-1).to(device=input_ids.device)
1398
- # finished sentences should have their next token be a padding token
1399
- if eos_token_id is not None:
1400
- if pad_token_id is None:
1401
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1402
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1403
-
1404
- # update generated ids, model inputs, and length for next step
1405
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1406
- if streamer is not None:
1407
- streamer.put(next_tokens.cpu())
1408
- model_kwargs = self._update_model_kwargs_for_generation(
1409
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1410
- )
1411
-
1412
- # if eos_token was found in one sentence, set sentence to finished
1413
- if eos_token_id_tensor is not None:
1414
- unfinished_sequences = unfinished_sequences.mul(
1415
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1416
- )
1417
-
1418
- # stop when each sentence is finished
1419
- if unfinished_sequences.max() == 0:
1420
- this_peer_finished = True
1421
-
1422
- # stop if we exceed the maximum length
1423
- if stopping_criteria(input_ids, scores):
1424
- this_peer_finished = True
1425
-
1426
- if this_peer_finished and not synced_gpus:
1427
- break
1428
-
1429
- if streamer is not None:
1430
- streamer.end()
1431
-
1432
- if return_dict_in_generate:
1433
- if self.config.is_encoder_decoder:
1434
- return GreedySearchEncoderDecoderOutput(
1435
- sequences=input_ids,
1436
- scores=scores,
1437
- encoder_attentions=encoder_attentions,
1438
- encoder_hidden_states=encoder_hidden_states,
1439
- decoder_attentions=decoder_attentions,
1440
- cross_attentions=cross_attentions,
1441
- decoder_hidden_states=decoder_hidden_states,
1442
- past_key_values=model_kwargs.get("past_key_values"),
1443
- )
1444
- else:
1445
- return GreedySearchDecoderOnlyOutput(
1446
- sequences=input_ids,
1447
- scores=scores,
1448
- attentions=decoder_attentions,
1449
- hidden_states=decoder_hidden_states,
1450
- past_key_values=model_kwargs.get("past_key_values"),
1451
- )
1452
- else:
1453
- return input_ids
1454
-
1455
 
1456
  # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1457
  @add_start_docstrings(
 
39
  from transformers.generation.streamers import BaseStreamer
40
  except: # noqa # pylint: disable=bare-except
41
  BaseStreamer = None
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  from .configuration_internlm2 import InternLM2Config
44
 
 
1082
  output = (logits,) + outputs[1:]
1083
  return (loss,) + output if loss is not None else output
1084
 
1085
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1086
+ output = CausalLMOutputWithPast(
1087
  loss=loss,
1088
  logits=logits,
1089
  past_key_values=outputs.past_key_values,
1090
  hidden_states=outputs.hidden_states,
1091
  attentions=outputs.attentions,
1092
  )
1093
+ output['logits'] = output['logits'].to(device)
1094
+ return output
1095
 
1096
  def prepare_inputs_for_generation(
1097
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
1275
 
1276
  return consumer()
1277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1278
 
1279
  # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1280
  @add_start_docstrings(
modeling_internvl_chat.py CHANGED
@@ -313,8 +313,8 @@ class InternVLChatModel(PreTrainedModel):
313
  if return_history:
314
  return response, history
315
  else:
316
- # query_to_print = query.replace(image_tokens, '<image>')
317
- # print(query_to_print, response)
318
  return response
319
  return response
320
 
 
313
  if return_history:
314
  return response, history
315
  else:
316
+ query_to_print = query.replace(image_tokens, '<image>')
317
+ print(query_to_print, response)
318
  return response
319
  return response
320