File size: 2,335 Bytes
cd607b2 eac37df cd607b2 f5ec828 eac37df cd607b2 7b856a8 69deff6 8200c4e 7b856a8 8200c4e 7b856a8 8200c4e 4e3dc76 8200c4e 7b856a8 69deff6 7b856a8 8200c4e 4e3dc76 7b856a8 5b30d27 7b856a8 4e3dc76 7b856a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
desc = """
### Typed Extraction
Information extraction that is automatically generated from a typed specification. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/pal.ipynb)
(Novel to MiniChain)
"""
# $
from minichain import prompt, show, OpenAI, transform
from dataclasses import dataclass, is_dataclass, fields
from typing import List, Type, Dict, Any, get_origin, get_args
from enum import Enum
from jinja2 import select_autoescape, FileSystemLoader, Environment
import json
def enum(x: Type[Enum]) -> Dict[str, int]:
d = {e.name: e.value for e in x}
return d
def walk(x: Any) -> Any:
if issubclass(x if get_origin(x) is None else get_origin(x), List):
return {"_t_": "list", "t": walk(get_args(x)[0])}
if issubclass(x, Enum):
return enum(x)
if is_dataclass(x):
return {y.name: walk(y.type) for y in fields(x)}
return x.__name__
def type_to_prompt(out: type) -> str:
tmp = env.get_template("type_prompt.pmpt.tpl")
d = walk(out)
return tmp.render({"typ": d})
env = Environment(
loader=FileSystemLoader("."),
autoescape=select_autoescape(),
extensions=["jinja2_highlight.HighlightExtension"],
)
# Data specification
# +
class StatType(Enum):
POINTS = 1
REBOUNDS = 2
ASSISTS = 3
@dataclass
class Stat:
value: int
stat: StatType
@dataclass
class Player:
player: str
stats: List[Stat]
# -
@prompt(OpenAI(), template_file="stats.pmpt.tpl")
def stats(model, passage):
return model.stream(dict(passage=passage, typ=type_to_prompt(Player)))
@transform()
def to_data(s:str):
return [Player(**j) for j in json.loads(s)]
# $
article = open("sixers.txt").read()
gradio = show(lambda passage: to_data(stats(passage)),
examples=[article],
subprompts=[stats],
out_type="json",
description=desc,
code=open("stats.py", "r").read().split("$")[1].strip().strip("#").strip(),
)
if __name__ == "__main__":
gradio.queue().launch()
# ExtractionPrompt().show({"passage": "Harden had 10 rebounds."},
# '[{"player": "Harden", "stats": {"value": 10, "stat": 2}}]')
# # View the run log.
# minichain.show_log("bash.log")
|