diff --git a/.gitattributes b/.gitattributes index 108356b2c92c3c34c0f6b0a67c9642ed5dd14a40..e07fff0b70a54b43be13dd6257d698c52fea4884 100644 --- a/.gitattributes +++ b/.gitattributes @@ -249,3 +249,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/pycountry/locales/tr/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/pycountry/locales/kn/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/pycparser/ply/__pycache__/yacc.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torchvision.libs/libpng16.7f72a3c5.so.16 filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/attr/__init__.py b/.venv/lib/python3.11/site-packages/attr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6e0650bc4bf53806420d7ef5f881ecd2bd77ea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/attr/__init__.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: MIT + +""" +Classes Without Boilerplate +""" + +from functools import partial +from typing import Callable, Literal, Protocol + +from . import converters, exceptions, filters, setters, validators +from ._cmp import cmp_using +from ._config import get_run_validators, set_run_validators +from ._funcs import asdict, assoc, astuple, has, resolve_types +from ._make import ( + NOTHING, + Attribute, + Converter, + Factory, + _Nothing, + attrib, + attrs, + evolve, + fields, + fields_dict, + make_class, + validate, +) +from ._next_gen import define, field, frozen, mutable +from ._version_info import VersionInfo + + +s = attributes = attrs +ib = attr = attrib +dataclass = partial(attrs, auto_attribs=True) # happy Easter ;) + + +class AttrsInstance(Protocol): + pass + + +NothingType = Literal[_Nothing.NOTHING] + +__all__ = [ + "NOTHING", + "Attribute", + "AttrsInstance", + "Converter", + "Factory", + "NothingType", + "asdict", + "assoc", + "astuple", + "attr", + "attrib", + "attributes", + "attrs", + "cmp_using", + "converters", + "define", + "evolve", + "exceptions", + "field", + "fields", + "fields_dict", + "filters", + "frozen", + "get_run_validators", + "has", + "ib", + "make_class", + "mutable", + "resolve_types", + "s", + "set_run_validators", + "setters", + "validate", + "validators", +] + + +def _make_getattr(mod_name: str) -> Callable: + """ + Create a metadata proxy for packaging information that uses *mod_name* in + its warnings and errors. + """ + + def __getattr__(name: str) -> str: + if name not in ("__version__", "__version_info__"): + msg = f"module {mod_name} has no attribute {name}" + raise AttributeError(msg) + + from importlib.metadata import metadata + + meta = metadata("attrs") + + if name == "__version_info__": + return VersionInfo._from_version_string(meta["version"]) + + return meta["version"] + + return __getattr__ + + +__getattr__ = _make_getattr(__name__) diff --git a/.venv/lib/python3.11/site-packages/attr/__pycache__/_funcs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/attr/__pycache__/_funcs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4874e79260330b4bff1899e0ed74c25a054524cd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/attr/__pycache__/_funcs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/attr/__pycache__/_version_info.cpython-311.pyc b/.venv/lib/python3.11/site-packages/attr/__pycache__/_version_info.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9e26f4a0ced971c58c43fe13c6098844915f682 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/attr/__pycache__/_version_info.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/attr/__pycache__/converters.cpython-311.pyc b/.venv/lib/python3.11/site-packages/attr/__pycache__/converters.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2457564025ec88eb1fbf943819813638d073ba89 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/attr/__pycache__/converters.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/attr/__pycache__/setters.cpython-311.pyc b/.venv/lib/python3.11/site-packages/attr/__pycache__/setters.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b54912f5daf573140e3e074d9b0ed6748a9d1967 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/attr/__pycache__/setters.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/attr/_compat.py b/.venv/lib/python3.11/site-packages/attr/_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..22fcd78387b7b36f005ec5eee3fbf784ba87a93d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/attr/_compat.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: MIT + +import inspect +import platform +import sys +import threading + +from collections.abc import Mapping, Sequence # noqa: F401 +from typing import _GenericAlias + + +PYPY = platform.python_implementation() == "PyPy" +PY_3_9_PLUS = sys.version_info[:2] >= (3, 9) +PY_3_10_PLUS = sys.version_info[:2] >= (3, 10) +PY_3_11_PLUS = sys.version_info[:2] >= (3, 11) +PY_3_12_PLUS = sys.version_info[:2] >= (3, 12) +PY_3_13_PLUS = sys.version_info[:2] >= (3, 13) +PY_3_14_PLUS = sys.version_info[:2] >= (3, 14) + + +if PY_3_14_PLUS: # pragma: no cover + import annotationlib + + _get_annotations = annotationlib.get_annotations + +else: + + def _get_annotations(cls): + """ + Get annotations for *cls*. + """ + return cls.__dict__.get("__annotations__", {}) + + +class _AnnotationExtractor: + """ + Extract type annotations from a callable, returning None whenever there + is none. + """ + + __slots__ = ["sig"] + + def __init__(self, callable): + try: + self.sig = inspect.signature(callable) + except (ValueError, TypeError): # inspect failed + self.sig = None + + def get_first_param_type(self): + """ + Return the type annotation of the first argument if it's not empty. + """ + if not self.sig: + return None + + params = list(self.sig.parameters.values()) + if params and params[0].annotation is not inspect.Parameter.empty: + return params[0].annotation + + return None + + def get_return_type(self): + """ + Return the return type if it's not empty. + """ + if ( + self.sig + and self.sig.return_annotation is not inspect.Signature.empty + ): + return self.sig.return_annotation + + return None + + +# Thread-local global to track attrs instances which are already being repr'd. +# This is needed because there is no other (thread-safe) way to pass info +# about the instances that are already being repr'd through the call stack +# in order to ensure we don't perform infinite recursion. +# +# For instance, if an instance contains a dict which contains that instance, +# we need to know that we're already repr'ing the outside instance from within +# the dict's repr() call. +# +# This lives here rather than in _make.py so that the functions in _make.py +# don't have a direct reference to the thread-local in their globals dict. +# If they have such a reference, it breaks cloudpickle. +repr_context = threading.local() + + +def get_generic_base(cl): + """If this is a generic class (A[str]), return the generic base for it.""" + if cl.__class__ is _GenericAlias: + return cl.__origin__ + return None diff --git a/.venv/lib/python3.11/site-packages/attr/_config.py b/.venv/lib/python3.11/site-packages/attr/_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4b257726fb1e8b95583ecc3eee8d153336dc4089 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/attr/_config.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: MIT + +__all__ = ["get_run_validators", "set_run_validators"] + +_run_validators = True + + +def set_run_validators(run): + """ + Set whether or not validators are run. By default, they are run. + + .. deprecated:: 21.3.0 It will not be removed, but it also will not be + moved to new ``attrs`` namespace. Use `attrs.validators.set_disabled()` + instead. + """ + if not isinstance(run, bool): + msg = "'run' must be bool." + raise TypeError(msg) + global _run_validators + _run_validators = run + + +def get_run_validators(): + """ + Return whether or not validators are run. + + .. deprecated:: 21.3.0 It will not be removed, but it also will not be + moved to new ``attrs`` namespace. Use `attrs.validators.get_disabled()` + instead. + """ + return _run_validators diff --git a/.venv/lib/python3.11/site-packages/attr/_funcs.py b/.venv/lib/python3.11/site-packages/attr/_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..c39fb8aa5a9426c18157253aad4b0168084eeb1a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/attr/_funcs.py @@ -0,0 +1,468 @@ +# SPDX-License-Identifier: MIT + + +import copy + +from ._compat import PY_3_9_PLUS, get_generic_base +from ._make import _OBJ_SETATTR, NOTHING, fields +from .exceptions import AttrsAttributeNotFoundError + + +def asdict( + inst, + recurse=True, + filter=None, + dict_factory=dict, + retain_collection_types=False, + value_serializer=None, +): + """ + Return the *attrs* attribute values of *inst* as a dict. + + Optionally recurse into other *attrs*-decorated classes. + + Args: + inst: Instance of an *attrs*-decorated class. + + recurse (bool): Recurse into classes that are also *attrs*-decorated. + + filter (~typing.Callable): + A callable whose return code determines whether an attribute or + element is included (`True`) or dropped (`False`). Is called with + the `attrs.Attribute` as the first argument and the value as the + second argument. + + dict_factory (~typing.Callable): + A callable to produce dictionaries from. For example, to produce + ordered dictionaries instead of normal Python dictionaries, pass in + ``collections.OrderedDict``. + + retain_collection_types (bool): + Do not convert to `list` when encountering an attribute whose type + is `tuple` or `set`. Only meaningful if *recurse* is `True`. + + value_serializer (typing.Callable | None): + A hook that is called for every attribute or dict key/value. It + receives the current instance, field and value and must return the + (updated) value. The hook is run *after* the optional *filter* has + been applied. + + Returns: + Return type of *dict_factory*. + + Raises: + attrs.exceptions.NotAnAttrsClassError: + If *cls* is not an *attrs* class. + + .. versionadded:: 16.0.0 *dict_factory* + .. versionadded:: 16.1.0 *retain_collection_types* + .. versionadded:: 20.3.0 *value_serializer* + .. versionadded:: 21.3.0 + If a dict has a collection for a key, it is serialized as a tuple. + """ + attrs = fields(inst.__class__) + rv = dict_factory() + for a in attrs: + v = getattr(inst, a.name) + if filter is not None and not filter(a, v): + continue + + if value_serializer is not None: + v = value_serializer(inst, a, v) + + if recurse is True: + if has(v.__class__): + rv[a.name] = asdict( + v, + recurse=True, + filter=filter, + dict_factory=dict_factory, + retain_collection_types=retain_collection_types, + value_serializer=value_serializer, + ) + elif isinstance(v, (tuple, list, set, frozenset)): + cf = v.__class__ if retain_collection_types is True else list + items = [ + _asdict_anything( + i, + is_key=False, + filter=filter, + dict_factory=dict_factory, + retain_collection_types=retain_collection_types, + value_serializer=value_serializer, + ) + for i in v + ] + try: + rv[a.name] = cf(items) + except TypeError: + if not issubclass(cf, tuple): + raise + # Workaround for TypeError: cf.__new__() missing 1 required + # positional argument (which appears, for a namedturle) + rv[a.name] = cf(*items) + elif isinstance(v, dict): + df = dict_factory + rv[a.name] = df( + ( + _asdict_anything( + kk, + is_key=True, + filter=filter, + dict_factory=df, + retain_collection_types=retain_collection_types, + value_serializer=value_serializer, + ), + _asdict_anything( + vv, + is_key=False, + filter=filter, + dict_factory=df, + retain_collection_types=retain_collection_types, + value_serializer=value_serializer, + ), + ) + for kk, vv in v.items() + ) + else: + rv[a.name] = v + else: + rv[a.name] = v + return rv + + +def _asdict_anything( + val, + is_key, + filter, + dict_factory, + retain_collection_types, + value_serializer, +): + """ + ``asdict`` only works on attrs instances, this works on anything. + """ + if getattr(val.__class__, "__attrs_attrs__", None) is not None: + # Attrs class. + rv = asdict( + val, + recurse=True, + filter=filter, + dict_factory=dict_factory, + retain_collection_types=retain_collection_types, + value_serializer=value_serializer, + ) + elif isinstance(val, (tuple, list, set, frozenset)): + if retain_collection_types is True: + cf = val.__class__ + elif is_key: + cf = tuple + else: + cf = list + + rv = cf( + [ + _asdict_anything( + i, + is_key=False, + filter=filter, + dict_factory=dict_factory, + retain_collection_types=retain_collection_types, + value_serializer=value_serializer, + ) + for i in val + ] + ) + elif isinstance(val, dict): + df = dict_factory + rv = df( + ( + _asdict_anything( + kk, + is_key=True, + filter=filter, + dict_factory=df, + retain_collection_types=retain_collection_types, + value_serializer=value_serializer, + ), + _asdict_anything( + vv, + is_key=False, + filter=filter, + dict_factory=df, + retain_collection_types=retain_collection_types, + value_serializer=value_serializer, + ), + ) + for kk, vv in val.items() + ) + else: + rv = val + if value_serializer is not None: + rv = value_serializer(None, None, rv) + + return rv + + +def astuple( + inst, + recurse=True, + filter=None, + tuple_factory=tuple, + retain_collection_types=False, +): + """ + Return the *attrs* attribute values of *inst* as a tuple. + + Optionally recurse into other *attrs*-decorated classes. + + Args: + inst: Instance of an *attrs*-decorated class. + + recurse (bool): + Recurse into classes that are also *attrs*-decorated. + + filter (~typing.Callable): + A callable whose return code determines whether an attribute or + element is included (`True`) or dropped (`False`). Is called with + the `attrs.Attribute` as the first argument and the value as the + second argument. + + tuple_factory (~typing.Callable): + A callable to produce tuples from. For example, to produce lists + instead of tuples. + + retain_collection_types (bool): + Do not convert to `list` or `dict` when encountering an attribute + which type is `tuple`, `dict` or `set`. Only meaningful if + *recurse* is `True`. + + Returns: + Return type of *tuple_factory* + + Raises: + attrs.exceptions.NotAnAttrsClassError: + If *cls* is not an *attrs* class. + + .. versionadded:: 16.2.0 + """ + attrs = fields(inst.__class__) + rv = [] + retain = retain_collection_types # Very long. :/ + for a in attrs: + v = getattr(inst, a.name) + if filter is not None and not filter(a, v): + continue + if recurse is True: + if has(v.__class__): + rv.append( + astuple( + v, + recurse=True, + filter=filter, + tuple_factory=tuple_factory, + retain_collection_types=retain, + ) + ) + elif isinstance(v, (tuple, list, set, frozenset)): + cf = v.__class__ if retain is True else list + items = [ + ( + astuple( + j, + recurse=True, + filter=filter, + tuple_factory=tuple_factory, + retain_collection_types=retain, + ) + if has(j.__class__) + else j + ) + for j in v + ] + try: + rv.append(cf(items)) + except TypeError: + if not issubclass(cf, tuple): + raise + # Workaround for TypeError: cf.__new__() missing 1 required + # positional argument (which appears, for a namedturle) + rv.append(cf(*items)) + elif isinstance(v, dict): + df = v.__class__ if retain is True else dict + rv.append( + df( + ( + ( + astuple( + kk, + tuple_factory=tuple_factory, + retain_collection_types=retain, + ) + if has(kk.__class__) + else kk + ), + ( + astuple( + vv, + tuple_factory=tuple_factory, + retain_collection_types=retain, + ) + if has(vv.__class__) + else vv + ), + ) + for kk, vv in v.items() + ) + ) + else: + rv.append(v) + else: + rv.append(v) + + return rv if tuple_factory is list else tuple_factory(rv) + + +def has(cls): + """ + Check whether *cls* is a class with *attrs* attributes. + + Args: + cls (type): Class to introspect. + + Raises: + TypeError: If *cls* is not a class. + + Returns: + bool: + """ + attrs = getattr(cls, "__attrs_attrs__", None) + if attrs is not None: + return True + + # No attrs, maybe it's a specialized generic (A[str])? + generic_base = get_generic_base(cls) + if generic_base is not None: + generic_attrs = getattr(generic_base, "__attrs_attrs__", None) + if generic_attrs is not None: + # Stick it on here for speed next time. + cls.__attrs_attrs__ = generic_attrs + return generic_attrs is not None + return False + + +def assoc(inst, **changes): + """ + Copy *inst* and apply *changes*. + + This is different from `evolve` that applies the changes to the arguments + that create the new instance. + + `evolve`'s behavior is preferable, but there are `edge cases`_ where it + doesn't work. Therefore `assoc` is deprecated, but will not be removed. + + .. _`edge cases`: https://github.com/python-attrs/attrs/issues/251 + + Args: + inst: Instance of a class with *attrs* attributes. + + changes: Keyword changes in the new copy. + + Returns: + A copy of inst with *changes* incorporated. + + Raises: + attrs.exceptions.AttrsAttributeNotFoundError: + If *attr_name* couldn't be found on *cls*. + + attrs.exceptions.NotAnAttrsClassError: + If *cls* is not an *attrs* class. + + .. deprecated:: 17.1.0 + Use `attrs.evolve` instead if you can. This function will not be + removed du to the slightly different approach compared to + `attrs.evolve`, though. + """ + new = copy.copy(inst) + attrs = fields(inst.__class__) + for k, v in changes.items(): + a = getattr(attrs, k, NOTHING) + if a is NOTHING: + msg = f"{k} is not an attrs attribute on {new.__class__}." + raise AttrsAttributeNotFoundError(msg) + _OBJ_SETATTR(new, k, v) + return new + + +def resolve_types( + cls, globalns=None, localns=None, attribs=None, include_extras=True +): + """ + Resolve any strings and forward annotations in type annotations. + + This is only required if you need concrete types in :class:`Attribute`'s + *type* field. In other words, you don't need to resolve your types if you + only use them for static type checking. + + With no arguments, names will be looked up in the module in which the class + was created. If this is not what you want, for example, if the name only + exists inside a method, you may pass *globalns* or *localns* to specify + other dictionaries in which to look up these names. See the docs of + `typing.get_type_hints` for more details. + + Args: + cls (type): Class to resolve. + + globalns (dict | None): Dictionary containing global variables. + + localns (dict | None): Dictionary containing local variables. + + attribs (list | None): + List of attribs for the given class. This is necessary when calling + from inside a ``field_transformer`` since *cls* is not an *attrs* + class yet. + + include_extras (bool): + Resolve more accurately, if possible. Pass ``include_extras`` to + ``typing.get_hints``, if supported by the typing module. On + supported Python versions (3.9+), this resolves the types more + accurately. + + Raises: + TypeError: If *cls* is not a class. + + attrs.exceptions.NotAnAttrsClassError: + If *cls* is not an *attrs* class and you didn't pass any attribs. + + NameError: If types cannot be resolved because of missing variables. + + Returns: + *cls* so you can use this function also as a class decorator. Please + note that you have to apply it **after** `attrs.define`. That means the + decorator has to come in the line **before** `attrs.define`. + + .. versionadded:: 20.1.0 + .. versionadded:: 21.1.0 *attribs* + .. versionadded:: 23.1.0 *include_extras* + """ + # Since calling get_type_hints is expensive we cache whether we've + # done it already. + if getattr(cls, "__attrs_types_resolved__", None) != cls: + import typing + + kwargs = {"globalns": globalns, "localns": localns} + + if PY_3_9_PLUS: + kwargs["include_extras"] = include_extras + + hints = typing.get_type_hints(cls, **kwargs) + for field in fields(cls) if attribs is None else attribs: + if field.name in hints: + # Since fields have been frozen we must work around it. + _OBJ_SETATTR(field, "type", hints[field.name]) + # We store the class we resolved so that subclasses know they haven't + # been resolved. + cls.__attrs_types_resolved__ = cls + + # Return the class so you can use it as a decorator too. + return cls diff --git a/.venv/lib/python3.11/site-packages/attr/_version_info.pyi b/.venv/lib/python3.11/site-packages/attr/_version_info.pyi new file mode 100644 index 0000000000000000000000000000000000000000..45ced086337783c4b73b26cd17d2c1c260e24029 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/attr/_version_info.pyi @@ -0,0 +1,9 @@ +class VersionInfo: + @property + def year(self) -> int: ... + @property + def minor(self) -> int: ... + @property + def micro(self) -> int: ... + @property + def releaselevel(self) -> str: ... diff --git a/.venv/lib/python3.11/site-packages/attr/exceptions.pyi b/.venv/lib/python3.11/site-packages/attr/exceptions.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f2680118b404db8f5227d04d27e8439331341c4d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/attr/exceptions.pyi @@ -0,0 +1,17 @@ +from typing import Any + +class FrozenError(AttributeError): + msg: str = ... + +class FrozenInstanceError(FrozenError): ... +class FrozenAttributeError(FrozenError): ... +class AttrsAttributeNotFoundError(ValueError): ... +class NotAnAttrsClassError(ValueError): ... +class DefaultAlreadySetError(RuntimeError): ... +class UnannotatedAttributeError(RuntimeError): ... +class PythonTooOldError(RuntimeError): ... + +class NotCallableError(TypeError): + msg: str = ... + value: Any = ... + def __init__(self, msg: str, value: Any) -> None: ... diff --git a/.venv/lib/python3.11/site-packages/attr/validators.pyi b/.venv/lib/python3.11/site-packages/attr/validators.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a0fdda7c8773f791103938fca0d4b448859aff1f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/attr/validators.pyi @@ -0,0 +1,86 @@ +from types import UnionType +from typing import ( + Any, + AnyStr, + Callable, + Container, + ContextManager, + Iterable, + Mapping, + Match, + Pattern, + TypeVar, + overload, +) + +from attrs import _ValidatorType +from attrs import _ValidatorArgType + +_T = TypeVar("_T") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") +_T3 = TypeVar("_T3") +_I = TypeVar("_I", bound=Iterable) +_K = TypeVar("_K") +_V = TypeVar("_V") +_M = TypeVar("_M", bound=Mapping) + +def set_disabled(run: bool) -> None: ... +def get_disabled() -> bool: ... +def disabled() -> ContextManager[None]: ... + +# To be more precise on instance_of use some overloads. +# If there are more than 3 items in the tuple then we fall back to Any +@overload +def instance_of(type: type[_T]) -> _ValidatorType[_T]: ... +@overload +def instance_of(type: tuple[type[_T]]) -> _ValidatorType[_T]: ... +@overload +def instance_of( + type: tuple[type[_T1], type[_T2]], +) -> _ValidatorType[_T1 | _T2]: ... +@overload +def instance_of( + type: tuple[type[_T1], type[_T2], type[_T3]], +) -> _ValidatorType[_T1 | _T2 | _T3]: ... +@overload +def instance_of(type: tuple[type, ...]) -> _ValidatorType[Any]: ... +@overload +def instance_of(type: UnionType) -> _ValidatorType[Any]: ... +def optional( + validator: ( + _ValidatorType[_T] + | list[_ValidatorType[_T]] + | tuple[_ValidatorType[_T]] + ), +) -> _ValidatorType[_T | None]: ... +def in_(options: Container[_T]) -> _ValidatorType[_T]: ... +def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ... +def matches_re( + regex: Pattern[AnyStr] | AnyStr, + flags: int = ..., + func: Callable[[AnyStr, AnyStr, int], Match[AnyStr] | None] | None = ..., +) -> _ValidatorType[AnyStr]: ... +def deep_iterable( + member_validator: _ValidatorArgType[_T], + iterable_validator: _ValidatorType[_I] | None = ..., +) -> _ValidatorType[_I]: ... +def deep_mapping( + key_validator: _ValidatorType[_K], + value_validator: _ValidatorType[_V], + mapping_validator: _ValidatorType[_M] | None = ..., +) -> _ValidatorType[_M]: ... +def is_callable() -> _ValidatorType[_T]: ... +def lt(val: _T) -> _ValidatorType[_T]: ... +def le(val: _T) -> _ValidatorType[_T]: ... +def ge(val: _T) -> _ValidatorType[_T]: ... +def gt(val: _T) -> _ValidatorType[_T]: ... +def max_len(length: int) -> _ValidatorType[_T]: ... +def min_len(length: int) -> _ValidatorType[_T]: ... +def not_( + validator: _ValidatorType[_T], + *, + msg: str | None = None, + exc_types: type[Exception] | Iterable[type[Exception]] = ..., +) -> _ValidatorType[_T]: ... +def or_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ... diff --git a/.venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so b/.venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..10719ef7577ffc291a96fde9ac32315ab9b8c994 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a6211b1e1e47f505c8b79cb8b191ba1169b99f917cd874de883cffd11aa9883 +size 406024 diff --git a/.venv/lib/python3.11/site-packages/outlines/__init__.py b/.venv/lib/python3.11/site-packages/outlines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..307d2ba6f484f7b4c416189374462ed762789d25 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/__init__.py @@ -0,0 +1,20 @@ +"""Outlines is a Generative Model Programming Framework.""" +import outlines.generate +import outlines.grammars +import outlines.models +import outlines.processors +import outlines.types +from outlines.base import vectorize +from outlines.caching import clear_cache, disable_cache, get_cache +from outlines.function import Function +from outlines.prompts import prompt + +__all__ = [ + "clear_cache", + "disable_cache", + "get_cache", + "Function", + "prompt", + "vectorize", + "grammars", +] diff --git a/.venv/lib/python3.11/site-packages/outlines/_version.py b/.venv/lib/python3.11/site-packages/outlines/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..72c1ac0e4780433a3be4a9bc80c43962a048f6e5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/_version.py @@ -0,0 +1,16 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '0.1.11' +__version_tuple__ = version_tuple = (0, 1, 11) diff --git a/.venv/lib/python3.11/site-packages/outlines/base.py b/.venv/lib/python3.11/site-packages/outlines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..29d42c54c2835570cef97b94520bc06e69fcc80c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/base.py @@ -0,0 +1,299 @@ +import asyncio +import builtins +import functools +import inspect +from typing import Callable, Optional + +import numpy as np + +# Import required functions based on NumPy version +np_major_version = int(np.__version__.split(".")[0]) +if np_major_version >= 2: + from numpy.lib._function_base_impl import ( + _calculate_shapes, + _parse_gufunc_signature, + _parse_input_dimensions, + _update_dim_sizes, + ) +else: + from numpy.lib.function_base import ( + _calculate_shapes, + _parse_gufunc_signature, + _parse_input_dimensions, + _update_dim_sizes, + ) + +# Allow nested loops for running in notebook. We don't enable it globally as it +# may interfere with other libraries that use asyncio. +if hasattr(builtins, "__IPYTHON__"): + try: + import nest_asyncio + + nest_asyncio.apply() + except ImportError: + print( + "Couldn't patch nest_asyncio because it's not installed. Running in the notebook might be have issues" + ) + + +class vectorize: + """Returns an object that acts like a function but takes arrays as an input. + + The vectorized function evaluates `func` over successive tuples of the input + chararrays and returns a single NumPy chararrays or a tuple of NumPy chararrays. + + Its behavior is similar to NumPy's `vectorize` for Python functions: the function + being vectorized is executed in a `for` loop. Coroutines, however, are executed + concurrently. + + Part of the code was adapted from `numpy.lib.function_base`. + + """ + + def __init__(self, func: Callable, signature: Optional[str] = None): + self.func = func + self.signature = signature + self.is_coroutine_fn = inspect.iscoroutinefunction(func) + + functools.update_wrapper(self, func) + + if signature is not None: + # Parse the signature string into a Python data structure. + # For instance "(m),(s)->(s,m)" becomes `([(m,),(s,)],[(s,m)])`. + self._in_and_out_core_dimensions = _parse_gufunc_signature(signature) + else: + self._in_and_out_core_dimensions = None + + def __call__(self, *args, **kwargs): + """Call the vectorized function.""" + if not args and not kwargs: + return self.call_thunk() + elif self.signature is not None: + return self.call_with_signature(*args, **kwargs) + else: + return self.call_no_signature(*args, **kwargs) + + def call_thunk(self): + """Call a vectorized thunk. + + Thunks have no arguments and can thus be called directly. + + """ + if self.is_coroutine_fn: + loop = asyncio.new_event_loop() + try: + outputs = loop.run_until_complete(self.func()) + finally: + loop.close() + else: + outputs = self.func() + + return outputs + + def call_no_signature(self, *args, **kwargs): + """Call functions and coroutines when no signature is specified. + + When no signature is specified we assume that all of the function's + inputs and outputs are scalars (core dimension of zero). We first + broadcast the input arrays, then iteratively apply the function over the + elements of the broadcasted arrays and finally reshape the results to + match the input shape. + + Functions are executed in a for loop, coroutines are executed + concurrently. + + """ + # Convert args and kwargs to arrays + args = [np.array(arg) for arg in args] + kwargs = {key: np.array(value) for key, value in kwargs.items()} + + # Broadcast args and kwargs + broadcast_shape = np.broadcast(*args, *list(kwargs.values())).shape + args = [np.broadcast_to(arg, broadcast_shape) for arg in args] + kwargs = { + key: np.broadcast_to(value, broadcast_shape) + for key, value in kwargs.items() + } + + # Execute functions in a loop, and coroutines concurrently + if self.is_coroutine_fn: + outputs = self.vectorize_call_coroutine(broadcast_shape, args, kwargs) + else: + outputs = self.vectorize_call(broadcast_shape, args, kwargs) + + # `outputs` is a flat array or a tuple of flat arrays. We reshape the arrays + # to match the input shape. + outputs = [ + results if isinstance(results, tuple) else (results,) for results in outputs + ] + outputs = tuple( + [np.asarray(x).reshape(broadcast_shape).squeeze() for x in zip(*outputs)] + ) + outputs = tuple([x.item() if np.ndim(x) == 0 else x for x in outputs]) + + n_results = len(list(outputs)) + + return outputs[0] if n_results == 1 else outputs + + def call_with_signature(self, *args, **kwargs): + """Call functions and coroutines when a signature is specified.""" + input_core_dims, output_core_dims = self._in_and_out_core_dimensions + + # Make sure that the numbers of arguments passed is compatible with + # the signature. + num_args = len(args) + len(kwargs) + if num_args != len(input_core_dims): + raise TypeError( + "wrong number of positional arguments: " + "expected %r, got %r" % (len(input_core_dims), len(args)) + ) + + # Convert args and kwargs to arrays + args = [np.asarray(arg) for arg in args] + kwargs = {key: np.array(value) for key, value in kwargs.items()} + + # Find the arguments' broadcast shape, and map placeholder + # variables in the signature to the number of dimensions + # they correspond to given the arguments. + broadcast_shape, dim_sizes = _parse_input_dimensions( + args + list(kwargs.values()), input_core_dims + ) + + # Calculate the shape to which each of the arguments should be broadcasted + # and reshape them accordingly. + input_shapes = _calculate_shapes(broadcast_shape, dim_sizes, input_core_dims) + args = [ + np.broadcast_to(arg, shape, subok=True) + for arg, shape in zip(args, input_shapes) + ] + kwargs = { + key: np.broadcast_to(value, broadcast_shape) + for key, value in kwargs.items() + } + + n_out = len(output_core_dims) + + if self.is_coroutine_fn: + outputs = self.vectorize_call_coroutine(broadcast_shape, args, kwargs) + else: + outputs = self.vectorize_call(broadcast_shape, args, kwargs) + + outputs = [ + results if isinstance(results, tuple) else (results,) for results in outputs + ] + + flat_outputs = list(zip(*outputs)) + n_results = len(flat_outputs) + + if n_out != n_results: + raise ValueError( + f"wrong number of outputs from the function, expected {n_out}, got {n_results}" + ) + + # The number of dimensions of the outputs are not necessarily known in + # advance. The following iterates over the results and updates the + # number of dimensions of the outputs accordingly. + for results, core_dims in zip(flat_outputs, output_core_dims): + for result in results: + _update_dim_sizes(dim_sizes, result, core_dims) + + # Calculate the shape to which each of the outputs should be broadcasted + # and reshape them. + shapes = _calculate_shapes(broadcast_shape, dim_sizes, output_core_dims) + outputs = tuple( + [ + np.hstack(results).reshape(shape).squeeze() + for shape, results in zip(shapes, zip(*outputs)) + ] + ) + outputs = tuple([x.item() if np.ndim(x) == 0 else x for x in outputs]) + + return outputs[0] if n_results == 1 else outputs + + def vectorize_call(self, broadcast_shape, args, kwargs): + """Run the function in a for loop. + + A possible extension would be to parallelize the calls. + + Parameters + ---------- + broadcast_shape + The brodcast shape of the input arrays. + args + The function's broadcasted arguments. + kwargs + The function's broadcasted keyword arguments. + + """ + outputs = [] + for index in np.ndindex(*broadcast_shape): + current_args = tuple(arg[index] for arg in args) + current_kwargs = {key: value[index] for key, value in kwargs.items()} + outputs.append(self.func(*current_args, **current_kwargs)) + + return outputs + + def vectorize_call_coroutine(self, broadcast_shape, args, kwargs): + """Run coroutines concurrently. + + Creates as many tasks as needed and executes them in a new event + loop. + + Parameters + ---------- + broadcast_shape + The brodcast shape of the input arrays. + args + The function's broadcasted arguments. + kwargs + The function's broadcasted keyword arguments. + + """ + + async def create_and_gather_tasks(): + tasks = [] + for index in np.ndindex(*broadcast_shape): + current_args = tuple(arg[index] for arg in args) + current_kwargs = {key: value[index] for key, value in kwargs.items()} + tasks.append(self.func(*current_args, **current_kwargs)) + + outputs = await asyncio.gather(*tasks) + + return outputs + + loop = asyncio.new_event_loop() + try: + outputs = loop.run_until_complete(create_and_gather_tasks()) + finally: + loop.close() + + return outputs + + +def _update_arrays_type(arrays, results): + """Update the dtype of arrays. + + String arrays contain strings of fixed length. Here they are initialized with + the type of the first results, so that if the next results contain longer + strings they will be truncated when added to the output arrays. Here we + update the type if the current results contain longer strings than in the + current output array. + + Parameters + ---------- + arrays + Arrays that contain the vectorized function's results. + results + The current output of the function being vectorized. + + """ + + updated_arrays = [] + for array, result in zip(arrays, results): + if array.dtype.type == np.str_: + if array.dtype < np.array(result).dtype: + array = array.astype(np.array(result).dtype) + + updated_arrays.append(array) + + return tuple(updated_arrays) diff --git a/.venv/lib/python3.11/site-packages/outlines/caching.py b/.venv/lib/python3.11/site-packages/outlines/caching.py new file mode 100644 index 0000000000000000000000000000000000000000..6fdda6214b06bbe996f23b5e64a642f5d6aceecb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/caching.py @@ -0,0 +1,179 @@ +import asyncio +import contextlib +import functools +import os +from typing import Callable, Optional + +import cloudpickle +from diskcache import Cache, Disk +from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name + +_caching_enabled = True + + +class CloudpickleDisk(Disk): + def __init__(self, directory, compress_level=1, **kwargs): + self.compress_level = compress_level + super().__init__(directory, **kwargs) + + def put(self, key): + data = cloudpickle.dumps(key) + return super().put(data) + + def get(self, key, raw): + data = super().get(key, raw) + return cloudpickle.loads(data) + + def store(self, value, read, key=UNKNOWN): + if not read: + value = cloudpickle.dumps(value) + return super().store(value, read, key=key) + + def fetch(self, mode, filename, value, read): + data = super().fetch(mode, filename, value, read) + if not read: + data = cloudpickle.loads(data) + return data + + +@functools.lru_cache(1) +def get_cache(): + """Get the context object that contains previously-computed return values. + + The cache is used to avoid unnecessary computations and API calls, which can + be long and expensive for large models. + + The cache directory defaults to `HOMEDIR/.cache/outlines`, but this choice + can be overridden by the user by setting the value of the `OUTLINES_CACHE_DIR` + environment variable. + + """ + from outlines._version import __version__ as outlines_version # type: ignore + + home_dir = os.path.expanduser("~") + cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines") + memory = Cache( + cache_dir, + eviction_policy="none", + cull_limit=0, + disk=CloudpickleDisk, + ) + + # ensure if version upgrade occurs, old cache is pruned + if outlines_version != memory.get("__version__"): + memory.clear() + memory["__version__"] = outlines_version + + return memory + + +def cache(expire: Optional[float] = None, typed=False, ignore=()): + """Caching decorator for memoizing function calls. + + The cache key is created based on the values returned by the key_function callable + if provided or based on the arguments of the decorated function directly otherwise + + This is based on `diskcache`'s `memoize`. + + Parameters + ---------- + expire + Seconds until arguments expire. + typed + Cache different types separately. + ignore + Positional or keyword arguments to ignore. + + Returns + ------- + A decorator function that can be applied to other functions. + """ + + def decorator(cached_function: Callable): + memory = get_cache() + + base = (full_name(cached_function),) + + if asyncio.iscoroutinefunction(cached_function): + + async def wrapper(*args, **kwargs): + if not _caching_enabled: + return await cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = await cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + + else: + + def wrapper(*args, **kwargs): + if not _caching_enabled: + return cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + + def __cache_key__(*args, **kwargs): + """Make key for cache given function arguments.""" + return args_to_key(base, args, kwargs, typed, ignore) + + wrapper.__cache_key__ = __cache_key__ # type: ignore + wrapper.__memory__ = memory # type: ignore + wrapper.__wrapped__ = cached_function # type: ignore + + return wrapper + + return decorator + + +def disable_cache(): + """Disable the cache for this session. + + Generative models output different results each time they are called when + sampling. This can be a desirable property for some workflows, in which case + one can call `outlines.call.disable` to disable the cache for the session. + + This function does not delete the cache, call `outlines.cache.clear` + instead. It also does not overwrite the cache with the values returned + during the session. + + Example + ------- + + `outlines.cache.disable` should be called right after importing outlines: + + >>> import outlines.caching as cache + >>> cache.disable_cache() + + """ + global _caching_enabled + _caching_enabled = False + + +def clear_cache(): + """Erase the cache completely.""" + memory = get_cache() + memory.clear() + + +@contextlib.contextmanager +def cache_disabled(): + # outlines.caching._caching_enabled + global _caching_enabled + original_state = _caching_enabled + _caching_enabled = False + try: + yield + finally: + _caching_enabled = original_state diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/__init__.py b/.venv/lib/python3.11/site-packages/outlines/fsm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e38739a0df3e54e2d9ac8c8e1223a2c7ca89f802 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/guide.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/guide.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6340265f57b8aefb3fa4fd73409576b0dfa829e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/guide.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/json_schema.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/json_schema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e09d8bd22ea8d6ad9030657068aae5eb053c5f20 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/json_schema.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/parsing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/parsing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb0f11f5562aa64376171fc936f4bafe98112bbf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/parsing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/types.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87d8392996894154f42d20198d7d4bc4b4d08c5e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/types.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/guide.py b/.venv/lib/python3.11/site-packages/outlines/fsm/guide.py new file mode 100644 index 0000000000000000000000000000000000000000..6b97d7729ddf5a7b035afec0ef19cdbca50afdac --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/fsm/guide.py @@ -0,0 +1,276 @@ +import collections +import copy +import warnings +from typing import TYPE_CHECKING, Any, Generator, Union + +import torch +from lark.indenter import DedentError +from lark.lexer import UnexpectedCharacters, UnexpectedToken +from outlines_core.fsm.guide import Generate +from outlines_core.fsm.guide import Guide as CoreGuide +from outlines_core.fsm.guide import RegexGuide as CoreRegexGuide +from outlines_core.fsm.guide import Write +from outlines_core.fsm.guide import ( + create_states_mapping as uncached_create_states_mapping, +) + +from outlines import grammars +from outlines.fsm.parsing import PartialLark, PartialParserState + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + + +Instruction = Union[Write, Generate] + + +class Guide(CoreGuide): + """Base definition of a generation guide. + + A generation guide defines the behavior of a finite-state machine that guides + a text generation procedure. Unlike the DFAs built from regular expressions + guides can also emit a `Write` instructions which tells the model that it can + append a sequence of tokens (or token word) instead of generating it. + + """ + + initial_state: Any + + +class StopAtEOSGuide(Guide): + """Guide to generate tokens until the EOS token has been generated.""" + + final_state = 1 + start_state = 0 # TODO: remove start_state, use only initial_state + initial_state = 0 + + def __init__(self, tokenizer: "Tokenizer"): + """Initialize the generation guide. + + model + The logit generator used to generate the next token. + + """ + self.eos_token_id = tokenizer.eos_token_id + self.vocabulary = tokenizer.vocabulary.values() + + def get_next_instruction(self, state: int) -> Instruction: + if self.is_final_state(state): + return Write([self.eos_token_id]) + return Generate(None) + + def get_next_state(self, state: int, token_id: int) -> int: + if token_id == self.eos_token_id or state == self.final_state: + return self.final_state + + return self.initial_state + + def is_final_state(self, state: int): + return state == self.final_state + + def copy(self): + return self + + +def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs): + return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs) + + +class RegexGuide(CoreRegexGuide): + """ + Guide to generate text in the language of a regular expression. + CoreRegexGuide with outlines cache + """ + + @classmethod + def from_regex( + cls, + regex_string: str, + tokenizer, + **kwargs, + ): + return super().from_regex( + regex_string, + tokenizer, + _create_states_mapping=cached_create_states_mapping, + **kwargs, + ) + + +CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"]) + + +class CFGGuide(Guide): + """Guide to generate text that is in the language of a context-free Lark grammar.""" + + def __init__(self, cfg_string: str, tokenizer): + """ + Construct the PartialLark parser and set the empty initial_state (PartialParserState) + """ + warnings.warn( + "Outlines' public *community-contributed* CFG structured generation is experimental. " + "Please review https://dottxt-ai.github.io/outlines/latest/reference/generation/cfg#disclaimer" + ) + + self.cfg_string = cfg_string + self.tokenizer = tokenizer + self.eos_token_id = self.tokenizer.eos_token_id + self.parser = PartialLark( + cfg_string, + parser="lalr", + import_paths=[grammars.GRAMMAR_PATH], + ) + self.initial_state = CFGState( + parser_state=self.parser.parse(""), prev_token=None + ) + + def get_next_instruction(self, state: CFGState) -> Instruction: + """Return the next instruction for guided generation. + + Current lazy approach: + - For each token in the vocabulary + - create a copy of the parsers state + - add the tokens to the parsers input text + - if valid, add token to returned tokens + + Further refinements are necessary for performant text processing. + + Parameters + ---------- + state + The guides current PartialParserState, or None if complete + + Returns + ------- + A `Generate` instance that contains the model and the allowed token ids. + + """ + + if state.parser_state is None: + return Write(torch.tensor([self.eos_token_id])) + + valid_tokens = list( + self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values()) + ) + if len(valid_tokens) == 1: + return Write(torch.tensor(valid_tokens)) + return Generate(torch.tensor(valid_tokens)) + + def iter_valid_token_ids( + self, state: CFGState, candidate_token_ids: list + ) -> Generator[int, None, None]: + """ + Iterate over the given token_ids and yield those that are valid for the current parser state. + + Parameters + ---------- + parser_state + The current state of the parser, or None if complete. + token_ids + The list of token ids to check for validity. + + Yields + ------ + int + Valid token ids. + """ + if state.parser_state is None: + yield self.eos_token_id + return + + for token_id in candidate_token_ids: + if token_id == self.eos_token_id: + if self.can_terminate_state(state): + yield token_id + else: + try: + self._get_parser_state_token_applied(state, int(token_id)) + yield token_id + except ( + ValueError, + EOFError, + UnexpectedToken, + UnexpectedCharacters, + DedentError, + ): + pass + + def get_next_state(self, state: CFGState, token_id: int) -> CFGState: + """ + Update the state of the guide. + Decode the token_id, and calculate the new parser_state with the token applied. + + Parameters + ---------- + state + The guides current PartialParserState, or None if complete + token_id + The id of the token that was just generated. + + Returns + ------- + The guides new PartialParserState + + """ + if state.parser_state is None or token_id == self.eos_token_id: + parser_state = None + else: + parser_state = self._get_parser_state_token_applied(state, int(token_id)) + return CFGState(parser_state=parser_state, prev_token=token_id) + + def _get_parser_state_token_applied( + self, state: CFGState, token_id: int + ) -> PartialParserState: + """ + Don't mutate `parser_state`, copy to protect + + Get the token string + - if first token in generation: tokenizer.decode (no leading whitespace) + - else: normalized (with possibly leading whitespace) + + Don't allow empty ("") tokens, raise ValueError + """ + parser_state = copy.copy(state.parser_state) # prevent side effects + + # normalize + if state.prev_token is None: + new_token_str = self.tokenizer.decode([token_id])[0] + else: + prev_token_str = self.tokenizer.decode([[state.prev_token]])[0] + combined_token_str = self.tokenizer.decode([[state.prev_token, token_id]])[ + 0 + ] + new_token_str = combined_token_str[len(prev_token_str) :] + + if new_token_str == "": + raise ValueError("empty next token") + + # update parser with new token + parser_state.lexer.state.text += new_token_str + self.parser.parse_from_state(parser_state, is_end=False) + + return parser_state + + def is_final_state(self, state: CFGState) -> bool: + # TODO: remove this method, use can_terminate_state and must_terminate_state + # here and in RegexGuide per https://github.com/dottxt-ai/outlines/issues/885 + return self.can_terminate_state(state) + + def can_terminate_state(self, state: CFGState) -> bool: + """Generation is allowed to terminate""" + if state.parser_state is not None: + try: + copy.copy(state.parser_state).feed_eof() + except UnexpectedToken: + return False + return True + + def must_terminate_state(self, state: CFGState) -> bool: + """Generation must terminate, no legal continuations""" + return state.parser_state is None or set(state.parser_state.accepts()).issubset( + {"$END"} + ) + + def copy(self) -> "CFGGuide": + """Create a copy of the Guide.""" + return CFGGuide(self.cfg_string, self.tokenizer) diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/json_schema.py b/.venv/lib/python3.11/site-packages/outlines/fsm/json_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..578ee762661b6861c37de07c3cad0f42e4f5ff92 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/fsm/json_schema.py @@ -0,0 +1,83 @@ +import inspect +import json +import warnings +from enum import Enum +from typing import Callable, Type, Union + +from pydantic import BaseModel, create_model + + +def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: + """Convert a JSON schema to a string. + + Parameters + ---------- + json_schema + The JSON schema. + + Returns + ------- + str + The JSON schema converted to a string. + + Raises + ------ + ValueError + If the schema is not a dictionary, a string or a Pydantic class. + """ + if isinstance(json_schema, dict): + schema_str = json.dumps(json_schema) + elif isinstance(json_schema, str): + schema_str = json_schema + elif issubclass(json_schema, BaseModel): + schema_str = json.dumps(json_schema.model_json_schema()) + else: + raise ValueError( + f"Cannot parse schema {json_schema}. The schema must be either " + + "a Pydantic class, a dictionary or a string that contains the JSON " + + "schema specification" + ) + return schema_str + + +def get_schema_from_signature(fn: Callable) -> dict: + """Turn a function signature into a JSON schema. + + Every JSON object valid to the output JSON Schema can be passed + to `fn` using the ** unpacking syntax. + + """ + signature = inspect.signature(fn) + arguments = {} + for name, arg in signature.parameters.items(): + if arg.annotation == inspect._empty: + raise ValueError("Each argument must have a type annotation") + else: + arguments[name] = (arg.annotation, ...) + + try: + fn_name = fn.__name__ + except Exception as e: + fn_name = "Arguments" + warnings.warn( + f"The function name could not be determined. Using default name 'Arguments' instead. For debugging, here is exact error:\n{e}", + category=UserWarning, + ) + model = create_model(fn_name, **arguments) + + return model.model_json_schema() + + +def get_schema_from_enum(myenum: type[Enum]) -> dict: + if len(myenum) == 0: + raise ValueError( + f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)" + ) + choices = [ + get_schema_from_signature(elt.value.func) + if callable(elt.value) + else {"const": elt.value} + for elt in myenum + ] + schema = {"title": myenum.__name__, "oneOf": choices} + return schema diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/parsing.py b/.venv/lib/python3.11/site-packages/outlines/fsm/parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..e48fb69e49f130562904880f7913353535788b3a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/fsm/parsing.py @@ -0,0 +1,1127 @@ +from copy import copy, deepcopy +from dataclasses import dataclass +from functools import lru_cache +from typing import ( + Any, + Dict, + FrozenSet, + Generator, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import interegular +from interegular.fsm import FSM, Alphabet, OblivionError +from interegular.patterns import Unsupported +from lark import Lark, Token +from lark.common import LexerConf, ParserConf +from lark.exceptions import LexError, UnexpectedInput +from lark.indenter import Indenter +from lark.lexer import ( + BasicLexer, + ContextualLexer, + LexerState, + LexerThread, + Scanner, + UnexpectedCharacters, + UnexpectedToken, + _create_unless, +) +from lark.parser_frontends import ( + ParsingFrontend, + PostLexConnector, + _validate_frontend_args, +) +from lark.parsers.lalr_analysis import ( + Action, + IntParseTable, + LALR_Analyzer, + ParseTable, + Shift, +) +from lark.parsers.lalr_interactive_parser import InteractiveParser +from lark.parsers.lalr_parser import LALR_Parser, ParseConf, ParserState, _Parser +from outlines_core.fsm.regex import ( + BetterFSM, + get_token_transition_keys, + make_deterministic_fsm, +) + +PartialParseState = Tuple[str, int] +ParseStateType = Union[int, FrozenSet] + + +@dataclass +class PartialTerminalInfo: + priority: int + terminal_name: str + can_transition: bool + is_final: bool + + +@dataclass +class PartialTokensInfo: + fsm_state_seq: Tuple[int, ...] + is_not_finished: bool + terminals_and_info: Tuple[PartialTerminalInfo, ...] + final_terminals_and_info: Tuple[PartialTerminalInfo, ...] + + +class PartialParserConf(ParserConf): + __serialize_fields__ = ( + "rules", + "start", + "parser_type", + "deterministic", + "use_value_stack", + ) + + def __init__(self, rules, callbacks, start, deterministic, use_value_stack): + super().__init__(rules, callbacks, start) + self.deterministic = deterministic + self.use_value_stack = use_value_stack + + +class PartialLark(Lark): + __serialize_fields__ = ( + "parser", + "rules", + "options", + "deterministic", + "use_value_stack", + ) + + def __init__(self, grammar, **options): + # TODO: Could've extended `LarkOptions`, but all these extensions are + # already way too much (and brittle). This library really needs a + # complete refactoring. + self.deterministic = options.pop("deterministic", False) + self.use_value_stack = options.pop("use_value_stack", False) + options["regex"] = True + super().__init__(grammar, **options) + assert self.options.parser == "lalr" + + def _build_lexer(self, dont_ignore: bool = False) -> "PartialBasicLexer": + lexer_conf = self.lexer_conf + if dont_ignore: + from copy import copy + + lexer_conf = copy(lexer_conf) + lexer_conf.ignore = () + + return PartialBasicLexer(lexer_conf) + + def _build_parser(self) -> "PartialParsingFrontend": + self._prepare_callbacks() + _validate_frontend_args(self.options.parser, self.options.lexer) + parser_conf = PartialParserConf( + self.rules, + self._callbacks, + self.options.start, + self.deterministic, + self.use_value_stack, + ) + + # This is `_construct_parsing_frontend` expanded/inlined + parser_type = self.options.parser + lexer_type = self.options.lexer + lexer_conf = self.lexer_conf + + assert isinstance(lexer_conf, LexerConf) + assert isinstance(parser_conf, ParserConf) + parser_conf.parser_type = parser_type + self.lexer_conf.lexer_type = lexer_type + return PartialParsingFrontend(lexer_conf, parser_conf, self.options) + + def __repr__(self): + return "{}(open({!r}), parser={!r}, lexer={!r}, ...)".format( + type(self).__name__, + self.source_path, + self.options.parser, + self.options.lexer, + ) + + def parse_from_state(self, parse_state: "PartialParseState", is_end=False): + return self.parser.parser.parser.parse_from_state(parse_state, is_end=is_end) + + +class PartialLexerThread(LexerThread): + def __copy__(self): + return type(self)(copy(self.lexer), copy(self.state)) + + def __repr__(self): + return f"{type(self).__name__}(lexer={self.lexer!r}, state={self.state!r})" + + +class PartialPostLexConnector(PostLexConnector): + def __copy__(self): + return type(self)(self.lexer, copy(self.postlexer)) + + def __repr__(self): + return ( + f"{type(self).__name__}(lexer={self.lexer!r}, postlexer={self.postlexer!r})" + ) + + +class PartialParsingFrontend(ParsingFrontend): + def __init__(self, lexer_conf, parser_conf, options, parser=None): + assert parser_conf.parser_type == "lalr" + + options._plugins["LALR_Parser"] = PartialLALRParser + options._plugins["BasicLexer"] = PartialBasicLexer + options._plugins["ContextualLexer"] = PartialContextualLexer + options._plugins["LexerThread"] = PartialLexerThread + + super().__init__(lexer_conf, parser_conf, options, parser=parser) + + if lexer_conf.postlex: + self.lexer = PartialPostLexConnector(self.lexer.lexer, lexer_conf.postlex) + + self._termset_fsm_info = None + self._symbols_to_states: Optional[ + Dict[str, Set[Tuple[ParseStateType, Action]]] + ] = None + self._reverse_shifts: Optional[ + Dict[ParseStateType, Dict[str, Set[ParseStateType]]] + ] = None + # self._state_transition_map: Optional[ + # Dict[Tuple[ParseStateType, str], Set[ParseStateType]] + # ] = None + + def _compute_maps( + self, + ): + """Compute state transition and symbols-to-states maps.""" + self._reverse_shifts = {} + self._symbols_to_states = {} + + parse_table = self.parser.parser.parse_table + + for from_state, symbols_to_ops in parse_table.states.items(): + for symbol, op in symbols_to_ops.items(): + if op[0] == Shift: + symbols_to_from_states = self._reverse_shifts.setdefault(op[1], {}) + symbols_to_from_states.setdefault(symbol, set()).add(from_state) + self._symbols_to_states.setdefault(symbol, set()).add((from_state, op)) + + # # TODO: This approach is very wasteful. + # context_lexer = get_contextual_lexer(self) + # self._state_transition_map = {} + # + # for from_state, transitions in parse_table.states.items(): + # for symbol, action in transitions.items(): + # # TODO: Filter non-terminals + # if symbol not in context_lexer.root_lexer.terminals_by_name: + # continue + # + # if action[0] is Shift: + # self._state_transition_map.setdefault( + # (from_state, symbol), set() + # ).add(action[1]) + # continue + # + # antecedent_state_seqs = parse_to_terminal(self, [(from_state,)], symbol) + # + # for antecedent_state_seq in antecedent_state_seqs: + # antecedent_state = antecedent_state_seq[-1] + # self._state_transition_map.setdefault( + # (from_state, symbol), set() + # ).add(antecedent_state) + + def _compute_termset_fsm_info(self): + """Collect and return information about terminal symbol sets and their FSMs. + + Terminal symbol sets (or "termsets") are ordered sequences of terminal + symbols that are used by each parser state. Associated with each is a + collection of FSMs for each terminal and a single parse state FSM that is + the union of each terminal's FSM. + + This constructs a list of tuples containing the termset, the set of + parse states that use the termsets, parse state FSMs, and information + mapping the components of the parse state FSMs to their terminal symbol + FSMs. + + """ + context_lexer = get_contextual_lexer(self) + termsets_to_fsms = {} + termsets_to_parse_states: Dict[Tuple[str, ...], Set[ParseStateType]] = {} + for parse_state, lexer in context_lexer.lexers.items(): + scanner = lexer.scanner + key = tuple(term.name for term in scanner.terminals) + termsets_to_fsms[key] = (scanner.fsm, scanner.fsms_to_trans_finals) + termsets_to_parse_states.setdefault(key, set()).add(parse_state) + + self._termset_fsm_info = [ + ( + termset, + frozenset(termsets_to_parse_states[termset]), + fsm, + fsms_to_trans_finals, + ) + for termset, (fsm, fsms_to_trans_finals) in termsets_to_fsms.items() + ] + + @property + def termset_fsm_info(self): + if self._termset_fsm_info is None: + self._compute_termset_fsm_info() + return self._termset_fsm_info + + @property + def symbols_to_states(self): + if self._symbols_to_states is None: + self._compute_maps() + return self._symbols_to_states + + @property + def reverse_shifts(self): + if self._reverse_shifts is None: + self._compute_maps() + return self._reverse_shifts + + # @property + # def state_transition_map(self): + # if self._state_transition_map is None: + # self._compute_maps() + # return self._state_transition_map + + +class PartialLALRParser(LALR_Parser): + def __init__(self, parser_conf, debug=False, strict=False): + analysis = LALR_Analyzer( + parser_conf, debug=debug if not parser_conf.deterministic else True + ) + analysis.compute_lalr() + callbacks = parser_conf.callbacks + + self.parser_conf = parser_conf + self._parse_table = analysis.parse_table + + if parser_conf.deterministic: + old_to_new = {} + + def to_tuple(v): + new = old_to_new.get(v) + if new is None: + new = tuple(sorted(v, key=lambda y: str(y))) + old_to_new[v] = new + return new + + enum = sorted( + self._parse_table.states.keys(), + key=lambda x: str(sorted(x, key=lambda y: str(y))), + ) + + new_states = {} + for s in enum: + transitions = { + term: op if op[0] is not Shift else (op[0], to_tuple(op[1])) + for term, op in self._parse_table.states[s].items() + } + new_states[to_tuple(s)] = transitions + + self._parse_table = type(self._parse_table)( + new_states, + {k: to_tuple(v) for k, v in self._parse_table.start_states.items()}, + {k: to_tuple(v) for k, v in self._parse_table.end_states.items()}, + ) + + if not debug: + self._parse_table = IntParseTable.from_ParseTable(self._parse_table) + self.states_to_rulesets = dict( + zip(self._parse_table.states.keys(), new_states.keys()) + ) + + self.parser = PartialParser( + self._parse_table, + callbacks, + debug, + use_value_stack=parser_conf.use_value_stack, + ) + + @classmethod + def deserialize(cls, data, memo, callbacks, debug=False): + inst = cls.__new__(cls) + inst._parse_table = ParseTable.deserialize(data, memo) + inst.parser = PartialParser(inst._parse_table, callbacks, debug) + return inst + + +class PartialParserState(ParserState): + __slots__ = "use_value_stack" + + def __init__( + self, + parse_conf, + lexer, + state_stack=None, + value_stack=None, + use_value_stack=False, + ): + super().__init__( + parse_conf, lexer, state_stack=state_stack, value_stack=value_stack + ) + self.use_value_stack = use_value_stack + + def feed_token(self, token, is_end=False): + if token.type == "partial": + # If none of the potential terminals can transition, we need to know now + current_state = self.state_stack[-1] + current_lexer = get_contextual_lexer(self.lexer).lexers[current_state] + + # We have to feed the token and determine whether or not at least + # one terminal is consistent with the stack; otherwise, we'll miss + # invalid REDUCE cases. + # TODO: We should track separate parses conditional on possible + # token/symbol types, then we can coherently reuse the following + # results instead of recomputing it later. + can_transition = False + for terminal_info in token.value.terminals_and_info: + if terminal_info.terminal_name not in current_lexer.ignore_types: + test_token = Token.new_borrow_pos( + terminal_info.terminal_name, "", token + ) + + stack = copy(self.state_stack) + try: + self.feed_token_no_stack(test_token, is_end=is_end) + can_transition = True + break + except UnexpectedToken: + continue + finally: + self.state_stack = stack + else: + can_transition = True + + if not can_transition: + expected = { + s + for s in self.parse_conf.states[current_state].keys() + if s.isupper() + } + raise UnexpectedToken( + token, expected, state=self, interactive_parser=None + ) + + elif self.use_value_stack: + super().feed_token(token, is_end=is_end) + else: + self.feed_token_no_stack(token, is_end=is_end) + + def feed_token_no_stack(self, token, is_end=False): + """ + This is a copy of `ParserState.feed_token` with all the value stack + steps removed. Since we're not exactly parsing in order to obtain a + CST or anything similar, we can avoid the growing expense of tracking + the parse tree. + """ + state_stack = self.state_stack + states = self.parse_conf.states + end_state = self.parse_conf.end_state + + while True: + state = state_stack[-1] + try: + action, arg = states[state][token.type] + except KeyError: + expected = {s for s in states[state].keys() if s.isupper()} + raise UnexpectedToken( + token, expected, state=self, interactive_parser=None + ) + + assert arg != end_state + + if action is Shift: + # shift once and return + assert not is_end + state_stack.append(arg) + return + else: + # reduce+shift as many times as necessary + rule = arg + size = len(rule.expansion) + if size: + del state_stack[-size:] + + _action, new_state = states[state_stack[-1]][rule.origin.name] + assert _action is Shift + state_stack.append(new_state) + + if is_end and state_stack[-1] == end_state: + return + + def feed_eof(self): + last_token = self.lexer.state.last_token + + if last_token is None: + eof_token = self.lexer._Token("$END", "", 0, 1, 1) + else: + eof_token = Token.new_borrow_pos("$END", "", last_token) + + new_token_is_legal = ( + last_token is None + or last_token.type != "partial" + or any(ti.is_final for ti in last_token.value.terminals_and_info) + ) + if new_token_is_legal: + self.feed_token(eof_token, is_end=True) + else: + raise UnexpectedToken(eof_token, [], state=self, interactive_parser=None) + + def choices(self): + return self.parse_conf.parse_table.states[self.position] + + def accepts(self): + """ + Adapted from https://github.com/lark-parser/lark/blob/be542c2ff6d968817df019b8bf03f37b3111c08c/lark/parsers/lalr_interactive_parser.py#L95 + Returns the set of possible tokens that will advance the parser into a new valid state. + """ + accepts = set() + conf_no_callbacks = copy(self.parse_conf) + # We don't want to call callbacks here since those might have arbitrary side effects + # and are unnecessarily slow. + conf_no_callbacks.callbacks = {} + for t in self.choices(): + if t.isupper(): # is terminal? + new_state = copy(self) + new_state.parse_conf = conf_no_callbacks + try: + new_state.feed_token(new_state.lexer._Token(t, "")) + except UnexpectedToken: + pass + else: + accepts.add(t) + return accepts + + def __copy__(self): + return type(self)( + self.parse_conf, + copy(self.lexer), + copy(self.state_stack), + deepcopy(self.value_stack), + use_value_stack=self.use_value_stack, + ) + + def __repr__(self): + return f"{type(self).__name__}(lexer={self.lexer!r}, state_stack={self.state_stack!r})" + + +class PartialParser(_Parser): + def __init__(self, parse_table, callbacks, debug=False, use_value_stack=False): + super().__init__(parse_table, callbacks, debug=debug) + self.use_value_stack = use_value_stack + + def parse( + self, lexer, start, value_stack=None, state_stack=None, start_interactive=False + ): + parse_conf = ParseConf(self.parse_table, self.callbacks, start) + parser_state = PartialParserState( + parse_conf, copy(lexer), state_stack, value_stack, self.use_value_stack + ) + if start_interactive: + return InteractiveParser(self, parser_state, parser_state.lexer) + return self.parse_from_state(parser_state) + + def parse_from_state(self, state, last_token=None, is_end=False): + try: + token = last_token + for token in state.lexer.lex(state): + state.feed_token(token) + + if is_end and (not token or token.type != "partial"): + state.feed_eof() + + return state + except UnexpectedInput as e: + try: + e.interactive_parser = InteractiveParser(self, state, state.lexer) + except NameError: + pass + raise e + except Exception: + if self.debug: + print("") + print("STATE STACK DUMP") + print("----------------") + for i, s in enumerate(state.state_stack): + print("%d)" % i, s) + print("") + + raise + + +class PartialScanner(Scanner): + @classmethod + @lru_cache + def construct_terminal_fsm(cls, terminal): + # TODO: This should really be done at the lexer/parser level so that + # the lifetime of these objects is tied to the parser itself. + regex_str = terminal.pattern.to_regexp() + pattern = interegular.parse_pattern(regex_str) + fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) + return fsm, pattern.prefix_postfix + + def __init__(self, terminals, g_regex_flags, re_, use_bytes, match_whole=False): + self.terminals = terminals + self.g_regex_flags = g_regex_flags + self.use_bytes = use_bytes + self.match_whole = match_whole + self.allowed_types = {t.name for t in self.terminals} + self._mres = None + + fsms = [] + for t in self.terminals: + fsm, prefix_postfix = self.construct_terminal_fsm(t) + + # TODO FIXME: We don't support this right now. + assert prefix_postfix == (0, 0) + + fsms.append(fsm) + + self.fsm, self.fsms_to_trans_finals = fsm_union(fsms) + + def get_terminals_info( + self, fsm_state_seq + ) -> Tuple[Tuple[PartialTerminalInfo, ...], Tuple[PartialTerminalInfo, ...]]: + """Get the possible terminal symbols for an FSM state sequence.""" + terminals_and_info: Tuple[PartialTerminalInfo, ...] = () + final_terminals_and_info: Tuple[PartialTerminalInfo, ...] = () + for i, (fsm_id, fsm_reads_more, in_final) in enumerate( + get_sub_fsms_from_seq(fsm_state_seq, self.fsms_to_trans_finals) + ): + terminal_name = self.terminals[fsm_id].name + info = PartialTerminalInfo(i, terminal_name, fsm_reads_more, in_final) + terminals_and_info += (info,) + if in_final: + final_terminals_and_info += (info,) + + return terminals_and_info, final_terminals_and_info + + def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None): + """Determine an FSM match over `text` starting at `pos` and continuing `last_fsm_state_seq`.""" + + start_pos = pos + + if last_fsm_state_seq: + assert len(last_fsm_state_seq) > 1 + start_pos += len(last_fsm_state_seq) - 1 + start_state = last_fsm_state_seq[-1] + else: + start_state = self.fsm.initial + + text_part = text[start_pos:] + + text_transitions = get_token_transition_keys( + self.fsm.fsm_info.alphabet_symbol_mapping, + self.fsm.fsm_info.alphabet_anything_value, + text_part, + ) + + state_seq = walk_fsm( + self.fsm, + text_transitions, + start_state, + full_match=self.match_whole, + ) + + if not state_seq: + return None + + if last_fsm_state_seq: + res = last_fsm_state_seq + tuple(state_seq) + else: + res = (start_state,) + tuple(state_seq) + + return res + + +class PartialContextualLexer(ContextualLexer): + def __init__(self, conf: "LexerConf", states, always_accept=()): + terminals = list(conf.terminals) + terminals_by_name = conf.terminals_by_name + + trad_conf = copy(conf) + trad_conf.terminals = terminals + + lexer_by_symbols: Dict = {} + self.lexers = {} + for state, accepts in states.items(): + key = frozenset(accepts) + try: + lexer = lexer_by_symbols[key] + except KeyError: + accepts = set(accepts) | set(conf.ignore) | set(always_accept) + lexer_conf = copy(trad_conf) + lexer_conf.terminals = [ + terminals_by_name[n] for n in accepts if n in terminals_by_name + ] + if not lexer_conf.terminals: + continue + lexer = PartialBasicLexer(lexer_conf) + lexer_by_symbols[key] = lexer + + self.lexers[state] = lexer + + assert trad_conf.terminals is terminals + self.root_lexer = PartialBasicLexer(trad_conf) + + def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]: + try: + while True: + lexer = self.lexers[parser_state.position] + next_tok = lexer.next_token(lexer_state, parser_state) + yield next_tok + except EOFError: + pass + except KeyError: + if len(lexer_state.text) > lexer_state.line_ctr.char_pos: + raise UnexpectedCharacters( + lexer_state.text, + lexer_state.line_ctr.char_pos, + lexer_state.line_ctr.line, + lexer_state.line_ctr.column, + allowed=False, + token_history=lexer_state.last_token and [lexer_state.last_token], + state=parser_state, + terminals_by_name=self.root_lexer.terminals, + ) + + +class PartialBasicLexer(BasicLexer): + def __init__(self, conf: "LexerConf"): + super().__init__(conf) + # Eagerly construct the scanner + self._build_scanner() + + def _build_scanner(self): + # This seems incredibly convoluted: `lark` creates callback-triggered + # nested scanners for regex-defined terminals that overlap with + # string-defined terminals when both types of terminals have the same + # priority. Unless I'm missing something important, why not simply + # reorder the terminals so that the string-defined ones come before the + # regex-defined ones? + terminals, self.callback = _create_unless( + self.terminals, self.g_regex_flags, self.re, self.use_bytes + ) + + # We can't let people arbitrarily mess with the scanning process. + assert not self.user_callbacks + # for type_, f in self.user_callbacks.items(): + # if type_ in self.callback: + # # Already a callback there, probably UnlessCallback + # self.callback[type_] = CallChain( + # self.callback[type_], f, lambda t: t.type == type_ + # ) + # else: + # self.callback[type_] = f + + # We used the "callback" results to reorder the terminals (see the + # comments above). + for terminal_name, callback in self.callback.items(): + terminal = self.terminals_by_name[terminal_name] + for sub_terminal in callback.scanner.terminals: + self.terminals.remove(sub_terminal) + idx = self.terminals.index(terminal) + self.terminals.insert(idx, sub_terminal) + + self._scanner = PartialScanner( + self.terminals, self.g_regex_flags, self.re, self.use_bytes + ) + + def match(self, text, pos, last_fsm_state_seq=None): + return self.scanner.match(text, pos, last_fsm_state_seq) + + def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token: + last_token = lex_state.last_token + + last_fsm_state_seq = None + if last_token and last_token.type == "partial": + # Continue from last partial lexer state + last_fsm_state_seq = last_token.value.fsm_state_seq + + line_ctr = lex_state.line_ctr + end_pos = line_ctr.char_pos + ( + len(last_fsm_state_seq) - 1 if last_fsm_state_seq else 0 + ) + while end_pos < len(lex_state.text): + res = self.match(lex_state.text, line_ctr.char_pos, last_fsm_state_seq) + + if not res: + if ( + not last_fsm_state_seq + or last_fsm_state_seq[-1] not in self.scanner.fsm.finals + ): + allowed = self.scanner.allowed_types - self.ignore_types + if not allowed: + allowed = {""} + raise UnexpectedCharacters( + lex_state.text, + line_ctr.char_pos, + line_ctr.line, + line_ctr.column, + allowed=allowed, + token_history=lex_state.last_token and [lex_state.last_token], + state=parser_state, + terminals_by_name=self.terminals_by_name, + ) + + # The partial match might be complete now + fsm_state_seq = last_token.value.fsm_state_seq + terminals_and_info = last_token.value.terminals_and_info + final_terminals_and_info = last_token.value.final_terminals_and_info + else: + fsm_state_seq = res + ( + terminals_and_info, + final_terminals_and_info, + ) = self.scanner.get_terminals_info(fsm_state_seq) + + priority_terminal_info = ( + final_terminals_and_info[0] + if final_terminals_and_info + else terminals_and_info[0] + ) + + is_not_finished = ( + not priority_terminal_info.is_final + or priority_terminal_info.can_transition + or len(terminals_and_info) > 1 + ) + + start_pos = line_ctr.char_pos + end_pos = start_pos + len(fsm_state_seq) - 1 + + if end_pos >= len(lex_state.text) and is_not_finished: + type_name = "partial" + token_value = PartialTokensInfo( + fsm_state_seq, + is_not_finished, + terminals_and_info, + final_terminals_and_info, + ) + # Don't update the line counter states until we've finished + value = "" + else: + type_name = priority_terminal_info.terminal_name + # The token value should contain all partial scan parts in this + # case + value = token_value = lex_state.text[start_pos:end_pos] + + assert isinstance(self.callback, Dict) + + if type_name not in self.ignore_types: + t = Token( + type_name, + token_value, + line_ctr.char_pos, + line_ctr.line, + line_ctr.column, + ) + + line_ctr.feed(value, type_name in self.newline_types) + + t.end_line = line_ctr.line + t.end_column = line_ctr.column + t.end_pos = line_ctr.char_pos + if t.type in self.callback: + t = self.callback[t.type](t) + if not isinstance(t, Token): + raise LexError( + "Callbacks must return a token (returned %r)" % t + ) + lex_state.last_token = t + return t + + if type_name in self.callback: + t2 = Token( + type_name, value, line_ctr.char_pos, line_ctr.line, line_ctr.column + ) + self.callback[type_name](t2) + + line_ctr.feed(value, type_name in self.newline_types) + + last_fsm_state_seq = None + + raise EOFError(self) + + +class PartialIndenter(Indenter): + """An `Indenter` that doesn't reset its state every time `process` is called.""" + + def process(self, stream): + return self._process(stream) + + def _process(self, stream): + for token in stream: + # These were previously *after* the `yield`, but that makes the + # state tracking unnecessarily convoluted. + if token.type in self.OPEN_PAREN_types: + self.paren_level += 1 + elif token.type in self.CLOSE_PAREN_types: + self.paren_level -= 1 + if self.paren_level < 0: + raise UnexpectedToken(token, []) + + if token.type == self.NL_type: + yield from self.handle_NL(token) + else: + yield token + + # TODO: What do we want to do here? + # while len(self.indent_level) > 1: + # self.indent_level.pop() + # yield Token(self.DEDENT_type, "") + + def accepts_token_type(self, token_type): + if token_type in self.CLOSE_PAREN_types and self.paren_level - 1 < 0: + return False + + # TODO: + # if token_type == self.NL_type and self.paren_level == 0: + # ... + # return False + + return True + + def __copy__(self): + res = type(self)() + res.paren_level = self.paren_level + res.indent_level = copy(self.indent_level) + return res + + def __repr__(self): + return f"{type(self).__name__}(paren_level={self.paren_level!r}, indent_level={self.indent_level!r})" + + +class PartialPythonIndenter(PartialIndenter): + NL_type = "_NEWLINE" + OPEN_PAREN_types = ["LPAR", "LSQB", "LBRACE"] + CLOSE_PAREN_types = ["RPAR", "RSQB", "RBRACE"] + INDENT_type = "_INDENT" + DEDENT_type = "_DEDENT" + tab_len = 8 + + +def get_contextual_lexer(x: Union[PartialLexerThread, PartialParsingFrontend]): + if isinstance(x.lexer, ContextualLexer): + return x.lexer + else: + return x.lexer.lexer + + +def terminals_to_fsms(lp: PartialLark) -> Dict[str, FSM]: + """Construct a ``dict`` mapping terminal symbol names to their finite state machines.""" + + symbol_names_and_fsms = {} + for terminal in lp.terminals: + pattern = interegular.parse_pattern(terminal.pattern.to_regexp()) + # TODO: Use `pyparser.terminals[0].pattern.flags`? + try: + fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) + except Unsupported: + fsm = None + + symbol_names_and_fsms[terminal.name] = fsm + + return symbol_names_and_fsms + + +def fsm_union( + fsms: Sequence[FSM], +) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]: + """Construct an FSM representing the union of the FSMs in `fsms`. + + This is an updated version of `interegular.fsm.FSM.union` made to return an + extra map of component FSMs to the sets of state transitions that + correspond to them in the new FSM. + + """ + + alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms]) + + indexed_fsms = tuple(enumerate(fsms)) + + initial = {i: fsm.initial for (i, fsm) in indexed_fsms} + + # Dedicated function accepting a "superset" and returning the next + # "superset" obtained by following this transition in the new FSM + def follow(current_state, new_transition: int): + next = {} + for i, f in indexed_fsms: + old_transition = new_to_old[i][new_transition] + if ( + i in current_state + and current_state[i] in f.map + and old_transition in f.map[current_state[i]] + ): + next[i] = f.map[current_state[i]][old_transition] + if not next: + raise OblivionError + return next + + states = [initial] + finals: Set[int] = set() + map: Dict[int, Dict[int, int]] = {} + + # Map component FSMs to their new state-to-state transitions, finals, and a + # map translating component FSM states to aggregate FSM states + fsms_to_trans_finals: Dict[ + int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] + ] = {} + + i = 0 + while i < len(states): + state = states[i] + + # Add to the finals of the aggregate FSM whenever we hit a final in a + # component FSM + if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms): + finals.add(i) + + # Compute the map for this state + map[i] = {} + for transition in alphabet.by_transition: + try: + next = follow(state, transition) + except OblivionError: + # Reached an oblivion state; don't list it + continue + else: + try: + # TODO: Seems like this could--and should--be avoided + j = states.index(next) + except ValueError: + j = len(states) + states.append(next) + + map[i][transition] = j + + for fsm_id, fsm_state in next.items(): + ( + fsm_transitions, + fsm_finals, + fsm_old_to_new, + ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {})) + old_from = state[fsm_id] + old_to = fsm_state + fsm_old_to_new.setdefault(old_from, set()).add(i) + fsm_old_to_new.setdefault(old_to, set()).add(j) + fsm_transitions.add((i, j)) + if fsm_state in fsms[fsm_id].finals: + fsm_finals.add(j) + + i += 1 + + fsm = FSM( + alphabet=alphabet, + states=range(len(states)), + initial=0, + finals=finals, + map=map, + __no_validation__=True, + ) + + fsm, old_to_new_states = make_deterministic_fsm(fsm) + _fsms_to_trans_finals = { + fsm_id: ( + {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions}, + {old_to_new_states[s] for s in finals}, + { + old_state: {old_to_new_states[new_state] for new_state in new_states} + for old_state, new_states in old_to_new.items() + }, + ) + for fsm_id, (transitions, finals, old_to_new) in sorted( + fsms_to_trans_finals.items(), key=lambda x: x[0] + ) + } + + return ( + fsm, + _fsms_to_trans_finals, + ) + + +def get_sub_fsms_from_seq( + state_seq: Sequence[int], + fsms_to_trans_finals: Dict[ + int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] + ], +) -> Generator[Tuple[int, bool, bool], None, None]: + """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`. + + Parameters + ---------- + state_seq + A state sequence. + fsms_to_trans_finals + A map from FSM indices to tuples containing sets of their state transitions + and sets of the final/accept states. + + Returns + ------- + A generator returning tuples containing each sub-FSM index (in the order + they were union-ed to construct `fsm`) and booleans indicating whether or + not there is another valid transition from the last state in the sequence + for the associated sub-FSM (i.e. if the FSM can continue + accepting/matching) and whether or not the sequence ends in a final state + of the sub-FSM. + """ + state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:])) + last_fsm_state = state_seq[-1] + yield from ( + ( + # The sub-FMS index + fsm_idx, + # Is there another possible transition in this sub-FSM? + any(last_fsm_state == from_s for (from_s, to_s) in transitions), + # Is this sub-FSM in a final state? + state_seq[-1] in finals, + ) + for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items() + if state_seq_transitions.issubset(transitions) + ) + + +def walk_fsm( + fsm: BetterFSM, + token_transition_keys: Sequence[int], + start_state: int, + full_match: bool = True, +) -> List[int]: + fsm_finals = fsm.finals + + state = start_state + accepted_states: List[int] = [] + last_final_idx: int = 0 + + fsm_transitions = fsm.flat_transition_map + + # Iterate over token transition key sequence. The transition key + # sequence represents the FSM traversal rules of the tokens symbols. + for i, trans_key in enumerate(token_transition_keys): + new_state = fsm_transitions.get((state, trans_key)) + + if new_state is None: + if not full_match and last_final_idx > 0: + return accepted_states[:last_final_idx] + + return [] + + state = new_state + + if state in fsm_finals: + last_final_idx = i + 1 + + accepted_states.append(state) + + if full_match and last_final_idx - 1 != i: + return [] + + return accepted_states diff --git a/.venv/lib/python3.11/site-packages/outlines/fsm/types.py b/.venv/lib/python3.11/site-packages/outlines/fsm/types.py new file mode 100644 index 0000000000000000000000000000000000000000..5695dee0733946f1a6334a1242d575df65458c05 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/fsm/types.py @@ -0,0 +1,81 @@ +import datetime +from enum import EnumMeta +from typing import Any, Protocol, Tuple, Type + +from typing_extensions import _AnnotatedAlias, get_args + +INTEGER = r"[+-]?(0|[1-9][0-9]*)" +BOOLEAN = "(True|False)" +FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?" +DATE = r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])" +TIME = r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])" +DATETIME = rf"({DATE})(\s)({TIME})" + + +class FormatFunction(Protocol): + def __call__(self, sequence: str) -> Any: + ... + + +def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]: + # If it is a custom type + if isinstance(python_type, _AnnotatedAlias): + json_schema = get_args(python_type)[1].json_schema + type_class = get_args(python_type)[0] + + custom_regex_str = json_schema["pattern"] + + def custom_format_fn(sequence: str) -> Any: + return type_class(sequence) + + return custom_regex_str, custom_format_fn + + if isinstance(python_type, EnumMeta): + values = python_type.__members__.keys() + enum_regex_str: str = "(" + "|".join(values) + ")" + + def enum_format_fn(sequence: str) -> str: + return str(sequence) + + return enum_regex_str, enum_format_fn + + if python_type == float: + + def float_format_fn(sequence: str) -> float: + return float(sequence) + + return FLOAT, float_format_fn + elif python_type == int: + + def int_format_fn(sequence: str) -> int: + return int(sequence) + + return INTEGER, int_format_fn + elif python_type == bool: + + def bool_format_fn(sequence: str) -> bool: + return bool(sequence) + + return BOOLEAN, bool_format_fn + elif python_type == datetime.date: + + def date_format_fn(sequence: str) -> datetime.date: + return datetime.datetime.strptime(sequence, "%Y-%m-%d").date() + + return DATE, date_format_fn + elif python_type == datetime.time: + + def time_format_fn(sequence: str) -> datetime.time: + return datetime.datetime.strptime(sequence, "%H:%M:%S").time() + + return TIME, time_format_fn + elif python_type == datetime.datetime: + + def datetime_format_fn(sequence: str) -> datetime.datetime: + return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S") + + return DATETIME, datetime_format_fn + else: + raise NotImplementedError( + f"The Python type {python_type} is not supported. Please open an issue." + ) diff --git a/.venv/lib/python3.11/site-packages/outlines/function.py b/.venv/lib/python3.11/site-packages/outlines/function.py new file mode 100644 index 0000000000000000000000000000000000000000..48577be8f7d4db1050545ac4036c8063bc206bb1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/function.py @@ -0,0 +1,117 @@ +import importlib.util +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union + +import requests + +from outlines import generate, models + +if TYPE_CHECKING: + from outlines.generate.api import SequenceGenerator + from outlines.prompts import Prompt + + +@dataclass +class Function: + """Represents an Outlines function. + + Functions are a convenient way to encapsulate a prompt template, a language + model and a Pydantic model that define the output structure. Once defined, + the function can be called with arguments that will be used to render the + prompt template. + + """ + + prompt_template: "Prompt" + schema: Union[str, Callable, object] + model_name: str + generator: Optional["SequenceGenerator"] = None + + @classmethod + def from_github(cls, program_path: str, function_name: str = "fn"): + """Load a function stored on GitHub""" + program_content = download_from_github(program_path) + function = extract_function_from_file(program_content, function_name) + + return function + + def init_generator(self): + """Load the model and initialize the generator.""" + model = models.transformers(self.model_name) + self.generator = generate.json(model, self.schema) + + def __call__(self, *args, **kwargs): + """Call the function. + + .. warning:: + + This currently does not support batching. + + Parameters + ---------- + args + Values to pass to the prompt template as positional arguments. + kwargs + Values to pass to the prompt template as keyword arguments. + + """ + if self.generator is None: + self.init_generator() + + prompt = self.prompt_template(*args, **kwargs) + return self.generator(prompt) + + +def download_from_github(short_path: str): + """Download the file in which the function is stored on GitHub.""" + GITHUB_BASE_URL = "https://raw.githubusercontent.com" + BRANCH = "main" + + path = short_path.split("/") + if len(path) < 3: + raise ValueError( + "Please provide a valid path in the form {USERNAME}/{REPO_NAME}/{PATH_TO_FILE}." + ) + elif short_path[-3:] == ".py": + raise ValueError("Do not append the `.py` extension to the program name.") + + username = path[0] + repo = path[1] + path_to_file = path[2:] + + url = "/".join([GITHUB_BASE_URL, username, repo, BRANCH] + path_to_file) + ".py" + result = requests.get(url) + + if result.status_code == 200: + return result.text + elif result.status_code == 404: + raise ValueError( + f"Program could not be found at {url}. Please make sure you entered the GitHub username, repository name and path to the program correctly." + ) + else: + result.raise_for_status() + + +def extract_function_from_file(content: str, function_name: str) -> Tuple[Callable]: + """Extract a function object from a downloaded file.""" + + spec = importlib.util.spec_from_loader( + "outlines_function", loader=None, origin="github" + ) + if spec is not None: + module = importlib.util.module_from_spec(spec) + exec(content, module.__dict__) + + try: + fn = getattr(module, function_name) + except AttributeError: + raise AttributeError( + "Could not find an `outlines.Function` instance in the remote file. Make sure that the path you specified is correct." + ) + + if not isinstance(fn, module.outlines.Function): + raise TypeError( + f"The `{function_name}` variable in the program must be an instance of `outlines.Function`" + ) + + return fn diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/__init__.py b/.venv/lib/python3.11/site-packages/outlines/generate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f28cbd80d5ff947205256d2d5e740e212935fa83 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/__init__.py @@ -0,0 +1,8 @@ +from .api import SequenceGenerator +from .cfg import cfg +from .choice import choice +from .format import format +from .fsm import fsm +from .json import json +from .regex import regex +from .text import text diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/choice.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/choice.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8941b1cb9bacefaafeb9a872877165222db18de7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/choice.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/generator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/generator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9775a06337190664b1e419b038aa2d1ee820837 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/generator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/json.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/json.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e951ca2f21d039d4db5737936dd1591a1b620fb8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/json.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/api.py b/.venv/lib/python3.11/site-packages/outlines/generate/api.py new file mode 100644 index 0000000000000000000000000000000000000000..4919f20904e2d3d6ddb7a9609ddb8b82503ecdfa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/api.py @@ -0,0 +1,623 @@ +import datetime +from copy import copy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union + +from outlines.generate.generator import sequence_generator +from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler + +if TYPE_CHECKING: + import torch + +FormattedOutput = Union[ + str, int, float, bool, datetime.date, datetime.time, datetime.datetime +] + + +class SequenceGenerator: + def __init__( + self, + fsm, + model, + sampler, + device, + ): + self.fsm = fsm + self.model = model + self.sampler = sampler + self.tokenizer = model.tokenizer + self.device = device + self.num_samples = sampler.samples + + def get_generated_token_ids( + self, + prompt_token_ids: "torch.Tensor", + token_ids: "torch.Tensor", + ) -> List["torch.Tensor"]: + """Get the tokens generated so far. + + Parameters + ---------- + prompt_token_ids + Tensor that contains the token ids of the sequences' prompts. + token_ids + The generated token ids. + + Returns + ------- + A tensor that contains the token ids that have been generated so far. + + """ + prompt_lengths = [len(prompt) for prompt in prompt_token_ids] + token_ids = [ + cur_token_ids[length:] + for cur_token_ids, length in zip(token_ids, prompt_lengths) + ] + + return token_ids + + def is_stop_sequence_found( + self, generated_sequences: List[str], stop_sequences: List[str] + ) -> bool: + """Determine whether one of the stop sequences has been generated. + + Parameters + ---------- + generated_sequences + The list of sequences generated so far. + stop_sequences + The list that contains the sequence which stop the generation when + found. + + Returns + ------- + True if at least one of the stop sequences has been found in each generated + sequence. + + """ + return all( + [ + any([seq in generated for seq in stop_sequences]) + for generated in generated_sequences + ] + ) + + def strip_stop_sequences( + self, sequence: str, stop_sequences: Optional[List[str]] + ) -> str: + """Remove the stop sequences from the generated sequences. + + Parameters + ---------- + sequence + One of the generated sequences. + stop_sequences + The list that contains the sequence which stop the generation when + found. + + """ + if stop_sequences: + match_indexes = [sequence.find(seq) for seq in stop_sequences] + if any([index != -1 for index in match_indexes]): + # select the stop_sequence that is found first in the sequence + min_match_index_value = min([i for i in match_indexes if i != -1]) + min_match_index_pos = match_indexes.index(min_match_index_value) + sequence = sequence[ + : match_indexes[min_match_index_pos] + + len(stop_sequences[min_match_index_pos]) + ] + + return sequence + + def format_sequence(self, sequence: str) -> FormattedOutput: + """Translate the generated sequence to another type. + + This method is for instance overridden when generating JSON to either + return a dictionnary or a Pydantic model. + + Parameters + ---------- + sequence + A generated sequences. + + Returns + ------- + The formatted sequence. + + """ + return sequence + + def __call__( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + rng: Optional["torch.Generator"] = None, + ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: + """Generate the full text sequence. + + Since `SequenceGenerator.stream` calls the tokenizer at every step this + method loops over the generator returned by `sequence_generator` itself + so the tokenizer is called only once after all token ids have been + generated. + + Parameters + ---------- + prompts + A string or list of strings that are passed to the model before + generating the first token. + max_tokens + An integer representing maximum number of tokens that will be generated + (per prompt) + stop_at + A string or list of strings at which the text generated will stop + rng + The random number generator. Defaults to a non-seeded `torch.Generator` + instance. + + Returns + ------- + The generation(s), potentially cast to another type. + """ + import torch + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(stop_at, str): + stop_at = [stop_at] + + stop_sequences = stop_at + num_samples = self.num_samples + + if rng is None: + rng = torch.Generator(device=self.device) + rng.seed() + + prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) + prompt_token_ids = prompt_token_ids.to(self.device) + attention_masks = attention_masks.to(self.device) + + # To draw multiple samples we repeat the prompt as many times + # as there are samples. We copy the FSMs and initialize the + # FSM states. + num_samples = self.num_samples + batch_size = len(prompts) + + prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) + attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) + fsm_states = [0 for _ in range(batch_size * num_samples)] + fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] + weights = torch.zeros( + (batch_size * num_samples), dtype=torch.float, device=self.device + ) + + states = sequence_generator( + self.model, + self.sampler, + fsms, + prompt_token_ids, + weights, + attention_masks, + fsm_states, + rng=rng, + ) + + while True: + try: + last_state = next(states) + if max_tokens or stop_sequences: + token_ids = last_state.token_ids + generated_token_ids = self.get_generated_token_ids( + prompt_token_ids, token_ids + ) + if max_tokens and len(generated_token_ids[0]) >= max_tokens: + break + if stop_sequences and self.is_stop_sequence_found( + self.tokenizer.decode(generated_token_ids), stop_sequences + ): + break + except StopIteration: + break + + token_ids = last_state.token_ids + generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) + + generated = self.tokenizer.decode(generated_token_ids) + stripped = [ + self.strip_stop_sequences(sequence, stop_sequences) + for sequence in generated + ] + formatted = [self.format_sequence(sequence) for sequence in stripped] + + # We reshape the output to (batch_size, sample_size) + output: List[List[FormattedOutput]] = list() + for i in range(0, batch_size * num_samples, num_samples): + output.append(formatted[i : i + num_samples]) + + # We remove leading dimensions for the output + if batch_size == 1 and num_samples == 1: + return output[0][0] + elif batch_size == 1: + return output[0] + elif num_samples == 1: + return [samples[0] for samples in output] + else: + return output + + def stream( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + rng: Optional["torch.Generator"] = None, + ) -> Iterator[Union[List[str], str, List[List[str]]]]: + """Generate the text sequence one token at a time. + + Since `Tokenizer.decode` strips the whitespaces from the tokens we have no + choice but to decode the generated token ids at each step and compare the + current decoded strings to the previously decoded strings. + + Parameters + ---------- + prompts + A string or list of strings that are passed to the model before + generating the first token. + max_tokens + An integer representing maximum number of tokens that will be generated + (per prompt) + stop_at + A string or list of strings at which the text generated will stop + rng + The random number generator. Defaults to a non-seeded `torch.Generator` + instance. + + Returns + ------- + A string or list of strings that contain the generated text. + + """ + import torch + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(stop_at, str): + stop_at = [stop_at] + + stop_sequences = stop_at + num_samples = self.num_samples + + prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) + prompt_token_ids = prompt_token_ids.to(self.device) + attention_masks = attention_masks.to(prompt_token_ids.device) + + # To draw multiple samples we repeat the prompt as many times + # as there are samples. We copy the FSMs and initialize the + # FSM states. + num_samples = self.num_samples + batch_size = len(prompts) + + prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) + attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) + fsm_states = [0 for _ in range(batch_size * num_samples)] + fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] + weights = torch.zeros( + (batch_size * num_samples), + dtype=torch.float, + device=prompt_token_ids.device, + ) + + if rng is None: + rng = torch.Generator(device=prompt_token_ids.device) + rng.seed() + + states = sequence_generator( + self.model, + self.sampler, + fsms, + prompt_token_ids, + weights, + attention_masks, + fsm_states, + rng=rng, + ) + + def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: + previously_generated_sequences = [ + "" for _ in range(batch_size) + ] * num_samples + num_generated = 0 + is_stop_at_reached = [False for _ in range(batch_size)] * num_samples + while True: + if (max_tokens and num_generated >= max_tokens) or all( + is_stop_at_reached + ): + return + try: + sequence = next(states) + num_generated += 1 + except StopIteration: + return + generated_token_ids = sequence.token_ids[:, -num_generated:] + generated_sequences = self.tokenizer.decode(generated_token_ids) + if stop_sequences: + is_stop_at_reached = [ + stop + or self.is_stop_sequence_found( + [generated_sequence], stop_sequences + ) + for generated_sequence, stop in zip( + generated_sequences, is_stop_at_reached + ) + ] + + generated_sequences = [ + self.format_sequence( + self.strip_stop_sequences(sequence, stop_sequences) + ) + if stop + else sequence + for sequence, stop in zip( + generated_sequences, is_stop_at_reached + ) + ] + next_tokens = [ + token[len(sequence) :] + for token, sequence, stop in zip( + generated_sequences, + previously_generated_sequences, + is_stop_at_reached, + ) + ] + previously_generated_sequences = generated_sequences + # We reshape the output to (batch_size, sample_size) + output: List[List[str]] = list() + for i in range(0, batch_size * num_samples, num_samples): + output.append(next_tokens[i : i + num_samples]) + + # We remove leading dimensions for the output + if batch_size == 1 and num_samples == 1: + yield output[0][0] + elif batch_size == 1: + yield output[0] + elif num_samples == 1: + yield [samples[0] for samples in output] + else: + yield output + + return token_generator() + + +@dataclass(frozen=True) +class GenerationParameters: + """Generation parameters used in Outlines' public API.""" + + max_tokens: Optional[int] + stop_at: Optional[Union[str, List[str]]] + seed: Optional[int] + + +@dataclass(frozen=True) +class SamplingParameters: + """Sampling parameters available in Outlines.""" + + sampler: str + num_samples: int = 1 + top_p: Optional[float] = None + top_k: Optional[int] = None + temperature: Optional[float] = None + + +class SequenceGeneratorAdapter: + """Class used to unify the interface to the model providers' + generation functions. + + Attributes + ---------- + model + The wrapped model. + logits_processor + The logits processor to use to generate text. + sampler + The sampler to use to generate text. + + """ + + def __init__(self, model, logits_processor, sampler): + self.model = model + self.logits_processor = logits_processor + + if isinstance(sampler, MultinomialSampler): + self.sampling_params = SamplingParameters( + "multinomial", + sampler.samples, + sampler.top_p, + sampler.top_k, + sampler.temperature, + ) + elif isinstance(sampler, GreedySampler): + self.sampling_params = SamplingParameters( + "greedy", sampler.samples, None, None, 0.0 + ) + elif isinstance(sampler, BeamSearchSampler): + self.sampling_params = SamplingParameters( + "beam_search", sampler.samples, None, None, 1.0 + ) + + def prepare_generation_parameters( + self, + max_tokens: Optional[int], + stop_at: Optional[Union[str, List[str]]], + seed: Optional[int], + ): + if isinstance(stop_at, str): + stop_at = [stop_at] + + generation_params = GenerationParameters( + max_tokens, + stop_at, + seed, + ) + + return generation_params + + def format_sequence(self, sequence: str) -> FormattedOutput: + """Translate the generated sequence to another type. + + This method is for instance overridden when generating JSON to either + return a dictionnary or a Pydantic model. + + Parameters + ---------- + sequence + A generated sequences. + + Returns + ------- + The formatted sequence. + + """ + return sequence + + def _format(self, sequences): + """Apply formatting to every string in a completion.""" + if isinstance(sequences, list): + return [self._format(sequence) for sequence in sequences] + else: + return self.format_sequence(sequences) + + def __call__( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + seed: Optional[int] = None, + **model_specific_params, + ): + """Generate text from a prompt of list of prompts.""" + + generation_params = self.prepare_generation_parameters( + max_tokens, stop_at, seed + ) + + completions = self.model.generate( + prompts, + generation_params, + copy(self.logits_processor), + self.sampling_params, + **model_specific_params, + ) + + return self._format(completions) + + def stream( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + seed: Optional[int] = None, + **model_specific_params, + ): + """Return a text generator from a prompt or a list of prompts.""" + generation_params = self.prepare_generation_parameters( + max_tokens, stop_at, seed + ) + return self.model.stream( + prompts, + generation_params, + copy(self.logits_processor), + self.sampling_params, + **model_specific_params, + ) + + +class VisionSequenceGeneratorAdapter(SequenceGeneratorAdapter): + def __call__( # type: ignore + self, + prompts: Union[str, List[str]], + media: Union[str, Any], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + seed: Optional[int] = None, + **model_specific_params, + ): + """ + Generate text from a prompt of list of prompts. + + Media: A URI to construct media or media object itself. Used as AutoProcessor argument. + """ + prompts, media = self._validate_prompt_media_types(prompts, media) + + generation_params = self.prepare_generation_parameters( + max_tokens, stop_at, seed + ) + + completions = self.model.generate( + prompts, + media, + generation_params, + copy(self.logits_processor), + self.sampling_params, + **model_specific_params, + ) + + return self._format(completions) + + def stream( # type: ignore + self, + prompts: Union[str, List[str]], + media: List[Union[str, Any, List[Union[str, Any]]]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + seed: Optional[int] = None, + **model_specific_params, + ): + """Return a text generator from a prompt or a list of prompts.""" + prompts, media = self._validate_prompt_media_types(prompts, media) + generation_params = self.prepare_generation_parameters( + max_tokens, stop_at, seed + ) + return self.model.stream( + prompts, + media, + generation_params, + copy(self.logits_processor), + self.sampling_params, + **model_specific_params, + ) + + @classmethod + def _validate_prompt_media_types( + cls, + prompts: Union[str, List[str]], + media: Union[str, Any, List[Union[str, Any]]], + ) -> Union[Any, List[Any]]: + """ + Prepare media as PIL.Image and ensure for every prompt str there is one List[PIL.Image] + """ + + def valid_types(prompts, media): + from PIL import Image # type: ignore + + if isinstance(prompts, list): + if not isinstance(media, list) or len(prompts) != len(media): + return False + for subprompt, submedia in zip(prompts, media): + if not isinstance(subprompt, str) or not all( + isinstance(m, Image.Image) for m in submedia + ): + return False + elif isinstance(prompts, str): + if not all(isinstance(m, Image.Image) for m in media): + return False + return True + + if not valid_types(prompts, media): + raise TypeError( + "Expected (prompts, media) to be of type " + "(str, List[Image])), or (List[str], List[List[Image]]) " + f"instead got prompts={prompts}, media={media}" + ) + + return prompts, media diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/cfg.py b/.venv/lib/python3.11/site-packages/outlines/generate/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..b677040d5837e08c0627cc24b10d21f0b447f8a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/cfg.py @@ -0,0 +1,54 @@ +from functools import singledispatch + +from outlines.generate.api import ( + SequenceGeneratorAdapter, + VisionSequenceGeneratorAdapter, +) +from outlines.models import LlamaCpp, OpenAI, TransformersVision +from outlines.samplers import Sampler, multinomial + + +@singledispatch +def cfg( + model, cfg_str: str, sampler: Sampler = multinomial() +) -> SequenceGeneratorAdapter: + """Generate text in the language of a Context-Free Grammar + + Arguments + --------- + model: + An `outlines.model` instance. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGeneratorAdapter` instance that generates text. + + """ + from outlines.processors import CFGLogitsProcessor + + logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer) + return SequenceGeneratorAdapter(model, logits_processor, sampler) + + +@cfg.register(TransformersVision) +def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()): + from outlines.processors import CFGLogitsProcessor + + logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer) + return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) + + +@cfg.register(LlamaCpp) +def cfg_llamacpp(model, cfg_str: str, sampler: Sampler = multinomial()): + raise NotImplementedError("Not yet available due to bug in llama_cpp tokenizer") + + +@cfg.register(OpenAI) +def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()): + raise NotImplementedError( + "Cannot use grammar-structured generation with an OpenAI model" + + "due to the limitations of the OpenAI API." + ) diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/choice.py b/.venv/lib/python3.11/site-packages/outlines/generate/choice.py new file mode 100644 index 0000000000000000000000000000000000000000..afb998f52abae6bf53fc3efbfd9b7122d741a4f6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/choice.py @@ -0,0 +1,59 @@ +import json as pyjson +import re +from enum import Enum +from functools import singledispatch +from typing import Callable, List, Union + +from outlines_core.fsm.json_schema import build_regex_from_schema + +from outlines.fsm.json_schema import get_schema_from_enum +from outlines.generate.api import SequenceGeneratorAdapter +from outlines.models import OpenAI +from outlines.samplers import Sampler, multinomial + +from .json import json +from .regex import regex + + +@singledispatch +def choice( + model, choices: Union[List[str], type[Enum]], sampler: Sampler = multinomial() +) -> SequenceGeneratorAdapter: + if isinstance(choices, type(Enum)): + regex_str = build_regex_from_schema(pyjson.dumps(get_schema_from_enum(choices))) + else: + choices = [re.escape(choice) for choice in choices] # type: ignore + regex_str = r"(" + r"|".join(choices) + r")" + + generator = regex(model, regex_str, sampler) + if isinstance(choices, type(Enum)): + generator.format_sequence = lambda x: pyjson.loads(x) + else: + generator.format_sequence = lambda x: x + + return generator + + +@choice.register(OpenAI) +def choice_openai( + model: OpenAI, choices: List[str], sampler: Sampler = multinomial() +) -> Callable: + """ + Call OpenAI API with response_format of a dict: + {"result": } + """ + + choices_schema = pyjson.dumps( + { + "type": "object", + "properties": {"result": {"type": "string", "enum": choices}}, + "additionalProperties": False, + "required": ["result"], + } + ) + generator = json(model, choices_schema, sampler) + + def generate_choice(*args, **kwargs): + return generator(*args, **kwargs)["result"] + + return generate_choice diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/format.py b/.venv/lib/python3.11/site-packages/outlines/generate/format.py new file mode 100644 index 0000000000000000000000000000000000000000..88acec75f414d073f098ca36584973f829c09d4a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/format.py @@ -0,0 +1,47 @@ +from functools import singledispatch + +from outlines.fsm.types import python_types_to_regex +from outlines.generate.api import SequenceGeneratorAdapter +from outlines.models import OpenAI +from outlines.samplers import Sampler, multinomial + +from .regex import regex + + +@singledispatch +def format( + model, python_type, sampler: Sampler = multinomial() +) -> SequenceGeneratorAdapter: + """Generate structured data that can be parsed as a Python type. + + Parameters + ---------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + python_type: + A Python type. The output of the generator must be parseable into + this type. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGenerator` instance that generates text constrained by the Python type + and translates this text into the corresponding type. + + """ + regex_str, format_fn = python_types_to_regex(python_type) + generator = regex(model, regex_str, sampler) + generator.format_sequence = format_fn + + return generator + + +@format.register(OpenAI) +def format_openai(model, python_type, sampler: Sampler = multinomial()): + raise NotImplementedError( + "Cannot use Python type-structured generation with an OpenAI model" + + " due to the limitations of the OpenAI API." + ) diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/fsm.py b/.venv/lib/python3.11/site-packages/outlines/generate/fsm.py new file mode 100644 index 0000000000000000000000000000000000000000..1950812d231f53aae6bfd5d80ed4d830ba52a092 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/fsm.py @@ -0,0 +1,31 @@ +from functools import singledispatch + +import interegular + +from outlines.fsm.guide import RegexGuide +from outlines.generate.api import ( + SequenceGeneratorAdapter, + VisionSequenceGeneratorAdapter, +) +from outlines.models import TransformersVision +from outlines.samplers import Sampler, multinomial + + +@singledispatch +def fsm( + model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial() +) -> SequenceGeneratorAdapter: + from outlines.processors import GuideLogitsProcessor + + guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) + logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide) + return SequenceGeneratorAdapter(model, logits_processor, sampler) + + +@fsm.register(TransformersVision) +def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()): + from outlines.processors import GuideLogitsProcessor + + guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) + logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide) + return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/generator.py b/.venv/lib/python3.11/site-packages/outlines/generate/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..e506aa035ea3f5ced5ec9cfa22b320428813bcaf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/generator.py @@ -0,0 +1,312 @@ +import dataclasses +import math +from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Tuple + +if TYPE_CHECKING: + import torch + + from outlines.fsm.guide import Guide + + +class ContextLengthExceededError(Exception): + pass + + +@dataclasses.dataclass(frozen=True) +class GenerationState: + token_ids: "torch.Tensor" + kv_cache: "torch.Tensor" + logits: "torch.Tensor" + weights: "torch.Tensor" + fsm_states: List[int] + + +def sequence_generator( + model: Callable, + sampler: Callable, + fsms: List["Guide"], + token_ids: "torch.Tensor", + sequence_weights: "torch.Tensor", + attention_masks: "torch.Tensor", + fsm_states: List[int], + rng: "torch.Generator", +) -> Iterator[GenerationState]: + """Generates sequences of tokens. + + Parameters + ---------- + model + A callable that generates a probability distribution over the + vocabulary when passed a tensor of token ids. + sampler + A callable that returns the next token ids, their ancestor sequence and + the updated sequence weights when passed a distribution over the + vocabulary. + token_ids + A tensor of token ids on which the sequence distribution is conditioned, of + shape ``(n_seqs, n_prompt_tokens)`` + sequence_weights + A tensor that contains the initial weights of the sequences, of shape + ``(n_seqs,)`` + attention_masks + A tensor of tensors that represent the tokens considered at the attention + layer, of shape ``(n_seqs, n_prompt_tokens)``. + fsms + List of finite-state machines that drive the text generation, + one for each sequence in the batch. + fsm_states + The initial states of the finite-state machine for each sequence in the batch. + + Yields + ------ + A new sequence. + + """ + import torch + + if rng is None: + rng = torch.Generator() + + kv_cache = None + + while True: + try: + logits, kv_cache = model(token_ids, attention_masks, kv_cache) + except IndexError: # Exceeding the context length + raise ContextLengthExceededError( + "The input length exceeds the context length of the model." + ) + + allowed_tokens = get_allowed_tokens(fsms, fsm_states) + biased_logits = bias_logits(logits, allowed_tokens) + next_token_ids, ancestors, sequence_weights = sampler( + biased_logits, sequence_weights, rng + ) + + token_ids = update_token_ids(token_ids, next_token_ids, ancestors) + attention_masks = update_attention_masks(attention_masks, ancestors) + kv_cache = reorder_kv_cache(kv_cache, ancestors) + if len(ancestors) > 1: + fsms = reorder_fsms(fsms, ancestors) + fsm_states = reorder_fsm_states(fsm_states, ancestors) + + fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids) + is_finished = is_generation_finished(fsms, fsm_states) + + if is_finished: + yield GenerationState( + token_ids, + kv_cache, + logits, + sequence_weights, + fsm_states, + ) + return + + yield GenerationState( + token_ids, + kv_cache, + logits, + sequence_weights, + fsm_states, + ) + + +def get_next_fsm_states( + fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor" +) -> List[int]: + """ + + Parameters + ---------- + fsm + The finite-state machine used to monitor this batch. + next_token_ids + The tokens that were just generated. + + Returns + ------- + A `torch.Tensor` object that represents the next logit mask. + + """ + return [ + fsm.get_next_state(fsm_state, int(token_id[0])) + for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids) + ] + + +def get_allowed_tokens( + fsms: List["Guide"], fsm_states: List[int] +) -> List[Optional[Iterable[int]]]: + """Get the new instructions for each sequence from the finite-state machine. + + Parameters + ---------- + fsm + The finite-state machine used to monitor this batch. + fsm_states + The FSM states corresponding to each sequence in the batch. + + Returns + ------- + A nested list that contains the ids of the logits to keep. + + """ + return [ + fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states) + ] + + +def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool: + """Determine if the generation is finished. + + A generation is considered finished if the FSM of every sequence in the + batch is in a final state. + + A better solution is to return finished sequences as soon as their FSM + is in a final state. + + Parameters + ---------- + fsm + The finite-state machine used to monitor this batch. + fsm_states + The FSM states corresponding to each sequence in the batch. + + Returns + ------- + Whether all sequences are finished sampling. + + """ + return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)]) + + +def update_token_ids( + token_ids: "torch.Tensor", next_token_ids: "torch.Tensor", ancestors: "torch.Tensor" +) -> "torch.Tensor": + """Append the sampled tokens to the running sequence of tokens. + + Parameters + ---------- + token_ids + The current token sequences + next_token_ids + The tokens that were just generated and that we need to append + to the existing sequences. + ancestors + The sequences to which the token ids need to be added. + + Returns + ------- + A new sequence of token ids that contains the tokens that were + just generated. + + """ + import torch + + token_ids = torch.index_select(token_ids, 0, ancestors) + return torch.concatenate([token_ids, next_token_ids], dim=-1) + + +def update_attention_masks( + attention_masks: "torch.Tensor", ancestors: "torch.Tensor" +) -> "torch.Tensor": + """Expand the attention masks. + + Parameters + ---------- + attention_masks + The attention masks for each sequence in the batch. + ancestors + The sequences to which the token ids need to be added. + + Returns + ------- + The attention masks padded with 1s. + + """ + import torch + + attention_masks = torch.index_select(attention_masks, 0, ancestors) + return torch.concatenate( + [ + attention_masks, + torch.ones( + attention_masks.shape[:-1] + (1,), device=attention_masks.device + ), + ], + axis=-1, + ) + + +def reorder_fsms(fsms: List["Guide"], ancestors: "torch.Tensor") -> List["Guide"]: + reordered_fsms = [] + for ancestor in ancestors: + reordered_fsms.append(fsms[ancestor].copy()) + + return reordered_fsms + + +def reorder_fsm_states(fsm_states: List[int], ancestors: "torch.Tensor") -> List[int]: + reordered_states = [] + for ancestor in ancestors: + reordered_states.append(fsm_states[ancestor]) + + return reordered_states + + +def reorder_kv_cache( + kv_cache: Optional[Tuple], ancestors: "torch.Tensor" +) -> Optional[Tuple]: + """Re-order the KV-cache based on the ancestors. + + In transformers, the object that stores the KV-cache is a tuple who elements + are the key cache and the value cache. Each of these caches are tuples where + each element correpond to a layer. To each layer corresponds a tensor whose + first dimension is the batch size. + + """ + import torch + + if kv_cache is None: + return None + + new_kv_cache: Tuple = tuple() + for cache_item in kv_cache: + new_cache_item: Tuple = tuple() + for layer in cache_item: + layer = torch.index_select(layer, 0, ancestors.to(layer.device)) + new_cache_item += (layer,) + new_kv_cache += (new_cache_item,) + + return new_kv_cache + + +def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tensor": + """Mask the logits. + + The function iterates over a nested list where each list corresponds to the + indices that need to be masked for each row in the array. + + Parameters + ---------- + logits + Two dimensional tensor that contains the next-token probability + distribution. + allowed_token_ids + A list that contains the tokens that can be generated by the model. + + Returns + ------- + A view of the original logits tensor where some values are masked. + + """ + import torch + + biased_logits = torch.full_like(logits, -math.inf, device=logits.device) + for i, ids in enumerate(allowed_token_ids): + if ids is not None: + biased_logits[i, ids] = logits[i, ids] + else: + biased_logits[i] = logits[i] + return biased_logits diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/json.py b/.venv/lib/python3.11/site-packages/outlines/generate/json.py new file mode 100644 index 0000000000000000000000000000000000000000..d098d920d515b3e6c55534ce72aebf6c2de9bffc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/json.py @@ -0,0 +1,115 @@ +import json as pyjson +from enum import Enum +from functools import singledispatch +from typing import Callable, Optional, Union + +from outlines_core.fsm.json_schema import build_regex_from_schema +from pydantic import BaseModel + +from outlines.fsm.json_schema import get_schema_from_enum, get_schema_from_signature +from outlines.generate.api import SequenceGeneratorAdapter +from outlines.models import OpenAI +from outlines.samplers import Sampler, multinomial + +from .regex import regex + + +@singledispatch +def json( + model, + schema_object: Union[str, object, Callable], + sampler: Sampler = multinomial(), + whitespace_pattern: Optional[str] = None, +) -> SequenceGeneratorAdapter: + """ + Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. + + Parameters + ---------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + schema_object: + The JSON Schema to generate data for. Can be a JSON string, a Pydantic model, or a callable + that returns a JSON schema. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + + Returns + ------- + A `SequenceGenerator` instance that generates text constrained by the schema_object and + transforms the result if BaseModel is used. + + """ + if isinstance(schema_object, type(BaseModel)): + schema = pyjson.dumps(schema_object.model_json_schema()) + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: schema_object.parse_raw(x) + elif isinstance(schema_object, type(Enum)): + schema = pyjson.dumps(get_schema_from_enum(schema_object)) + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: pyjson.loads(x) + elif callable(schema_object): + schema = pyjson.dumps(get_schema_from_signature(schema_object)) + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: pyjson.loads(x) + elif isinstance(schema_object, str): + schema = schema_object + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: pyjson.loads(x) + else: + raise ValueError( + f"Cannot parse schema {schema_object}. The schema must be either " + + "a Pydantic object, a function or a string that contains the JSON " + + "Schema specification" + ) + + return generator + + +@json.register(OpenAI) +def json_openai( + model, schema_object: Union[str, object], sampler: Sampler = multinomial() +): + if not isinstance(sampler, multinomial): + raise NotImplementedError( + r"The OpenAI API does not support any other sampling algorithm " + + "than the multinomial sampler." + ) + + if isinstance(schema_object, type(BaseModel)): + schema = pyjson.dumps(schema_object.model_json_schema()) + format_sequence = lambda x: schema_object.parse_raw(x) + elif isinstance(schema_object, str): + schema = schema_object + format_sequence = lambda x: pyjson.loads(x) + else: + raise ValueError( + f"Cannot parse schema {schema_object}. The schema must be either " + + "a Pydantic object, a function or a string that contains the JSON " + + "Schema specification" + ) + + # create copied, patched model with normalized json schema set + generator = model.new_with_replacements( + response_format={ + "type": "json_schema", + "json_schema": { + "name": "default", + "strict": True, + "schema": pyjson.loads(schema), + }, + } + ) + + generator.format_sequence = format_sequence + + return generator diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/regex.py b/.venv/lib/python3.11/site-packages/outlines/generate/regex.py new file mode 100644 index 0000000000000000000000000000000000000000..673880e4986322be6fdbaeed1684757e601a0624 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/regex.py @@ -0,0 +1,59 @@ +from functools import singledispatch + +from outlines.generate.api import ( + SequenceGeneratorAdapter, + VisionSequenceGeneratorAdapter, +) +from outlines.models import OpenAI, TransformersVision +from outlines.samplers import Sampler, multinomial + + +@singledispatch +def regex(model, regex_str: str, sampler: Sampler = multinomial()): + """Generate structured text in the language of a regular expression. + + Parameters + ---------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + regex_str: + The regular expression that the output must follow. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGeneratorAdapter` instance that generates text constrained by the + regular expression. + + """ + from outlines.processors import RegexLogitsProcessor + + logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer) + return SequenceGeneratorAdapter(model, logits_processor, sampler) + + +@regex.register(TransformersVision) +def regex_vision( + model, + regex_str: str, + sampler: Sampler = multinomial(), +): + from outlines.processors import RegexLogitsProcessor + + logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer) + return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) + + +@regex.register(OpenAI) +def regex_openai( + model: OpenAI, + regex_str: str, + sampler: Sampler = multinomial(), +): + raise NotImplementedError( + "Cannot use regex-structured generation with an OpenAI model" + + "due to the limitations of the OpenAI API." + ) diff --git a/.venv/lib/python3.11/site-packages/outlines/generate/text.py b/.venv/lib/python3.11/site-packages/outlines/generate/text.py new file mode 100644 index 0000000000000000000000000000000000000000..32530d0c49c693ec499a067ab751ef834536716b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/generate/text.py @@ -0,0 +1,50 @@ +from functools import singledispatch + +from outlines.generate.api import ( + SequenceGeneratorAdapter, + VisionSequenceGeneratorAdapter, +) +from outlines.models import OpenAI, TransformersVision +from outlines.samplers import Sampler, multinomial + + +@singledispatch +def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter: + """Generate text with a `Transformer` model. + + Note + ---- + Python 3.11 allows dispatching on Union types and + this should greatly simplify the code. + + Arguments + --------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGeneratorAdapter` instance that generates text. + + """ + return SequenceGeneratorAdapter(model, None, sampler) + + +@text.register(TransformersVision) +def text_vision(model, sampler: Sampler = multinomial()): + return VisionSequenceGeneratorAdapter(model, None, sampler) + + +@text.register(OpenAI) +def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI: + if not isinstance(sampler, multinomial): + raise NotImplementedError( + r"The OpenAI API does not support any other sampling algorithm " + + "than the multinomial sampler." + ) + + return model diff --git a/.venv/lib/python3.11/site-packages/outlines/grammars.py b/.venv/lib/python3.11/site-packages/outlines/grammars.py new file mode 100644 index 0000000000000000000000000000000000000000..f0c122964786e8cb42cec595cff04a823f2c1958 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/grammars.py @@ -0,0 +1,14 @@ +from pathlib import Path + +GRAMMAR_PATH = Path(__file__).parent / "grammars" + + +def read_grammar(grammar_file_name, base_grammar_path=GRAMMAR_PATH): + """Read grammar file from default grammar path""" + full_path = base_grammar_path / grammar_file_name + with open(full_path) as file: + return file.read() + + +arithmetic = read_grammar("arithmetic.lark") +json = read_grammar("json.lark") diff --git a/.venv/lib/python3.11/site-packages/outlines/grammars/arithmetic.lark b/.venv/lib/python3.11/site-packages/outlines/grammars/arithmetic.lark new file mode 100644 index 0000000000000000000000000000000000000000..2332650c63c02b5f3ded849dc61542170c922038 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/grammars/arithmetic.lark @@ -0,0 +1,18 @@ +?start: sum + +?sum: product +| sum "+" product -> add +| sum "-" product -> sub + +?product: atom +| product "*" atom -> mul +| product "/" atom -> div + +?atom: NUMBER -> number +| "-" atom -> neg +| "(" sum ")" + +%import common.NUMBER +%import common.WS_INLINE + +%ignore WS_INLINE diff --git a/.venv/lib/python3.11/site-packages/outlines/grammars/common.lark b/.venv/lib/python3.11/site-packages/outlines/grammars/common.lark new file mode 100644 index 0000000000000000000000000000000000000000..ee5e00c500093e8c095c83cc7d383ebc82592a6f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/grammars/common.lark @@ -0,0 +1,83 @@ +// Adapted from https://github.com/lark-parser/lark/blob/master/lark/grammars/common.lark + +// Lark License: +// Copyright © 2017 Erez Shinan +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +// Basic terminals for common use + + +// +// Numbers +// + +DIGIT: "0".."9" +HEXDIGIT: "a".."f"|"A".."F"|DIGIT + +INT: DIGIT+ +SIGNED_INT: ["+"|"-"] INT +DECIMAL: INT "." INT? | "." INT + +// float = /-?\d+(\.\d+)?([eE][+-]?\d+)?/ +_EXP: ("e"|"E") SIGNED_INT +FLOAT: INT _EXP | DECIMAL _EXP? +SIGNED_FLOAT: ["+"|"-"] FLOAT + +NUMBER: FLOAT | INT +SIGNED_NUMBER: ["+"|"-"] NUMBER + +UNESCAPED_STRING: /\"[^"]*\"/ + +// based on `outlines/fsm/json_schema.py` +_NON_CONTROL_CHAR: /([^"\\\x00-\x1F\x7F-\x9F])/ +_ESCAPED_CHAR: /\\/ (_NON_CONTROL_CHAR | /\\/ | /"/) +ESCAPED_STRING_INNER: _NON_CONTROL_CHAR | _ESCAPED_CHAR +ESCAPED_STRING: /"/ ESCAPED_STRING_INNER* /"/ + + + +// +// Names (Variables) +// +LCASE_LETTER: "a".."z" +UCASE_LETTER: "A".."Z" + +LETTER: UCASE_LETTER | LCASE_LETTER +WORD: LETTER+ + +CNAME: ("_"|LETTER) ("_"|LETTER|DIGIT)* + + +// +// Whitespace +// +WS_INLINE: (" "|/\t/)+ +WS: /[ \t\f\r\n]/+ + +CR : /\r/ +LF : /\n/ +NEWLINE: (CR? LF)+ + + +// Comments +SH_COMMENT: /#[^\n]*/ +CPP_COMMENT: /\/\/[^\n]*/ +C_COMMENT: "/*" /(.|\n)*?/ "*/" +SQL_COMMENT: /--[^\n]*/ diff --git a/.venv/lib/python3.11/site-packages/outlines/grammars/json.lark b/.venv/lib/python3.11/site-packages/outlines/grammars/json.lark new file mode 100644 index 0000000000000000000000000000000000000000..7429fa5583f92cbbad1d3c239895750dffc460ea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/grammars/json.lark @@ -0,0 +1,19 @@ +?start: value + +?value: object +| array +| ESCAPED_STRING +| SIGNED_NUMBER -> number +| "true" -> true +| "false" -> false +| "null" -> null + +array : "[" [value ("," value)*] "]" +object : "{" [pair ("," pair)*] "}" +pair : ESCAPED_STRING ":" value + +%import common.ESCAPED_STRING +%import common.SIGNED_NUMBER +%import common.WS + +%ignore WS diff --git a/.venv/lib/python3.11/site-packages/outlines/processors/__init__.py b/.venv/lib/python3.11/site-packages/outlines/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f0f829b54b34505115479ee5c1dc8e773d9e28 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/processors/__init__.py @@ -0,0 +1,7 @@ +from .structured import ( + CFGLogitsProcessor, + GuideLogitsProcessor, + JSONLogitsProcessor, + OutlinesLogitsProcessor, + RegexLogitsProcessor, +) diff --git a/.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ad4e2d683e43f4bbafc0d83f3aef9601a9e2fb3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/base_logits_processor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/base_logits_processor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4452e1948bb39330f05635d398b606adda7fb13 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/base_logits_processor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/structured.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/structured.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42373546de13622186fe96401bc29726981bd7dd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/structured.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/processors/base_logits_processor.py b/.venv/lib/python3.11/site-packages/outlines/processors/base_logits_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..44b55af2e465adeea7c9ddf26211327c5a5c15fe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/processors/base_logits_processor.py @@ -0,0 +1,159 @@ +from abc import abstractmethod +from typing import TYPE_CHECKING, List, Protocol, Type, Union + +import numpy as np +import torch +from numpy.typing import NDArray + +if TYPE_CHECKING: + import mlx.core as mx + + +Array = Union[NDArray, torch.Tensor, List, "mx.array"] + + +def is_mlx_array_type(array_type): + try: + import mlx.core as mx + except ImportError: + return False + return issubclass(array_type, mx.array) + + +def is_jax_array_type(array_type): + try: + import jaxlib + except ImportError: + return False + return issubclass(array_type, jaxlib.xla_extension.ArrayImpl) or isinstance( + array_type, jaxlib.xla_extension.ArrayImpl + ) + + +class OutlinesLogitsProcessor(Protocol): + """ + Base class for logits processors which normalizes types of logits: + - ndarray (used by llama-cpp-python), converted to torch.Tensor + - mlx.core.array (used by mlx-lm), converted to torch.Tensor + - torch.Tensor (used by everything else) + + Normalization of types and conversion to torch.Tensor + doesn't move memory, it just casts the type. + + Normalizing the types allows all logits processors inheriting from this class + to implement a single method for all the business logit: `process_logits()` + """ + + @abstractmethod + def process_logits( + self, input_ids: List[List[int]], logits: torch.Tensor + ) -> torch.Tensor: + """ + input_ids and logits are always 2D tensors for handling a batch of sequences. + + - input_ids -> List[List[tokens]] + - logits -> 2D_Tensor[logit floats] + + Important to keep in mind when designing universal logits processors + - logits processors are only used once and never re-applied for a new sequence generator + - Some models only pass output_ids, some models such as llamacpp and transformers prefix with input_ids + - Some sampling methods, such as beam search, result in unstable sequence ordering in models like vLLM + """ + pass + + @torch.no_grad() + def __call__( + self, + input_ids: Array, + logits: Array, + ) -> Array: + """ + Apply logits processor + + 1) Unify type + - convert input_ids: either ndarray, mlx array, List[int], or Tensor -> List[List[int]] + - convert logits: either ndarray, mlx array, or Tensor -> 2D float Tensor + 2) Unify shape, ensure logits and input_ids are 2D + 3) Call self.process_logits() to perform business logic + 4) Cast logits back to original array library type + """ + # ensure logits are torch Tensors + torch_logits = self._to_torch(logits) + input_ids = self._to_torch(input_ids) + + assert torch_logits.shape[:-1] == input_ids.shape[:-1] + + # Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape + if len(torch_logits.shape) == 2: + processed_logits = self.process_logits(input_ids, torch_logits) + elif len(torch_logits.shape) == 1: + processed_logits = self.process_logits( + input_ids.unsqueeze(0), torch_logits.unsqueeze(0) + ).squeeze(0) + + # return logits as passed array type + return self._from_torch(processed_logits, type(logits)) + + @staticmethod + def _to_torch(tensor_like: Array) -> torch.Tensor: + """Convert various types to torch.Tensor.""" + if isinstance(tensor_like, torch.Tensor): + return tensor_like + + elif isinstance(tensor_like, np.ndarray): + return torch.from_numpy(tensor_like) + + elif isinstance(tensor_like, (list, tuple)): + return torch.tensor(tensor_like) + + elif is_mlx_array_type(type(tensor_like)): + import mlx.core as mx + + # https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch + return torch.from_dlpack( + np.array(tensor_like.astype(mx.float32), copy=False) + ) + + elif is_jax_array_type(type(tensor_like)): + import jax + + torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like)) + return torch_tensor + + else: + raise TypeError( + "LogitsProcessor must be called with either np.NDArray, " + "torch.Tensor, list, or mlx.core.array typed logits. " + f"Logits type: `{type(tensor_like)}`" + ) + + @staticmethod + def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array: + """Convert torch.Tensor to the specified target type.""" + if target_type == torch.Tensor: + return tensor + + elif target_type == np.ndarray: + return tensor.detach().numpy() + + elif target_type == list: + return tensor.detach().tolist() + + elif target_type == tuple: + return tuple(tensor.detach().tolist()) + + elif is_mlx_array_type(target_type): + import mlx.core as mx + + # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch + return mx.array(tensor.float().numpy()) + + elif is_jax_array_type(target_type): + import jax + + return jax.dlpack.from_dlpack(tensor) + + else: + raise TypeError( + f"Failed to convert torch tensors to target_type `{target_type}`" + ) diff --git a/.venv/lib/python3.11/site-packages/outlines/processors/structured.py b/.venv/lib/python3.11/site-packages/outlines/processors/structured.py new file mode 100644 index 0000000000000000000000000000000000000000..64892b73f76f0e36855f2f67a44e11c3515edcd9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/processors/structured.py @@ -0,0 +1,247 @@ +""" + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import math +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union + +import torch +from outlines_core.fsm.json_schema import build_regex_from_schema +from pydantic import BaseModel + +from outlines.fsm.guide import CFGGuide, Guide, RegexGuide +from outlines.fsm.json_schema import convert_json_schema_to_str + +from .base_logits_processor import OutlinesLogitsProcessor + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + + +class GuideLogitsProcessor(OutlinesLogitsProcessor): + """Bias generation using a finite + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + guide + The `outlines.fsm.Guide` which is used to bias the logits. + """ + + tokenizer: "Tokenizer" + guide: Guide + _guide_states: Dict[int, Any] + _seq_start_idx: Optional[int] + + def __init__(self, tokenizer: "Tokenizer", guide: Guide): + """A Guide-based logits processor. + + Parameters + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + guide + The `outlines.fsm.Guide. which is used to bias the logits. + """ + self.tokenizer = tokenizer + self.guide = guide + self._guide_states = {hash(tuple([])): self.guide.initial_state} + self._seq_start_idx = None + + def process_logits( + self, input_ids: torch.LongTensor, logits: torch.FloatTensor + ) -> torch.Tensor: + """Use the Guide to bias the logits before sampling the next token. + + Parameters + ---------- + input_ids + The input token ids. + logits + The logits. + + Returns + ------- + torch.Tensor + The biased logits. + """ + if self._seq_start_idx is None: + self._seq_start_idx = len(input_ids[0]) + + sequence_states: List[int] = [] # vector of states corresponding to `input_ids` + + for seq_ids in input_ids: + gen_ids = seq_ids[self._seq_start_idx :] + curr_state_key = hash(tuple(gen_ids.tolist())) + + if curr_state_key not in self._guide_states: + prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))] + curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item()) + self._guide_states[curr_state_key] = curr_state + + sequence_states.append(self._guide_states[curr_state_key]) + + mask = torch.ones_like(logits, dtype=torch.bool) + + allowed_tokens_batch = [] + batch_indices = [] + for i, guide_state in enumerate(sequence_states): + allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to( + mask.device, non_blocking=True + ) + allowed_tokens_batch.append(allowed_tokens) + batch_indices.append( + torch.full_like(allowed_tokens, i) + ) # Store batch index for each allowed token + + allowed_tokens_concat = torch.cat(allowed_tokens_batch) + batch_indices_concat = torch.cat(batch_indices) + + mask[batch_indices_concat, allowed_tokens_concat] = False + logits.masked_fill_(mask, float("-inf")) + + return logits + + def copy(self) -> "GuideLogitsProcessor": + """Return a copy of the logits processor.""" + return GuideLogitsProcessor(tokenizer=self.tokenizer, guide=self.guide.copy()) + + +class RegexLogitsProcessor(GuideLogitsProcessor): + """Bias generation based on a regular expression. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + guide + The `outlines.fsm.RegexGuide. which is used to bias the logits. + """ + + def __init__(self, regex_string: str, tokenizer: "Tokenizer"): + """Compile the RegexGuide that drives the regex-guided generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + An Outlines tokenizer + """ + guide = RegexGuide.from_regex(regex_string, tokenizer) + super().__init__(tokenizer=tokenizer, guide=guide) + + +class JSONLogitsProcessor(RegexLogitsProcessor): + """Bias generation based on a JSON schema. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + guide + The `outlines.fsm.RegexGuide. which is used to bias the logits. + """ + + def __init__( + self, + schema: Union[dict, Type[BaseModel], str], + tokenizer: "Tokenizer", + whitespace_pattern: Optional[str] = None, + ): + """Compile the Guide that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate. + tokenizer + The tokenizer used to convert tokens to ids. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string + literals). For example, to allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + schema_str = convert_json_schema_to_str(json_schema=schema) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string=regex_string, tokenizer=tokenizer) + + +class CFGLogitsProcessor(GuideLogitsProcessor): + """Bias generation based on a context-free grammar. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + guide + The `outlines.fsm.CFGGuide. which is used to bias the logits. + """ + + guide: CFGGuide + + def __init__(self, cfg_str: str, tokenizer: "Tokenizer"): + """Compile the CFGGuide that drives the CFG-guided generation. + + Parameters + ---------- + cfg_str + A string that represents a grammar + tokenizer + The tokenizer used to convert tokens to ids. + """ + cfg_guide = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer) + super().__init__(tokenizer=tokenizer, guide=cfg_guide) + + def process_logits( + self, input_ids: torch.LongTensor, logits: torch.Tensor + ) -> torch.Tensor: + """Same behavior as GuideLogitsProcessor, but uses rejection sampling""" + if self._seq_start_idx is None: + self._seq_start_idx = len(input_ids[0]) + + sequence_states: List = [] # vector of states corresponding to `input_ids` + + for seq_ids in input_ids: + gen_ids = seq_ids[self._seq_start_idx :] + curr_state_key = hash(tuple(gen_ids.tolist())) + + if curr_state_key not in self._guide_states: + prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))] + curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item()) + self._guide_states[curr_state_key] = curr_state + + sequence_states.append(self._guide_states[curr_state_key]) + + mask = torch.full_like(logits, -math.inf) + for i, guide_state in enumerate(sequence_states): + first_legal_token = next( + self.guide.iter_valid_token_ids( + guide_state, torch.argsort(logits[i], descending=True) + ) + ) + mask[i, [first_legal_token]] = logits[i, [first_legal_token]] + + return mask diff --git a/.venv/lib/python3.11/site-packages/outlines/prompts.py b/.venv/lib/python3.11/site-packages/outlines/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..a7824451a19b9a60cbda59455ea0e3900526dd44 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/prompts.py @@ -0,0 +1,343 @@ +import functools +import inspect +import json +import re +import textwrap +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Type, cast + +from jinja2 import Environment, StrictUndefined +from pydantic import BaseModel + + +@dataclass +class Prompt: + """Represents a prompt function. + + We return a `Prompt` class instead of a simple function so the + template defined in prompt functions can be accessed. + + """ + + template: str + signature: inspect.Signature + + def __post_init__(self): + self.parameters: List[str] = list(self.signature.parameters.keys()) + self.jinja_environment = create_jinja_template(self.template) + + def __call__(self, *args, **kwargs) -> str: + """Render and return the template. + + Returns + ------- + The rendered template as a Python ``str``. + + """ + bound_arguments = self.signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return self.jinja_environment.render(**bound_arguments.arguments) + + def __str__(self): + return self.template + + +def prompt(fn: Callable) -> Prompt: + """Decorate a function that contains a prompt template. + + This allows to define prompts in the docstring of a function and simplify their + manipulation by providing some degree of encapsulation. It uses the `render` + function internally to render templates. + + >>> import outlines + >>> + >>> @outlines.prompt + >>> def build_prompt(question): + ... "I have a ${question}" + ... + >>> prompt = build_prompt("How are you?") + + This API can also be helpful in an "agent" context where parts of the prompt + are set when the agent is initialized and never modified later. In this situation + we can partially apply the prompt function at initialization. + + >>> import outlines + >>> import functools as ft + ... + >>> @outlines.prompt + ... def solve_task(name: str, objective: str, task: str): + ... '''Your name is {{name}}. + .. Your overall objective is to {{objective}}. + ... Please solve the following task: {{task}} + ... ''' + ... + >>> hal = ft.partial(solve_task, "HAL", "Travel to Jupiter") + + Returns + ------- + A `Prompt` callable class which will render the template when called. + + """ + + signature = inspect.signature(fn) + + # The docstring contains the template that will be rendered to be used + # as a prompt to the language model. + docstring = fn.__doc__ + if docstring is None: + raise TypeError("Could not find a template in the function's docstring.") + + template = cast(str, docstring) + + return Prompt(template, signature) + + +def render(template: str, **values: Optional[Dict[str, Any]]) -> str: + r"""Parse a Jinaj2 template and translate it into an Outlines graph. + + This function removes extra whitespaces and linebreaks from templates to + allow users to enter prompts more naturally than if they used Python's + constructs directly. See the examples for a detailed explanation. + + Examples + -------- + + Outlines follow Jinja2's syntax + + >>> import outlines + >>> outline = outlines.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis") + I like tomatoes and tennis + + If the first line of the template is empty, `render` removes it + + >>> from outlines import render + >>> + >>> tpl = ''' + ... A new string''' + >>> tpl + ... '\nA new string' + >>> render(tpl) + ... 'a new string' + + Similarly, `render` ignores linebreaks introduced by placing the closing quotes + underneath the text: + + >>> tpl = ''' + ... A new string + ... ''' + >>> tpl + ... '\nA new string\n' + >>> render(tpl) + ... 'A new string' + + If you want to insert a linebreak at the end of the rendered template, you will + need to leave an empty line at the end of the template: + + >>> tpl = ''' + ... A new string + ... + ... ''' + >>> tpl + ... '\nA new string\n\n' + >>> render(tpl) + ... 'A new string\n' + + `render` removes the identation in docstrings. This is particularly important + when using prompt functions + + >>> tpl = ''' + ... a string + ... and another string''' + >>> tpl + ... '\n a string\n and another string' + >>> render(tpl) + ... 'a string\nand another string' + + The indentation of the first line is assumed to be the same as the second line's + + >>> tpl = '''a string + ... and another''' + >>> tpl + ... 'a string\n and another' + >>> render(tpl) + ... 'a string\nand another' + + To get a different indentation for the first and the second line, we can start the + prompt on the string's second line: + + >>> tpl = ''' + ... First line + ... Second line''' + >>> render(tpl) + ... 'First Line\n Second Line' + + Parameters + ---------- + template + A string that contains a template written with the Jinja2 syntax. + **values + Map from the variables in the template to their value. + + Returns + ------- + A string that contains the rendered template. + + """ + jinja_template = create_jinja_template(template) + return jinja_template.render(**values) + + +def create_jinja_template(template: str): + # Dedent, and remove extra linebreak + cleaned_template = inspect.cleandoc(template) + + # Add linebreak if there were any extra linebreaks that + # `cleandoc` would have removed + ends_with_linebreak = template.replace(" ", "").endswith("\n\n") + if ends_with_linebreak: + cleaned_template += "\n" + + # Remove extra whitespaces, except those that immediately follow a newline symbol. + # This is necessary to avoid introducing whitespaces after backslash `\` characters + # used to continue to the next line without linebreak. + cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) + + env = Environment( + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=StrictUndefined, + ) + env.filters["name"] = get_fn_name + env.filters["description"] = get_fn_description + env.filters["source"] = get_fn_source + env.filters["signature"] = get_fn_signature + env.filters["schema"] = get_schema + env.filters["args"] = get_fn_args + + jinja_template = env.from_string(cleaned_template) + return jinja_template + + +def get_fn_name(fn: Callable): + """Returns the name of a callable.""" + if not callable(fn): + raise TypeError("The `name` filter only applies to callables.") + + if not hasattr(fn, "__name__"): + name = type(fn).__name__ + else: + name = fn.__name__ + + return name + + +def get_fn_args(fn: Callable): + """Returns the arguments of a function with annotations and default values if provided.""" + if not callable(fn): + raise TypeError("The `args` filter only applies to callables.") + + arg_str_list = [] + signature = inspect.signature(fn) + arg_str_list = [str(param) for param in signature.parameters.values()] + arg_str = ", ".join(arg_str_list) + return arg_str + + +def get_fn_description(fn: Callable): + """Returns the first line of a callable's docstring.""" + if not callable(fn): + raise TypeError("The `description` filter only applies to callables.") + + docstring = inspect.getdoc(fn) + if docstring is None: + description = "" + else: + description = docstring.split("\n")[0].strip() + + return description + + +def get_fn_source(fn: Callable): + """Return the source code of a callable.""" + if not callable(fn): + raise TypeError("The `source` filter only applies to callables.") + + source = textwrap.dedent(inspect.getsource(fn)) + re_search = re.search(re.compile(r"(\bdef\b.*)", re.DOTALL), source) + if re_search is not None: + source = re_search.group(0) + else: + raise TypeError("Could not read the function's source code") + + return source + + +def get_fn_signature(fn: Callable): + """Return the signature of a callable.""" + if not callable(fn): + raise TypeError("The `source` filter only applies to callables.") + + source = textwrap.dedent(inspect.getsource(fn)) + re_search = re.search(re.compile(r"\(([^)]+)\)"), source) + if re_search is None: + signature = "" + else: + signature = re_search.group(1) + + return signature + + +@functools.singledispatch +def get_schema(model: Any): + raise NotImplementedError( + f"No schema rendering function defined for type {type(model)}." + ) + + +@get_schema.register(dict) +def get_schema_dict(model: Dict): + """Return a pretty-printed dictionary""" + return json.dumps(model, indent=2) + + +@get_schema.register(type(BaseModel)) +def get_schema_pydantic(model: Type[BaseModel]): + """Return the schema of a Pydantic model.""" + if not type(model) == type(BaseModel): + raise TypeError("The `schema` filter only applies to Pydantic models.") + + if hasattr(model, "model_json_schema"): + def_key = "$defs" + raw_schema = model.model_json_schema() + else: # pragma: no cover + def_key = "definitions" + raw_schema = model.schema() + + definitions = raw_schema.get(def_key, None) + schema = parse_pydantic_schema(raw_schema, definitions) + + return json.dumps(schema, indent=2) + + +def parse_pydantic_schema(raw_schema, definitions): + """Parse the output of `Basemodel.[schema|model_json_schema]()`. + + This recursively follows the references to other schemas in case + of nested models. Other schemas are stored under the "definitions" + key in the schema of the top-level model. + + """ + simple_schema = {} + for name, value in raw_schema["properties"].items(): + if "description" in value: + simple_schema[name] = value["description"] + elif "$ref" in value: + refs = value["$ref"].split("/") + simple_schema[name] = parse_pydantic_schema( + definitions[refs[2]], definitions + ) + else: + simple_schema[name] = f"<{name}>" + + return simple_schema diff --git a/.venv/lib/python3.11/site-packages/outlines/py.typed b/.venv/lib/python3.11/site-packages/outlines/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/outlines/samplers.py b/.venv/lib/python3.11/site-packages/outlines/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1421971f6bfc12c5d05652a36d7bd364fb608fe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/samplers.py @@ -0,0 +1,324 @@ +import math +from typing import TYPE_CHECKING, Callable, Optional, Protocol, Tuple + +if TYPE_CHECKING: + import torch + + +class Sampler(Protocol): + samples: int + + def __call__( + self, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + rng: "torch.Generator", + ) -> "torch.DoubleTensor": + ... + + +class GreedySampler: + """Greedy Sampling algorithm. + + Greedy sampling consists in choosing the token with the largest + likelihood at every step. + + We don't allow more than one sample. We could attribute this a meaning, for + instance the k-th sample represents the k-th most likely token. In which + case it would be equivalent to beam search without the sequence weights. + + Attributes + ---------- + samples + The number of samples taken for each input sequence. + + """ + + def __init__(self): + self.samples = 1 + + def __call__( + self, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + _, + ) -> "torch.DoubleTensor": + """Call the greedy sampler. + + Parameters + ---------- + next_token_logits + A tensor of shape ``(n_seqs, vocab_size,)`` that represents the + probability distribution of the next token over the vocabulary. + sequence_weights + A tensor of shape ``(n_seqs,)`` that represents the cumulative + weight of each sequence. + rng + A random number generator. + + Returns + ------- + A tuple with an array that contains the ids of the sampled tokens of + shape ``(n_seqs, 1)``, an array that contains the ancestors of each + sampled id of shape ``(n_seqs,)`` and an array that contains the updated + cumulative weights of each sequence of shape ``(n_seqs,)``. + + """ + import torch + + logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + next_token_ids = torch.argmax(logprobs, dim=-1, keepdim=True) + + ancestors = torch.arange( + next_token_logits.shape[0], device=next_token_logits.device + ) + weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze() + + return next_token_ids, ancestors, weights + + +greedy = GreedySampler + + +class MultinomialSampler: + """Multinomial sampling algorithm. + + Multinomial sampling consists in randomly sampling the next token assuming + its distribution is a Categorical distribution parametrized by the + next-token logits. + + + Attributes + ---------- + samples + The number of samples taken for each input sequence. + + """ + + def __init__( + self, + samples: int = 1, + *, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ): + self.samples = samples + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + + self.logits_processors = [] + if top_k is not None: + self.logits_processors.append(keep_top_k_logits(top_k)) + elif top_p is not None: + self.logits_processors.append(keep_top_p_logits(top_p)) + + if temperature is not None: + self.logits_processors.append(rescale_logits(temperature)) + + def __call__( + self, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + rng: "torch.Generator", + ) -> Tuple["torch.DoubleTensor", "torch.DoubleTensor", "torch.DoubleTensor"]: + """Call the multinomial sampler. + + Parameters + ---------- + next_token_logits + A tensor of shape ``(n_seqs, vocab_size,)`` that represents the + probability distribution of the next token over the vocabulary. + sequence_weights + A tensor of shape ``(n_seqs,)`` that represents the cumulative + weight of each sequence. + rng + A random number generator. + + Returns + ------- + A tuple with an array that contains the ids of the sampled tokens of + shape ``(n_seqs, 1)``, an array that contains the ancestors of each + sampled id of shape ``(n_seqs,)`` and an array that contains the updated + cumulative weights of each sequence of shape ``(n_seqs,)``. + + """ + import torch + + altered_next_token_logits = next_token_logits + for logit_processor in self.logits_processors: + altered_next_token_logits = logit_processor(next_token_logits) + + probs = torch.nn.functional.softmax(altered_next_token_logits, dim=-1) + next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng) + + logprobs = torch.nn.functional.log_softmax(altered_next_token_logits, dim=-1) + ancestors = torch.arange( + altered_next_token_logits.shape[0], device=next_token_logits.device + ) + weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze() + + return next_token_ids, ancestors, weights + + +multinomial = MultinomialSampler + + +def keep_top_k_logits(k: int) -> Callable[["torch.Tensor"], "torch.Tensor"]: + """Build a function that masks logits values smaller than the top `k` ones. + + Parameters + ---------- + k + The ranking below which logit values are replaced by `-math.inf`. + + """ + import torch + + if not isinstance(k, int) or k < 1: + raise ValueError(f"`k` must be a strictly positive integers, got {k} instead.") + + def logits_processor(logits: torch.Tensor) -> torch.Tensor: + num_to_keep = min(k, logits.size(-1)) + mask_idx = logits < torch.topk(logits, num_to_keep)[0][..., -1, None] + return logits.masked_fill(mask_idx, -math.inf) + + return logits_processor + + +def keep_top_p_logits(p: float) -> Callable[["torch.Tensor"], "torch.Tensor"]: + """Build a function that masks the lowest probability tokens whose + cumulative probability is below a certain threshold. + + Parameters + ---------- + p + The value of the threshold. We keep the highest probability tokens whose + cumulative distribution is greater than or equal to `p` and mask the + others. Its value must be between 0 (excluded) and 1 (included). + + """ + import torch + + if p <= 0.0 or p > 1.0: + raise ValueError( + f"`p` must be a floating point number between 0 (excluded) and 1 (included), got {p} instead." + ) + + def logits_processor(logits: torch.Tensor) -> torch.Tensor: + sorted_logits, sorted_idx = torch.sort(logits, descending=False) + cumulative_probabilties = torch.nn.functional.softmax( + sorted_logits, dim=-1 + ).cumsum(dim=-1) + + sorted_masked_idx = cumulative_probabilties <= (1 - p) + mask_idx = torch.scatter(sorted_masked_idx, 1, sorted_idx, sorted_masked_idx) + return logits.masked_fill(mask_idx, -math.inf) + + return logits_processor + + +def rescale_logits(temperature: float) -> Callable[["torch.Tensor"], "torch.Tensor"]: + """Build a function that rescales the token probabilities exponentially. + + Parameters + ---------- + temperature + The value by which we rescale the logits. + + """ + + if not isinstance(temperature, float) or temperature < 0.0: + raise ValueError( + f"`temperature` must be a strictly positive floating point number, got {temperature} instead." + ) + elif temperature == 0.0: + raise ValueError( + "Please use the greedy sampler instead of setting the temperature to 0." + ) + + def logits_processor(logits: "torch.Tensor") -> "torch.Tensor": + return logits / temperature + + return logits_processor + + +class BeamSearchSampler: + """Beam Search sampling algorithm. + + Attributes + ---------- + samples + The number of samples taken for each input sequence. Equivalent to the + number of beams. + """ + + def __init__(self, beams: int = 1): + self.samples = beams + + def __call__( + self, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + _, + ) -> Tuple["torch.DoubleTensor", "torch.DoubleTensor", "torch.DoubleTensor"]: + """Call the beam search sampler. + + Parameters + ---------- + next_token_logits + A tensor of shape ``(n_seqs, vocab_size,)`` that represents the + probability distribution of the next token over the vocabulary. + sequence_weights + A tensor of shape ``(n_seqs,)`` that represents the cumulative + weight of each sequence. + rng + A random number generator. + + Returns + ------- + A tuple with an array that contains the ids of the sampled tokens of + shape ``(n_seqs, 1)``, an array that contains the ancestors of each + sampled id of shape ``(n_seqs,)`` and an array that contains the updated + cumulative weights of each sequence of shape ``(n_seqs,)``. + + """ + import torch + + logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + weights = logprobs + sequence_weights.unsqueeze(1).expand_as(next_token_logits) + + # Flatten scores to (n_batch, n_samples * vocab_size) + # and find the top-k weights for each batch. + batch_size = next_token_logits.shape[0] // self.samples + vocab_size = next_token_logits.shape[-1] + weights = weights.view(batch_size, self.samples * vocab_size) + + # If the weights are all equal to 0 we are at the beginning of the search + # and thus only need to sample from one set of token logits for each + # batch. + if torch.all(sequence_weights == 0): + weights = weights[:, :vocab_size] + + weights, indices = torch.topk( + weights, self.samples, dim=1, largest=True, sorted=True + ) + + ancestors = torch.div(indices, vocab_size, rounding_mode="floor") + next_token_ids = indices % vocab_size + + # Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1) + first_batch_idx = torch.arange( + 0, batch_size * self.samples, self.samples, device=next_token_logits.device + ).unsqueeze(1) + ancestors = ancestors + first_batch_idx + + ancestors = ancestors.view(self.samples * batch_size) + weights = weights.view(self.samples * batch_size) + next_token_ids = next_token_ids.view(self.samples * batch_size, 1) + + return next_token_ids, ancestors, weights + + +beam_search = BeamSearchSampler diff --git a/.venv/lib/python3.11/site-packages/outlines/serve/__init__.py b/.venv/lib/python3.11/site-packages/outlines/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/outlines/serve/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/serve/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77c68d9371e546bcf850464706206f1c633fa576 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/serve/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/serve/__pycache__/serve.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/serve/__pycache__/serve.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f447fde7dc865e4882f5409378f152cfbc41eacf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/serve/__pycache__/serve.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/serve/serve.py b/.venv/lib/python3.11/site-packages/outlines/serve/serve.py new file mode 100644 index 0000000000000000000000000000000000000000..998fbc4594752aa2dd6b91222b5bc0d343eefab7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/serve/serve.py @@ -0,0 +1,139 @@ +# _______________________________ +# / Don't want to self-host? \ +# \ Try .json at http://dottxt.co / +# ------------------------------- +# \ ^__^ +# \ (oo)\_______ +# (__)\ )\/\ +# ||----w | +# || || +# +# +# Copyright 2024- the Outlines developers +# Copyright 2023 the vLLM developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +from typing import AsyncGenerator + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +from outlines.models.vllm import adapt_tokenizer +from outlines.processors import JSONLogitsProcessor, RegexLogitsProcessor + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. +app = FastAPI() +engine = None +tokenizer = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - schema: the JSON schema to use for the generation (if regex is not provided). + - regex: the regex to use for the generation (if schema is not provided). + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + assert engine is not None + + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + + json_schema = request_dict.pop("schema", None) + regex_string = request_dict.pop("regex", None) + if json_schema is not None: + logits_processors = [JSONLogitsProcessor(json_schema, tokenizer)] + elif regex_string is not None: + logits_processors = [RegexLogitsProcessor(regex_string, tokenizer)] + else: + logits_processors = [] + + sampling_params = SamplingParams( + **request_dict, logits_processors=logits_processors # type: ignore + ) + request_id = random_uuid() + + results_generator = engine.generate(prompt, sampling_params, request_id) # type: ignore + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + text_outputs = [prompt + output.text for output in request_output.outputs] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) # type: ignore + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + # Adds the `engine_use_ray`, `disable_log_requests` and `max_log_len` + # arguments + engine_args: AsyncEngineArgs = AsyncEngineArgs.from_cli_args(args) # type: ignore + + # Sets default for the model (`facebook/opt-125m`) + engine = AsyncLLMEngine.from_engine_args(engine_args) + tokenizer = adapt_tokenizer(tokenizer=engine.engine.tokenizer.tokenizer) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) diff --git a/.venv/lib/python3.11/site-packages/outlines/types/__init__.py b/.venv/lib/python3.11/site-packages/outlines/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d2b8cd350c46a3217600adcf86c42f4cf4e7fc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/types/__init__.py @@ -0,0 +1,4 @@ +from . import airports, countries +from .email import Email +from .isbn import ISBN +from .locales import locale diff --git a/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b704eb478d0a900a6861886d5707ffca85282bd0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/airports.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/airports.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fbbd76b9bd2db43504a14a828ac5996277d1454 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/airports.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/countries.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/countries.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1288656769357e435c704f3066b348176c22195d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/countries.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/email.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/email.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1127ee81c9a9488e5121bfefbb88aec5653d6b51 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/email.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/isbn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/isbn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..479a91cbf9e8e00b9b4a8caf19e9c0d8dbb82138 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/isbn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/locales.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/locales.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9bda528f2f1383420d7e1815c48fa776b703b2c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/locales.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/phone_numbers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/phone_numbers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9781af0a593ceb3323959af6f9a7df5c0d1e2389 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/phone_numbers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/zip_codes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/zip_codes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf6d7e2c2068afea159be0cf030c09aa1e61afa1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/outlines/types/__pycache__/zip_codes.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/outlines/types/airports.py b/.venv/lib/python3.11/site-packages/outlines/types/airports.py new file mode 100644 index 0000000000000000000000000000000000000000..934ae18441c1a2bd5cddb327533784e926494f6c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/types/airports.py @@ -0,0 +1,11 @@ +"""Generate valid airport codes.""" +from enum import Enum + +import airportsdata + +AIRPORT_IATA_LIST = [ + (v["iata"], v["iata"]) for v in airportsdata.load().values() if v["iata"] +] + + +IATA = Enum("Airport", AIRPORT_IATA_LIST) # type:ignore diff --git a/.venv/lib/python3.11/site-packages/outlines/types/countries.py b/.venv/lib/python3.11/site-packages/outlines/types/countries.py new file mode 100644 index 0000000000000000000000000000000000000000..bbfc0dde7eaff0ad4fbef99e804703278753f897 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/types/countries.py @@ -0,0 +1,19 @@ +"""Generate valid country codes and names.""" +from enum import Enum + +import pycountry + +ALPHA_2_CODE = [(country.alpha_2, country.alpha_2) for country in pycountry.countries] +Alpha2 = Enum("Alpha_2", ALPHA_2_CODE) # type:ignore + +ALPHA_3_CODE = [(country.alpha_3, country.alpha_3) for country in pycountry.countries] +Alpha3 = Enum("Alpha_2", ALPHA_3_CODE) # type:ignore + +NUMERIC_CODE = [(country.numeric, country.numeric) for country in pycountry.countries] +Numeric = Enum("Numeric_code", NUMERIC_CODE) # type:ignore + +NAME = [(country.name, country.name) for country in pycountry.countries] +Name = Enum("Name", NAME) # type:ignore + +FLAG = [(country.flag, country.flag) for country in pycountry.countries] +Flag = Enum("Flag", FLAG) # type:ignore diff --git a/.venv/lib/python3.11/site-packages/outlines/types/email.py b/.venv/lib/python3.11/site-packages/outlines/types/email.py new file mode 100644 index 0000000000000000000000000000000000000000..45f8c4b2cac8d4bdda23e74e450ce302f6ea711f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/types/email.py @@ -0,0 +1,11 @@ +"""Email Address types.""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +# Taken from StackOverflow +# https://stackoverflow.com/a/201378/14773537 +EMAIL_REGEX = r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])""" +Email = Annotated[ + str, + WithJsonSchema({"type": "string", "pattern": EMAIL_REGEX}), +] diff --git a/.venv/lib/python3.11/site-packages/outlines/types/isbn.py b/.venv/lib/python3.11/site-packages/outlines/types/isbn.py new file mode 100644 index 0000000000000000000000000000000000000000..5aebb067ec0254a6cb65d825e71e2eba53e5aa23 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/types/isbn.py @@ -0,0 +1,12 @@ +"""ISBN type""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +# Matches any ISBN number. Note that this is not completely correct as not all +# 10 or 13 digits numbers are valid ISBNs. See https://en.wikipedia.org/wiki/ISBN +# Taken from O'Reilly's Regular Expression Cookbook: +# https://www.oreilly.com/library/view/regular-expressions-cookbook/9781449327453/ch04s13.html +# TODO: Can this be represented by a grammar or do we need semantic checks? +ISBN_REGEX = r"(?:ISBN(?:-1[03])?:? )?(?=[0-9X]{10}$|(?=(?:[0-9]+[- ]){3})[- 0-9X]{13}$|97[89][0-9]{10}$|(?=(?:[0-9]+[- ]){4})[- 0-9]{17}$)(?:97[89][- ]?)?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]" + +ISBN = Annotated[str, WithJsonSchema({"type": "string", "pattern": ISBN_REGEX})] diff --git a/.venv/lib/python3.11/site-packages/outlines/types/locales.py b/.venv/lib/python3.11/site-packages/outlines/types/locales.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d251bae6dc05a787dd563c170829495b09ab76 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/types/locales.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from outlines.types.phone_numbers import USPhoneNumber +from outlines.types.zip_codes import USZipCode + + +@dataclass +class US: + ZipCode = USZipCode + PhoneNumber = USPhoneNumber + + +def locale(locale_str: str): + locales = {"us": US} + + if locale_str not in locales: + raise NotImplementedError( + f"The locale {locale_str} is not supported yet. Please don't hesitate to create custom types for you locale and open a Pull Request." + ) + + return locales[locale_str] diff --git a/.venv/lib/python3.11/site-packages/outlines/types/phone_numbers.py b/.venv/lib/python3.11/site-packages/outlines/types/phone_numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..618687e759c0d2a9689630af769042752b6443a3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/types/phone_numbers.py @@ -0,0 +1,16 @@ +"""Phone number types. + +We currently only support US phone numbers. We can however imagine having custom types +for each country, for instance leveraging the `phonenumbers` library. + +""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}" + + +USPhoneNumber = Annotated[ + str, + WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}), +] diff --git a/.venv/lib/python3.11/site-packages/outlines/types/zip_codes.py b/.venv/lib/python3.11/site-packages/outlines/types/zip_codes.py new file mode 100644 index 0000000000000000000000000000000000000000..67d994d5ce217f57193f215e8e69d314b3dc0e32 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/outlines/types/zip_codes.py @@ -0,0 +1,13 @@ +"""Zip code types. + +We currently only support US Zip Codes. + +""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +# This matches Zip and Zip+4 codes +US_ZIP_CODE = r"\d{5}(?:-\d{4})?" + + +USZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})] diff --git a/.venv/lib/python3.11/site-packages/torchvision.libs/libpng16.7f72a3c5.so.16 b/.venv/lib/python3.11/site-packages/torchvision.libs/libpng16.7f72a3c5.so.16 new file mode 100644 index 0000000000000000000000000000000000000000..4539125cd1f7670b14d68f69e39b8a2b29b8a7d6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision.libs/libpng16.7f72a3c5.so.16 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0b8f3c80b385da99dea0cf7c8daa954000d23b0164f112abe7b2faff0a0f63d +size 1079081 diff --git a/.venv/lib/python3.11/site-packages/urllib3-2.3.0.dist-info/METADATA b/.venv/lib/python3.11/site-packages/urllib3-2.3.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..d2064b9253f9c1d53d31096c87ea880b433f2e20 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/urllib3-2.3.0.dist-info/METADATA @@ -0,0 +1,154 @@ +Metadata-Version: 2.4 +Name: urllib3 +Version: 2.3.0 +Summary: HTTP library with thread-safe connection pooling, file post, and more. +Project-URL: Changelog, https://github.com/urllib3/urllib3/blob/main/CHANGES.rst +Project-URL: Documentation, https://urllib3.readthedocs.io +Project-URL: Code, https://github.com/urllib3/urllib3 +Project-URL: Issue tracker, https://github.com/urllib3/urllib3/issues +Author-email: Andrey Petrov +Maintainer-email: Seth Michael Larson , Quentin Pradet , Illia Volochii +License-File: LICENSE.txt +Keywords: filepost,http,httplib,https,pooling,ssl,threadsafe,urllib +Classifier: Environment :: Web Environment +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: Internet :: WWW/HTTP +Classifier: Topic :: Software Development :: Libraries +Requires-Python: >=3.9 +Provides-Extra: brotli +Requires-Dist: brotli>=1.0.9; (platform_python_implementation == 'CPython') and extra == 'brotli' +Requires-Dist: brotlicffi>=0.8.0; (platform_python_implementation != 'CPython') and extra == 'brotli' +Provides-Extra: h2 +Requires-Dist: h2<5,>=4; extra == 'h2' +Provides-Extra: socks +Requires-Dist: pysocks!=1.5.7,<2.0,>=1.5.6; extra == 'socks' +Provides-Extra: zstd +Requires-Dist: zstandard>=0.18.0; extra == 'zstd' +Description-Content-Type: text/markdown + +

