Spaces:
Runtime error
Runtime error
| import torch | |
| def get_output_batch( | |
| model, tokenizer, prompts, generation_config, device='cuda' | |
| ): | |
| if len(prompts) == 1: | |
| encoding = tokenizer(prompts, return_tensors="pt") | |
| input_ids = encoding["input_ids"].to(device) | |
| generated_id = model.generate( | |
| input_ids=input_ids, | |
| generation_config=generation_config, | |
| ) | |
| decoded = tokenizer.batch_decode( | |
| generated_id, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| del input_ids, generated_id | |
| torch.cuda.empty_cache() | |
| return decoded | |
| else: | |
| encodings = tokenizer(prompts, padding=True, return_tensors="pt").to(device) | |
| generated_ids = model.generate( | |
| **encodings, | |
| generation_config=generation_config, | |
| ) | |
| decoded = tokenizer.batch_decode( | |
| generated_ids, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| del encodings, generated_ids | |
| torch.cuda.empty_cache() | |
| return decoded | |