t5-parliament-categorisation / my_preprocessors.py
pere's picture
more small changes
0173bf7
import collections
import functools
import math
import re
from typing import Callable, Mapping, Optional, Sequence, Union
import uuid
from absl import logging
import babel
import gin
import seqio
import tensorflow.compat.v2 as tf
import json
import pandas as pd
# We disable no-value-for-parameter since the seqio.map_over_dataset leads to
# a false positive when seeds are provided.
# pylint:disable=no-value-for-parameter
AUTOTUNE = tf.data.experimental.AUTOTUNE
FeatureType = Mapping[str, tf.Tensor]
rekey = seqio.preprocessors.rekey
tokenize = seqio.preprocessors.tokenize
@seqio.map_over_dataset
def parse_tsv(line, field_names=None, field_delim='\t'):
"""Splits TSV lines into dict examples mapping field name to string value.
Args:
line: an example containing a comma/tab-delimited string.
field_names: a list of strings, the ordered names of the TSV fields.
Defaults to "inputs" and "targets".
field_delim: a string, the delimiter to split on e.g. ',' for csv.
Returns:
A feature dict mapping field name to string value.
"""
breakpoint()
field_names = field_names or ['inputs', 'targets']
return dict(
zip(field_names,
tf.io.decode_csv(
line,
record_defaults=[''] * len(field_names),
field_delim=field_delim,
use_quote_delim=False)))
@seqio.map_over_dataset
def parse_json(line,field_delim='\t'):
"""Splits JSON lines into dict examples mapping.
Args:
line: an example containing valid json
Returns:
A feature dict mapping field name to string value.
"""
mydf = pd.read_json(line, lines=True)
line = mydf.to_csv(header=False, index=False,sep="\t").strip()
field_names = list(mydf.columns)
return dict(
zip(field_names,
tf.io.decode_csv(
line,
record_defaults=[''] * len(field_names),
field_delim=field_delim,
use_quote_delim=False)))