|
import uuid |
|
from typing import Any, Dict, List, Union |
|
|
|
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging |
|
from .base import PIPELINE_INIT_ARGS, Pipeline |
|
|
|
|
|
if is_tf_available(): |
|
import tensorflow as tf |
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Conversation: |
|
""" |
|
Utility class containing a conversation and its history. This class is meant to be used as an input to the |
|
[`ConversationalPipeline`]. The conversation contains several utility functions to manage the addition of new user |
|
inputs and generated model responses. |
|
|
|
Arguments: |
|
messages (Union[str, List[Dict[str, str]]], *optional*): |
|
The initial messages to start the conversation, either a string, or a list of dicts containing "role" and |
|
"content" keys. If a string is passed, it is interpreted as a single message with the "user" role. |
|
conversation_id (`uuid.UUID`, *optional*): |
|
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the |
|
conversation. |
|
|
|
Usage: |
|
|
|
```python |
|
conversation = Conversation("Going to the movies tonight - any suggestions?") |
|
conversation.add_message({"role": "assistant", "content": "The Big lebowski."}) |
|
conversation.add_message({"role": "user", "content": "Is it good?"}) |
|
```""" |
|
|
|
def __init__( |
|
self, messages: Union[str, List[Dict[str, str]]] = None, conversation_id: uuid.UUID = None, **deprecated_kwargs |
|
): |
|
if not conversation_id: |
|
conversation_id = uuid.uuid4() |
|
|
|
if messages is None: |
|
text = deprecated_kwargs.pop("text", None) |
|
if text is not None: |
|
messages = [{"role": "user", "content": text}] |
|
else: |
|
messages = [] |
|
elif isinstance(messages, str): |
|
messages = [{"role": "user", "content": messages}] |
|
|
|
|
|
|
|
generated_responses = deprecated_kwargs.pop("generated_responses", None) |
|
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None) |
|
if generated_responses is not None and past_user_inputs is None: |
|
raise ValueError("generated_responses cannot be passed without past_user_inputs!") |
|
if past_user_inputs is not None: |
|
legacy_messages = [] |
|
if generated_responses is None: |
|
generated_responses = [] |
|
|
|
for i in range(max([len(past_user_inputs), len(generated_responses)])): |
|
if i < len(past_user_inputs): |
|
legacy_messages.append({"role": "user", "content": past_user_inputs[i]}) |
|
if i < len(generated_responses): |
|
legacy_messages.append({"role": "assistant", "content": generated_responses[i]}) |
|
messages = legacy_messages + messages |
|
|
|
self.uuid = conversation_id |
|
self.messages = messages |
|
|
|
def __eq__(self, other): |
|
if not isinstance(other, Conversation): |
|
return False |
|
return self.uuid == other.uuid or self.messages == other.messages |
|
|
|
def add_message(self, message: Dict[str, str]): |
|
if not set(message.keys()) == {"role", "content"}: |
|
raise ValueError("Message should contain only 'role' and 'content' keys!") |
|
if message["role"] not in ("user", "assistant", "system"): |
|
raise ValueError("Only 'user', 'assistant' and 'system' roles are supported for now!") |
|
self.messages.append(message) |
|
|
|
def add_user_input(self, text: str, overwrite: bool = False): |
|
""" |
|
Add a user input to the conversation for the next round. This is a legacy method that assumes that inputs must |
|
alternate user/assistant/user/assistant, and so will not add multiple user messages in succession. We recommend |
|
just using `add_message` with role "user" instead. |
|
""" |
|
if len(self) > 0 and self[-1]["role"] == "user": |
|
if overwrite: |
|
logger.warning( |
|
f'User input added while unprocessed input was existing: "{self[-1]["content"]}" was overwritten ' |
|
f'with: "{text}".' |
|
) |
|
self[-1]["content"] = text |
|
else: |
|
logger.warning( |
|
f'User input added while unprocessed input was existing: "{self[-1]["content"]}" new input ' |
|
f'ignored: "{text}". Set `overwrite` to True to overwrite unprocessed user input' |
|
) |
|
else: |
|
self.messages.append({"role": "user", "content": text}) |
|
|
|
def append_response(self, response: str): |
|
""" |
|
This is a legacy method. We recommend just using `add_message` with an appropriate role instead. |
|
""" |
|
self.messages.append({"role": "assistant", "content": response}) |
|
|
|
def mark_processed(self): |
|
""" |
|
This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between |
|
processed and unprocessed user input. |
|
""" |
|
pass |
|
|
|
def __iter__(self): |
|
for message in self.messages: |
|
yield message |
|
|
|
def __getitem__(self, item): |
|
return self.messages[item] |
|
|
|
def __setitem__(self, key, value): |
|
self.messages[key] = value |
|
|
|
def __len__(self): |
|
return len(self.messages) |
|
|
|
def __repr__(self): |
|
""" |
|
Generates a string representation of the conversation. |
|
|
|
Returns: |
|
`str`: |
|
|
|
Example: |
|
Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user: Going to the movies tonight - any suggestions? |
|
bot: The Big Lebowski |
|
""" |
|
output = f"Conversation id: {self.uuid}\n" |
|
for message in self.messages: |
|
output += f"{message['role']}: {message['content']}\n" |
|
return output |
|
|
|
def iter_texts(self): |
|
|
|
|
|
for message in self.messages: |
|
yield message["role"] == "user", message["content"] |
|
|
|
@property |
|
def past_user_inputs(self): |
|
|
|
|
|
return [message["content"] for message in self.messages if message["role"] == "user"] |
|
|
|
@property |
|
def generated_responses(self): |
|
|
|
|
|
return [message["content"] for message in self.messages if message["role"] == "assistant"] |
|
|
|
|
|
@add_end_docstrings( |
|
PIPELINE_INIT_ARGS, |
|
r""" |
|
min_length_for_response (`int`, *optional*, defaults to 32): |
|
The minimum length (in number of tokens) for a response. |
|
minimum_tokens (`int`, *optional*, defaults to 10): |
|
The minimum length of tokens to leave for a response. |
|
""", |
|
) |
|
class ConversationalPipeline(Pipeline): |
|
""" |
|
Multi-turn conversational pipeline. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import pipeline, Conversation |
|
|
|
>>> chatbot = pipeline(model="microsoft/DialoGPT-medium") |
|
>>> conversation = Conversation("Going to the movies tonight - any suggestions?") |
|
>>> conversation = chatbot(conversation) |
|
>>> conversation.generated_responses[-1] |
|
'The Big Lebowski' |
|
|
|
>>> conversation.add_user_input("Is it an action movie?") |
|
>>> conversation = chatbot(conversation) |
|
>>> conversation.generated_responses[-1] |
|
"It's a comedy." |
|
``` |
|
|
|
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) |
|
|
|
This conversational pipeline can currently be loaded from [`pipeline`] using the following task identifier: |
|
`"conversational"`. |
|
|
|
The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task, |
|
currently: *'microsoft/DialoGPT-small'*, *'microsoft/DialoGPT-medium'*, *'microsoft/DialoGPT-large'*. See the |
|
up-to-date list of available models on |
|
[huggingface.co/models](https://huggingface.co/models?filter=conversational). |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
if self.tokenizer.pad_token_id is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
def _sanitize_parameters( |
|
self, min_length_for_response=None, minimum_tokens=None, clean_up_tokenization_spaces=None, **generate_kwargs |
|
): |
|
preprocess_params = {} |
|
forward_params = {} |
|
postprocess_params = {} |
|
|
|
if min_length_for_response is not None: |
|
preprocess_params["min_length_for_response"] = min_length_for_response |
|
if minimum_tokens is not None: |
|
forward_params["minimum_tokens"] = minimum_tokens |
|
|
|
if "max_length" in generate_kwargs: |
|
forward_params["max_length"] = generate_kwargs["max_length"] |
|
|
|
if clean_up_tokenization_spaces is not None: |
|
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces |
|
|
|
if generate_kwargs: |
|
forward_params.update(generate_kwargs) |
|
return preprocess_params, forward_params, postprocess_params |
|
|
|
def __call__(self, conversations: Union[Conversation, List[Conversation]], num_workers=0, **kwargs): |
|
r""" |
|
Generate responses for the conversation(s) given as inputs. |
|
|
|
Args: |
|
conversations (a [`Conversation`] or a list of [`Conversation`]): |
|
Conversations to generate responses for. |
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): |
|
Whether or not to clean up the potential extra spaces in the text output. |
|
generate_kwargs: |
|
Additional keyword arguments to pass along to the generate method of the model (see the generate method |
|
corresponding to your framework [here](./model#generative-models)). |
|
|
|
Returns: |
|
[`Conversation`] or a list of [`Conversation`]: Conversation(s) with updated generated responses for those |
|
containing a new user input. |
|
""" |
|
|
|
|
|
|
|
|
|
outputs = super().__call__(conversations, num_workers=num_workers, **kwargs) |
|
if isinstance(outputs, list) and len(outputs) == 1: |
|
return outputs[0] |
|
return outputs |
|
|
|
def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]: |
|
input_ids = self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True) |
|
|
|
if self.framework == "pt": |
|
input_ids = torch.LongTensor([input_ids]) |
|
elif self.framework == "tf": |
|
input_ids = tf.constant([input_ids]) |
|
return {"input_ids": input_ids, "conversation": conversation} |
|
|
|
def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs): |
|
max_length = generate_kwargs.get("max_length", self.model.config.max_length) |
|
|
|
n = model_inputs["input_ids"].shape[1] |
|
if max_length - minimum_tokens < n: |
|
logger.warning( |
|
f"Conversation input is too long ({n}), trimming it to {max_length - minimum_tokens} tokens. Consider increasing `max_length` to avoid truncation." |
|
) |
|
trim = max_length - minimum_tokens |
|
model_inputs["input_ids"] = model_inputs["input_ids"][:, -trim:] |
|
if "attention_mask" in model_inputs: |
|
model_inputs["attention_mask"] = model_inputs["attention_mask"][:, -trim:] |
|
conversation = model_inputs.pop("conversation") |
|
generate_kwargs["max_length"] = max_length |
|
output_ids = self.model.generate(**model_inputs, **generate_kwargs) |
|
if self.model.config.is_encoder_decoder: |
|
start_position = 1 |
|
else: |
|
start_position = n |
|
return {"output_ids": output_ids[:, start_position:], "conversation": conversation} |
|
|
|
def postprocess(self, model_outputs, clean_up_tokenization_spaces=True): |
|
output_ids = model_outputs["output_ids"] |
|
answer = self.tokenizer.decode( |
|
output_ids[0], |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
) |
|
conversation = model_outputs["conversation"] |
|
conversation.add_message({"role": "assistant", "content": answer}) |
|
return conversation |
|
|