Spaces:
Runtime error
Runtime error
feat: use the cache payload and overwrite the text_generation func
Browse files
app.py
CHANGED
@@ -257,6 +257,279 @@ class InferenceClientUS(InferenceClient):
|
|
257 |
continue
|
258 |
raise
|
259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
client = InferenceClientUS(
|
261 |
API_URL,
|
262 |
headers={"Authorization": f"Bearer {HF_TOKEN}"},
|
|
|
257 |
continue
|
258 |
raise
|
259 |
|
260 |
+
def text_generation(
|
261 |
+
self,
|
262 |
+
prompt: str,
|
263 |
+
*,
|
264 |
+
details: bool = False,
|
265 |
+
stream: bool = False,
|
266 |
+
model: Optional[str] = None,
|
267 |
+
do_sample: bool = False,
|
268 |
+
max_new_tokens: int = 20,
|
269 |
+
best_of: Optional[int] = None,
|
270 |
+
repetition_penalty: Optional[float] = None,
|
271 |
+
return_full_text: bool = False,
|
272 |
+
seed: Optional[int] = None,
|
273 |
+
stop_sequences: Optional[List[str]] = None,
|
274 |
+
temperature: Optional[float] = None,
|
275 |
+
top_k: Optional[int] = None,
|
276 |
+
top_p: Optional[float] = None,
|
277 |
+
truncate: Optional[int] = None,
|
278 |
+
typical_p: Optional[float] = None,
|
279 |
+
watermark: bool = False,
|
280 |
+
decoder_input_details: bool = False,
|
281 |
+
) -> Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]:
|
282 |
+
"""
|
283 |
+
Given a prompt, generate the following text.
|
284 |
+
|
285 |
+
It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
|
286 |
+
early failures.
|
287 |
+
|
288 |
+
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
|
289 |
+
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
|
290 |
+
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
|
291 |
+
not exactly the same. This method is compatible with both approaches but some parameters are only available for
|
292 |
+
`text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
|
293 |
+
continues correctly.
|
294 |
+
|
295 |
+
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
prompt (`str`):
|
299 |
+
Input text.
|
300 |
+
details (`bool`, *optional*):
|
301 |
+
By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens,
|
302 |
+
probabilities, seed, finish reason, etc.). Only available for models running on with the
|
303 |
+
`text-generation-inference` backend.
|
304 |
+
stream (`bool`, *optional*):
|
305 |
+
By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of
|
306 |
+
tokens to be returned. Only available for models running on with the `text-generation-inference`
|
307 |
+
backend.
|
308 |
+
model (`str`, *optional*):
|
309 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
310 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
311 |
+
do_sample (`bool`):
|
312 |
+
Activate logits sampling
|
313 |
+
max_new_tokens (`int`):
|
314 |
+
Maximum number of generated tokens
|
315 |
+
best_of (`int`):
|
316 |
+
Generate best_of sequences and return the one if the highest token logprobs
|
317 |
+
repetition_penalty (`float`):
|
318 |
+
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
319 |
+
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
320 |
+
return_full_text (`bool`):
|
321 |
+
Whether to prepend the prompt to the generated text
|
322 |
+
seed (`int`):
|
323 |
+
Random sampling seed
|
324 |
+
stop_sequences (`List[str]`):
|
325 |
+
Stop generating tokens if a member of `stop_sequences` is generated
|
326 |
+
temperature (`float`):
|
327 |
+
The value used to module the logits distribution.
|
328 |
+
top_k (`int`):
|
329 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
330 |
+
top_p (`float`):
|
331 |
+
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
332 |
+
higher are kept for generation.
|
333 |
+
truncate (`int`):
|
334 |
+
Truncate inputs tokens to the given size
|
335 |
+
typical_p (`float`):
|
336 |
+
Typical Decoding mass
|
337 |
+
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
338 |
+
watermark (`bool`):
|
339 |
+
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
340 |
+
decoder_input_details (`bool`):
|
341 |
+
Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
|
342 |
+
into account. Defaults to `False`.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
`Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`:
|
346 |
+
Generated text returned from the server:
|
347 |
+
- if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
|
348 |
+
- if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
|
349 |
+
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
|
350 |
+
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`]
|
351 |
+
|
352 |
+
Raises:
|
353 |
+
`ValidationError`:
|
354 |
+
If input values are not valid. No HTTP call is made to the server.
|
355 |
+
[`InferenceTimeoutError`]:
|
356 |
+
If the model is unavailable or the request times out.
|
357 |
+
`HTTPError`:
|
358 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
359 |
+
|
360 |
+
Example:
|
361 |
+
```py
|
362 |
+
>>> from huggingface_hub import InferenceClient
|
363 |
+
>>> client = InferenceClient()
|
364 |
+
|
365 |
+
# Case 1: generate text
|
366 |
+
>>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
|
367 |
+
'100% open source and built to be easy to use.'
|
368 |
+
|
369 |
+
# Case 2: iterate over the generated tokens. Useful for large generation.
|
370 |
+
>>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
|
371 |
+
... print(token)
|
372 |
+
100
|
373 |
+
%
|
374 |
+
open
|
375 |
+
source
|
376 |
+
and
|
377 |
+
built
|
378 |
+
to
|
379 |
+
be
|
380 |
+
easy
|
381 |
+
to
|
382 |
+
use
|
383 |
+
.
|
384 |
+
|
385 |
+
# Case 3: get more details about the generation process.
|
386 |
+
>>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
|
387 |
+
TextGenerationResponse(
|
388 |
+
generated_text='100% open source and built to be easy to use.',
|
389 |
+
details=Details(
|
390 |
+
finish_reason=<FinishReason.Length: 'length'>,
|
391 |
+
generated_tokens=12,
|
392 |
+
seed=None,
|
393 |
+
prefill=[
|
394 |
+
InputToken(id=487, text='The', logprob=None),
|
395 |
+
InputToken(id=53789, text=' hugging', logprob=-13.171875),
|
396 |
+
(...)
|
397 |
+
InputToken(id=204, text=' ', logprob=-7.0390625)
|
398 |
+
],
|
399 |
+
tokens=[
|
400 |
+
Token(id=1425, text='100', logprob=-1.0175781, special=False),
|
401 |
+
Token(id=16, text='%', logprob=-0.0463562, special=False),
|
402 |
+
(...)
|
403 |
+
Token(id=25, text='.', logprob=-0.5703125, special=False)
|
404 |
+
],
|
405 |
+
best_of_sequences=None
|
406 |
+
)
|
407 |
+
)
|
408 |
+
|
409 |
+
# Case 4: iterate over the generated tokens with more details.
|
410 |
+
# Last object is more complete, containing the full generated text and the finish reason.
|
411 |
+
>>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
|
412 |
+
... print(details)
|
413 |
+
...
|
414 |
+
TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
|
415 |
+
TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
|
416 |
+
TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
|
417 |
+
TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
|
418 |
+
TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
|
419 |
+
TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
|
420 |
+
TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
|
421 |
+
TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
|
422 |
+
TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
|
423 |
+
TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
|
424 |
+
TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
|
425 |
+
TextGenerationStreamResponse(token=Token(
|
426 |
+
id=25,
|
427 |
+
text='.',
|
428 |
+
logprob=-0.5703125,
|
429 |
+
special=False),
|
430 |
+
generated_text='100% open source and built to be easy to use.',
|
431 |
+
details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None)
|
432 |
+
)
|
433 |
+
```
|
434 |
+
"""
|
435 |
+
# NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
|
436 |
+
# like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
|
437 |
+
|
438 |
+
if decoder_input_details and not details:
|
439 |
+
warnings.warn(
|
440 |
+
"`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
|
441 |
+
" the output from the server will be truncated."
|
442 |
+
)
|
443 |
+
decoder_input_details = False
|
444 |
+
|
445 |
+
# Validate parameters
|
446 |
+
parameters = TextGenerationParameters(
|
447 |
+
best_of=best_of,
|
448 |
+
details=details,
|
449 |
+
do_sample=do_sample,
|
450 |
+
max_new_tokens=max_new_tokens,
|
451 |
+
repetition_penalty=repetition_penalty,
|
452 |
+
return_full_text=return_full_text,
|
453 |
+
seed=seed,
|
454 |
+
stop=stop_sequences if stop_sequences is not None else [],
|
455 |
+
temperature=temperature,
|
456 |
+
top_k=top_k,
|
457 |
+
top_p=top_p,
|
458 |
+
truncate=truncate,
|
459 |
+
typical_p=typical_p,
|
460 |
+
watermark=watermark,
|
461 |
+
decoder_input_details=decoder_input_details,
|
462 |
+
)
|
463 |
+
request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters)
|
464 |
+
payload = asdict(request)
|
465 |
+
|
466 |
+
|
467 |
+
# add the use_cache option
|
468 |
+
print(f"payload:{payload}")
|
469 |
+
payload["options"]['use_cache'] = False
|
470 |
+
|
471 |
+
# Remove some parameters if not a TGI server
|
472 |
+
if not _is_tgi_server(model):
|
473 |
+
ignored_parameters = []
|
474 |
+
for key in "watermark", "stop", "details", "decoder_input_details":
|
475 |
+
if payload["parameters"][key] is not None:
|
476 |
+
ignored_parameters.append(key)
|
477 |
+
del payload["parameters"][key]
|
478 |
+
if len(ignored_parameters) > 0:
|
479 |
+
warnings.warn(
|
480 |
+
"API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
|
481 |
+
f" {ignored_parameters}.",
|
482 |
+
UserWarning,
|
483 |
+
)
|
484 |
+
if details:
|
485 |
+
warnings.warn(
|
486 |
+
"API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will"
|
487 |
+
" be ignored meaning only the generated text will be returned.",
|
488 |
+
UserWarning,
|
489 |
+
)
|
490 |
+
details = False
|
491 |
+
if stream:
|
492 |
+
raise ValueError(
|
493 |
+
"API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream."
|
494 |
+
" Please pass `stream=False` as input."
|
495 |
+
)
|
496 |
+
|
497 |
+
# Handle errors separately for more precise error messages
|
498 |
+
try:
|
499 |
+
bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
|
500 |
+
except HTTPError as e:
|
501 |
+
if isinstance(e, BadRequestError) and "The following `model_kwargs` are not used by the model" in str(e):
|
502 |
+
_set_as_non_tgi(model)
|
503 |
+
return self.text_generation( # type: ignore
|
504 |
+
prompt=prompt,
|
505 |
+
details=details,
|
506 |
+
stream=stream,
|
507 |
+
model=model,
|
508 |
+
do_sample=do_sample,
|
509 |
+
max_new_tokens=max_new_tokens,
|
510 |
+
best_of=best_of,
|
511 |
+
repetition_penalty=repetition_penalty,
|
512 |
+
return_full_text=return_full_text,
|
513 |
+
seed=seed,
|
514 |
+
stop_sequences=stop_sequences,
|
515 |
+
temperature=temperature,
|
516 |
+
top_k=top_k,
|
517 |
+
top_p=top_p,
|
518 |
+
truncate=truncate,
|
519 |
+
typical_p=typical_p,
|
520 |
+
watermark=watermark,
|
521 |
+
decoder_input_details=decoder_input_details,
|
522 |
+
)
|
523 |
+
raise_text_generation_error(e)
|
524 |
+
|
525 |
+
# Parse output
|
526 |
+
if stream:
|
527 |
+
return _stream_text_generation_response(bytes_output, details) # type: ignore
|
528 |
+
|
529 |
+
data = _bytes_to_dict(bytes_output)[0]
|
530 |
+
return TextGenerationResponse(**data) if details else data["generated_text"]
|
531 |
+
|
532 |
+
|
533 |
client = InferenceClientUS(
|
534 |
API_URL,
|
535 |
headers={"Authorization": f"Bearer {HF_TOKEN}"},
|