+ +![urllib3](https://github.com/urllib3/urllib3/raw/main/docs/_static/banner_github.svg) + +

+ +

+ PyPI Version + Python Versions + Join our Discord + Coverage Status + Build Status on GitHub + Documentation Status
+ OpenSSF Scorecard + SLSA 3 + CII Best Practices +

+ +urllib3 is a powerful, *user-friendly* HTTP client for Python. Much of the +Python ecosystem already uses urllib3 and you should too. +urllib3 brings many critical features that are missing from the Python +standard libraries: + +- Thread safety. +- Connection pooling. +- Client-side SSL/TLS verification. +- File uploads with multipart encoding. +- Helpers for retrying requests and dealing with HTTP redirects. +- Support for gzip, deflate, brotli, and zstd encoding. +- Proxy support for HTTP and SOCKS. +- 100% test coverage. + +urllib3 is powerful and easy to use: + +```python3 +>>> import urllib3 +>>> resp = urllib3.request("GET", "http://httpbin.org/robots.txt") +>>> resp.status +200 +>>> resp.data +b"User-agent: *\nDisallow: /deny\n" +``` + +## Installing + +urllib3 can be installed with [pip](https://pip.pypa.io): + +```bash +$ python -m pip install urllib3 +``` + +Alternatively, you can grab the latest source code from [GitHub](https://github.com/urllib3/urllib3): + +```bash +$ git clone https://github.com/urllib3/urllib3.git +$ cd urllib3 +$ pip install . +``` + + +## Documentation + +urllib3 has usage and reference documentation at [urllib3.readthedocs.io](https://urllib3.readthedocs.io). + + +## Community + +urllib3 has a [community Discord channel](https://discord.gg/urllib3) for asking questions and +collaborating with other contributors. Drop by and say hello 👋 + + +## Contributing + +urllib3 happily accepts contributions. Please see our +[contributing documentation](https://urllib3.readthedocs.io/en/latest/contributing.html) +for some tips on getting started. + + +## Security Disclosures + +To report a security vulnerability, please use the +[Tidelift security contact](https://tidelift.com/security). +Tidelift will coordinate the fix and disclosure with maintainers. + + +## Maintainers + +- [@sethmlarson](https://github.com/sethmlarson) (Seth M. Larson) +- [@pquentin](https://github.com/pquentin) (Quentin Pradet) +- [@illia-v](https://github.com/illia-v) (Illia Volochii) +- [@theacodes](https://github.com/theacodes) (Thea Flowers) +- [@haikuginger](https://github.com/haikuginger) (Jess Shapiro) +- [@lukasa](https://github.com/lukasa) (Cory Benfield) +- [@sigmavirus24](https://github.com/sigmavirus24) (Ian Stapleton Cordasco) +- [@shazow](https://github.com/shazow) (Andrey Petrov) + +👋 + + +## Sponsorship + +If your company benefits from this library, please consider [sponsoring its +development](https://urllib3.readthedocs.io/en/latest/sponsors.html). + + +## For Enterprise + +Professional support for urllib3 is available as part of the [Tidelift +Subscription][1]. Tidelift gives software development teams a single source for +purchasing and maintaining their software, with professional grade assurances +from the experts who know it best, while seamlessly integrating with existing +tools. + +[1]: https://tidelift.com/subscription/pkg/pypi-urllib3?utm_source=pypi-urllib3&utm_medium=referral&utm_campaign=readme diff --git a/.venv/lib/python3.11/site-packages/urllib3-2.3.0.dist-info/WHEEL b/.venv/lib/python3.11/site-packages/urllib3-2.3.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..12228d414b6cfed7c39d3781c85c63256a1d7fb5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/urllib3-2.3.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.27.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/.venv/lib/python3.11/site-packages/urllib3-2.3.0.dist-info/licenses/LICENSE.txt b/.venv/lib/python3.11/site-packages/urllib3-2.3.0.dist-info/licenses/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..e6183d0276b26c5b87aecccf8d0d5bcd7b1148d4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/urllib3-2.3.0.dist-info/licenses/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2008-2020 Andrey Petrov and contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.