vpcom commited on
Commit
823beb2
1 Parent(s): 318cc53

feat: use the cache payload and overwrite the text_generation func

Browse files
Files changed (1) hide show
  1. app.py +273 -0
app.py CHANGED
@@ -257,6 +257,279 @@ class InferenceClientUS(InferenceClient):
257
  continue
258
  raise
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  client = InferenceClientUS(
261
  API_URL,
262
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
 
257
  continue
258
  raise
259
 
260
+ def text_generation(
261
+ self,
262
+ prompt: str,
263
+ *,
264
+ details: bool = False,
265
+ stream: bool = False,
266
+ model: Optional[str] = None,
267
+ do_sample: bool = False,
268
+ max_new_tokens: int = 20,
269
+ best_of: Optional[int] = None,
270
+ repetition_penalty: Optional[float] = None,
271
+ return_full_text: bool = False,
272
+ seed: Optional[int] = None,
273
+ stop_sequences: Optional[List[str]] = None,
274
+ temperature: Optional[float] = None,
275
+ top_k: Optional[int] = None,
276
+ top_p: Optional[float] = None,
277
+ truncate: Optional[int] = None,
278
+ typical_p: Optional[float] = None,
279
+ watermark: bool = False,
280
+ decoder_input_details: bool = False,
281
+ ) -> Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]:
282
+ """
283
+ Given a prompt, generate the following text.
284
+
285
+ It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
286
+ early failures.
287
+
288
+ API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
289
+ go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
290
+ default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
291
+ not exactly the same. This method is compatible with both approaches but some parameters are only available for
292
+ `text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
293
+ continues correctly.
294
+
295
+ To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
296
+
297
+ Args:
298
+ prompt (`str`):
299
+ Input text.
300
+ details (`bool`, *optional*):
301
+ By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens,
302
+ probabilities, seed, finish reason, etc.). Only available for models running on with the
303
+ `text-generation-inference` backend.
304
+ stream (`bool`, *optional*):
305
+ By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of
306
+ tokens to be returned. Only available for models running on with the `text-generation-inference`
307
+ backend.
308
+ model (`str`, *optional*):
309
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
310
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
311
+ do_sample (`bool`):
312
+ Activate logits sampling
313
+ max_new_tokens (`int`):
314
+ Maximum number of generated tokens
315
+ best_of (`int`):
316
+ Generate best_of sequences and return the one if the highest token logprobs
317
+ repetition_penalty (`float`):
318
+ The parameter for repetition penalty. 1.0 means no penalty. See [this
319
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
320
+ return_full_text (`bool`):
321
+ Whether to prepend the prompt to the generated text
322
+ seed (`int`):
323
+ Random sampling seed
324
+ stop_sequences (`List[str]`):
325
+ Stop generating tokens if a member of `stop_sequences` is generated
326
+ temperature (`float`):
327
+ The value used to module the logits distribution.
328
+ top_k (`int`):
329
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
330
+ top_p (`float`):
331
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
332
+ higher are kept for generation.
333
+ truncate (`int`):
334
+ Truncate inputs tokens to the given size
335
+ typical_p (`float`):
336
+ Typical Decoding mass
337
+ See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
338
+ watermark (`bool`):
339
+ Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
340
+ decoder_input_details (`bool`):
341
+ Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
342
+ into account. Defaults to `False`.
343
+
344
+ Returns:
345
+ `Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`:
346
+ Generated text returned from the server:
347
+ - if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
348
+ - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
349
+ - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
350
+ - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`]
351
+
352
+ Raises:
353
+ `ValidationError`:
354
+ If input values are not valid. No HTTP call is made to the server.
355
+ [`InferenceTimeoutError`]:
356
+ If the model is unavailable or the request times out.
357
+ `HTTPError`:
358
+ If the request fails with an HTTP error status code other than HTTP 503.
359
+
360
+ Example:
361
+ ```py
362
+ >>> from huggingface_hub import InferenceClient
363
+ >>> client = InferenceClient()
364
+
365
+ # Case 1: generate text
366
+ >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
367
+ '100% open source and built to be easy to use.'
368
+
369
+ # Case 2: iterate over the generated tokens. Useful for large generation.
370
+ >>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
371
+ ... print(token)
372
+ 100
373
+ %
374
+ open
375
+ source
376
+ and
377
+ built
378
+ to
379
+ be
380
+ easy
381
+ to
382
+ use
383
+ .
384
+
385
+ # Case 3: get more details about the generation process.
386
+ >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
387
+ TextGenerationResponse(
388
+ generated_text='100% open source and built to be easy to use.',
389
+ details=Details(
390
+ finish_reason=<FinishReason.Length: 'length'>,
391
+ generated_tokens=12,
392
+ seed=None,
393
+ prefill=[
394
+ InputToken(id=487, text='The', logprob=None),
395
+ InputToken(id=53789, text=' hugging', logprob=-13.171875),
396
+ (...)
397
+ InputToken(id=204, text=' ', logprob=-7.0390625)
398
+ ],
399
+ tokens=[
400
+ Token(id=1425, text='100', logprob=-1.0175781, special=False),
401
+ Token(id=16, text='%', logprob=-0.0463562, special=False),
402
+ (...)
403
+ Token(id=25, text='.', logprob=-0.5703125, special=False)
404
+ ],
405
+ best_of_sequences=None
406
+ )
407
+ )
408
+
409
+ # Case 4: iterate over the generated tokens with more details.
410
+ # Last object is more complete, containing the full generated text and the finish reason.
411
+ >>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
412
+ ... print(details)
413
+ ...
414
+ TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
415
+ TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
416
+ TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
417
+ TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
418
+ TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
419
+ TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
420
+ TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
421
+ TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
422
+ TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
423
+ TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
424
+ TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
425
+ TextGenerationStreamResponse(token=Token(
426
+ id=25,
427
+ text='.',
428
+ logprob=-0.5703125,
429
+ special=False),
430
+ generated_text='100% open source and built to be easy to use.',
431
+ details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None)
432
+ )
433
+ ```
434
+ """
435
+ # NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
436
+ # like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
437
+
438
+ if decoder_input_details and not details:
439
+ warnings.warn(
440
+ "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
441
+ " the output from the server will be truncated."
442
+ )
443
+ decoder_input_details = False
444
+
445
+ # Validate parameters
446
+ parameters = TextGenerationParameters(
447
+ best_of=best_of,
448
+ details=details,
449
+ do_sample=do_sample,
450
+ max_new_tokens=max_new_tokens,
451
+ repetition_penalty=repetition_penalty,
452
+ return_full_text=return_full_text,
453
+ seed=seed,
454
+ stop=stop_sequences if stop_sequences is not None else [],
455
+ temperature=temperature,
456
+ top_k=top_k,
457
+ top_p=top_p,
458
+ truncate=truncate,
459
+ typical_p=typical_p,
460
+ watermark=watermark,
461
+ decoder_input_details=decoder_input_details,
462
+ )
463
+ request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters)
464
+ payload = asdict(request)
465
+
466
+
467
+ # add the use_cache option
468
+ print(f"payload:{payload}")
469
+ payload["options"]['use_cache'] = False
470
+
471
+ # Remove some parameters if not a TGI server
472
+ if not _is_tgi_server(model):
473
+ ignored_parameters = []
474
+ for key in "watermark", "stop", "details", "decoder_input_details":
475
+ if payload["parameters"][key] is not None:
476
+ ignored_parameters.append(key)
477
+ del payload["parameters"][key]
478
+ if len(ignored_parameters) > 0:
479
+ warnings.warn(
480
+ "API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
481
+ f" {ignored_parameters}.",
482
+ UserWarning,
483
+ )
484
+ if details:
485
+ warnings.warn(
486
+ "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will"
487
+ " be ignored meaning only the generated text will be returned.",
488
+ UserWarning,
489
+ )
490
+ details = False
491
+ if stream:
492
+ raise ValueError(
493
+ "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream."
494
+ " Please pass `stream=False` as input."
495
+ )
496
+
497
+ # Handle errors separately for more precise error messages
498
+ try:
499
+ bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
500
+ except HTTPError as e:
501
+ if isinstance(e, BadRequestError) and "The following `model_kwargs` are not used by the model" in str(e):
502
+ _set_as_non_tgi(model)
503
+ return self.text_generation( # type: ignore
504
+ prompt=prompt,
505
+ details=details,
506
+ stream=stream,
507
+ model=model,
508
+ do_sample=do_sample,
509
+ max_new_tokens=max_new_tokens,
510
+ best_of=best_of,
511
+ repetition_penalty=repetition_penalty,
512
+ return_full_text=return_full_text,
513
+ seed=seed,
514
+ stop_sequences=stop_sequences,
515
+ temperature=temperature,
516
+ top_k=top_k,
517
+ top_p=top_p,
518
+ truncate=truncate,
519
+ typical_p=typical_p,
520
+ watermark=watermark,
521
+ decoder_input_details=decoder_input_details,
522
+ )
523
+ raise_text_generation_error(e)
524
+
525
+ # Parse output
526
+ if stream:
527
+ return _stream_text_generation_response(bytes_output, details) # type: ignore
528
+
529
+ data = _bytes_to_dict(bytes_output)[0]
530
+ return TextGenerationResponse(**data) if details else data["generated_text"]
531
+
532
+
533
  client = InferenceClientUS(
534
  API_URL,
535
  headers={"Authorization": f"Bearer {HF_TOKEN}"},