Spaces:
Runtime error
Runtime error
import json | |
import requests | |
from aiohttp import ClientSession, ClientTimeout | |
from pydantic import ValidationError | |
from typing import Dict, Optional, List, AsyncIterator, Iterator | |
from text_generation.types import ( | |
StreamResponse, | |
Response, | |
Request, | |
Parameters, | |
) | |
from text_generation.errors import parse_error | |
class Client: | |
"""Client to make calls to a text-generation-inference instance | |
Example: | |
```python | |
>>> from text_generation import Client | |
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz") | |
>>> client.generate("Why is the sky blue?").generated_text | |
' Rayleigh scattering' | |
>>> result = "" | |
>>> for response in client.generate_stream("Why is the sky blue?"): | |
>>> if not response.token.special: | |
>>> result += response.token.text | |
>>> result | |
' Rayleigh scattering' | |
``` | |
""" | |
def __init__( | |
self, | |
base_url: str, | |
headers: Optional[Dict[str, str]] = None, | |
cookies: Optional[Dict[str, str]] = None, | |
timeout: int = 10, | |
): | |
""" | |
Args: | |
base_url (`str`): | |
text-generation-inference instance base url | |
headers (`Optional[Dict[str, str]]`): | |
Additional headers | |
cookies (`Optional[Dict[str, str]]`): | |
Cookies to include in the requests | |
timeout (`int`): | |
Timeout in seconds | |
""" | |
self.base_url = base_url | |
self.headers = headers | |
self.cookies = cookies | |
self.timeout = timeout | |
def generate( | |
self, | |
prompt: str, | |
do_sample: bool = False, | |
max_new_tokens: int = 20, | |
best_of: Optional[int] = None, | |
repetition_penalty: Optional[float] = None, | |
return_full_text: bool = False, | |
seed: Optional[int] = None, | |
stop_sequences: Optional[List[str]] = None, | |
temperature: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
truncate: Optional[int] = None, | |
typical_p: Optional[float] = None, | |
watermark: bool = False, | |
) -> Response: | |
""" | |
Given a prompt, generate the following text | |
Args: | |
prompt (`str`): | |
Input text | |
do_sample (`bool`): | |
Activate logits sampling | |
max_new_tokens (`int`): | |
Maximum number of generated tokens | |
best_of (`int`): | |
Generate best_of sequences and return the one if the highest token logprobs | |
repetition_penalty (`float`): | |
The parameter for repetition penalty. 1.0 means no penalty. See [this | |
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
return_full_text (`bool`): | |
Whether to prepend the prompt to the generated text | |
seed (`int`): | |
Random sampling seed | |
stop_sequences (`List[str]`): | |
Stop generating tokens if a member of `stop_sequences` is generated | |
temperature (`float`): | |
The value used to module the logits distribution. | |
top_k (`int`): | |
The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
top_p (`float`): | |
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
higher are kept for generation. | |
truncate (`int`): | |
Truncate inputs tokens to the given size | |
typical_p (`float`): | |
Typical Decoding mass | |
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
watermark (`bool`): | |
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
Returns: | |
Response: generated response | |
""" | |
# Validate parameters | |
parameters = Parameters( | |
best_of=best_of, | |
details=True, | |
do_sample=do_sample, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
return_full_text=return_full_text, | |
seed=seed, | |
stop=stop_sequences if stop_sequences is not None else [], | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
truncate=truncate, | |
typical_p=typical_p, | |
watermark=watermark, | |
) | |
request = Request(inputs=prompt, stream=False, parameters=parameters) | |
resp = requests.post( | |
self.base_url, | |
json=request.dict(), | |
headers=self.headers, | |
cookies=self.cookies, | |
timeout=self.timeout, | |
) | |
payload = resp.json() | |
if resp.status_code != 200: | |
raise parse_error(resp.status_code, payload) | |
return Response(**payload[0]) | |
def generate_stream( | |
self, | |
prompt: str, | |
do_sample: bool = False, | |
max_new_tokens: int = 20, | |
repetition_penalty: Optional[float] = None, | |
return_full_text: bool = False, | |
seed: Optional[int] = None, | |
stop_sequences: Optional[List[str]] = None, | |
temperature: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
truncate: Optional[int] = None, | |
typical_p: Optional[float] = None, | |
watermark: bool = False, | |
) -> Iterator[StreamResponse]: | |
""" | |
Given a prompt, generate the following stream of tokens | |
Args: | |
prompt (`str`): | |
Input text | |
do_sample (`bool`): | |
Activate logits sampling | |
max_new_tokens (`int`): | |
Maximum number of generated tokens | |
repetition_penalty (`float`): | |
The parameter for repetition penalty. 1.0 means no penalty. See [this | |
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
return_full_text (`bool`): | |
Whether to prepend the prompt to the generated text | |
seed (`int`): | |
Random sampling seed | |
stop_sequences (`List[str]`): | |
Stop generating tokens if a member of `stop_sequences` is generated | |
temperature (`float`): | |
The value used to module the logits distribution. | |
top_k (`int`): | |
The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
top_p (`float`): | |
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
higher are kept for generation. | |
truncate (`int`): | |
Truncate inputs tokens to the given size | |
typical_p (`float`): | |
Typical Decoding mass | |
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
watermark (`bool`): | |
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
Returns: | |
Iterator[StreamResponse]: stream of generated tokens | |
""" | |
# Validate parameters | |
parameters = Parameters( | |
best_of=None, | |
details=True, | |
do_sample=do_sample, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
return_full_text=return_full_text, | |
seed=seed, | |
stop=stop_sequences if stop_sequences is not None else [], | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
truncate=truncate, | |
typical_p=typical_p, | |
watermark=watermark, | |
) | |
request = Request(inputs=prompt, stream=True, parameters=parameters) | |
resp = requests.post( | |
self.base_url, | |
json=request.dict(), | |
headers=self.headers, | |
cookies=self.cookies, | |
timeout=self.timeout, | |
stream=True, | |
) | |
if resp.status_code != 200: | |
raise parse_error(resp.status_code, resp.json()) | |
# Parse ServerSentEvents | |
for byte_payload in resp.iter_lines(): | |
# Skip line | |
if byte_payload == b"\n": | |
continue | |
payload = byte_payload.decode("utf-8") | |
# Event data | |
if payload.startswith("data:"): | |
# Decode payload | |
json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) | |
# Parse payload | |
try: | |
response = StreamResponse(**json_payload) | |
except ValidationError: | |
# If we failed to parse the payload, then it is an error payload | |
raise parse_error(resp.status_code, json_payload) | |
yield response | |
class AsyncClient: | |
"""Asynchronous Client to make calls to a text-generation-inference instance | |
Example: | |
```python | |
>>> from text_generation import AsyncClient | |
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz") | |
>>> response = await client.generate("Why is the sky blue?") | |
>>> response.generated_text | |
' Rayleigh scattering' | |
>>> result = "" | |
>>> async for response in client.generate_stream("Why is the sky blue?"): | |
>>> if not response.token.special: | |
>>> result += response.token.text | |
>>> result | |
' Rayleigh scattering' | |
``` | |
""" | |
def __init__( | |
self, | |
base_url: str, | |
headers: Optional[Dict[str, str]] = None, | |
cookies: Optional[Dict[str, str]] = None, | |
timeout: int = 10, | |
): | |
""" | |
Args: | |
base_url (`str`): | |
text-generation-inference instance base url | |
headers (`Optional[Dict[str, str]]`): | |
Additional headers | |
cookies (`Optional[Dict[str, str]]`): | |
Cookies to include in the requests | |
timeout (`int`): | |
Timeout in seconds | |
""" | |
self.base_url = base_url | |
self.headers = headers | |
self.cookies = cookies | |
self.timeout = ClientTimeout(timeout * 60) | |
async def generate( | |
self, | |
prompt: str, | |
do_sample: bool = False, | |
max_new_tokens: int = 20, | |
best_of: Optional[int] = None, | |
repetition_penalty: Optional[float] = None, | |
return_full_text: bool = False, | |
seed: Optional[int] = None, | |
stop_sequences: Optional[List[str]] = None, | |
temperature: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
truncate: Optional[int] = None, | |
typical_p: Optional[float] = None, | |
watermark: bool = False, | |
) -> Response: | |
""" | |
Given a prompt, generate the following text asynchronously | |
Args: | |
prompt (`str`): | |
Input text | |
do_sample (`bool`): | |
Activate logits sampling | |
max_new_tokens (`int`): | |
Maximum number of generated tokens | |
best_of (`int`): | |
Generate best_of sequences and return the one if the highest token logprobs | |
repetition_penalty (`float`): | |
The parameter for repetition penalty. 1.0 means no penalty. See [this | |
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
return_full_text (`bool`): | |
Whether to prepend the prompt to the generated text | |
seed (`int`): | |
Random sampling seed | |
stop_sequences (`List[str]`): | |
Stop generating tokens if a member of `stop_sequences` is generated | |
temperature (`float`): | |
The value used to module the logits distribution. | |
top_k (`int`): | |
The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
top_p (`float`): | |
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
higher are kept for generation. | |
truncate (`int`): | |
Truncate inputs tokens to the given size | |
typical_p (`float`): | |
Typical Decoding mass | |
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
watermark (`bool`): | |
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
Returns: | |
Response: generated response | |
""" | |
# Validate parameters | |
parameters = Parameters( | |
best_of=best_of, | |
details=True, | |
do_sample=do_sample, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
return_full_text=return_full_text, | |
seed=seed, | |
stop=stop_sequences if stop_sequences is not None else [], | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
truncate=truncate, | |
typical_p=typical_p, | |
watermark=watermark, | |
) | |
request = Request(inputs=prompt, stream=False, parameters=parameters) | |
async with ClientSession( | |
headers=self.headers, cookies=self.cookies, timeout=self.timeout | |
) as session: | |
async with session.post(self.base_url, json=request.dict()) as resp: | |
payload = await resp.json() | |
if resp.status != 200: | |
raise parse_error(resp.status, payload) | |
return Response(**payload[0]) | |
async def generate_stream( | |
self, | |
prompt: str, | |
do_sample: bool = False, | |
max_new_tokens: int = 20, | |
repetition_penalty: Optional[float] = None, | |
return_full_text: bool = False, | |
seed: Optional[int] = None, | |
stop_sequences: Optional[List[str]] = None, | |
temperature: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
truncate: Optional[int] = None, | |
typical_p: Optional[float] = None, | |
watermark: bool = False, | |
) -> AsyncIterator[StreamResponse]: | |
""" | |
Given a prompt, generate the following stream of tokens asynchronously | |
Args: | |
prompt (`str`): | |
Input text | |
do_sample (`bool`): | |
Activate logits sampling | |
max_new_tokens (`int`): | |
Maximum number of generated tokens | |
repetition_penalty (`float`): | |
The parameter for repetition penalty. 1.0 means no penalty. See [this | |
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
return_full_text (`bool`): | |
Whether to prepend the prompt to the generated text | |
seed (`int`): | |
Random sampling seed | |
stop_sequences (`List[str]`): | |
Stop generating tokens if a member of `stop_sequences` is generated | |
temperature (`float`): | |
The value used to module the logits distribution. | |
top_k (`int`): | |
The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
top_p (`float`): | |
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
higher are kept for generation. | |
truncate (`int`): | |
Truncate inputs tokens to the given size | |
typical_p (`float`): | |
Typical Decoding mass | |
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
watermark (`bool`): | |
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
Returns: | |
AsyncIterator[StreamResponse]: stream of generated tokens | |
""" | |
# Validate parameters | |
parameters = Parameters( | |
best_of=None, | |
details=True, | |
do_sample=do_sample, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
return_full_text=return_full_text, | |
seed=seed, | |
stop=stop_sequences if stop_sequences is not None else [], | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
truncate=truncate, | |
typical_p=typical_p, | |
watermark=watermark, | |
) | |
request = Request(inputs=prompt, stream=True, parameters=parameters) | |
async with ClientSession( | |
headers=self.headers, cookies=self.cookies, timeout=self.timeout | |
) as session: | |
async with session.post(self.base_url, json=request.dict()) as resp: | |
if resp.status != 200: | |
raise parse_error(resp.status, await resp.json()) | |
# Parse ServerSentEvents | |
async for byte_payload in resp.content: | |
# Skip line | |
if byte_payload == b"\n": | |
continue | |
payload = byte_payload.decode("utf-8") | |
# Event data | |
if payload.startswith("data:"): | |
# Decode payload | |
json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) | |
# Parse payload | |
try: | |
response = StreamResponse(**json_payload) | |
except ValidationError: | |
# If we failed to parse the payload, then it is an error payload | |
raise parse_error(resp.status, json_payload) | |
yield response | |