"""An example Flight CLI client.""" |
import argparse |
import sys |
import pyarrow |
import pyarrow.flight |
import pyarrow.csv as csv |
def list_flights(args, client, connection_args={}): |
print('Flights\n=======') |
for flight in client.list_flights(): |
descriptor = flight.descriptor |
if descriptor.descriptor_type == pyarrow.flight.DescriptorType.PATH: |
print("Path:", descriptor.path) |
elif descriptor.descriptor_type == pyarrow.flight.DescriptorType.CMD: |
print("Command:", descriptor.command) |
else: |
print("Unknown descriptor type") |
print("Total records:", end=" ") |
if flight.total_records >= 0: |
print(flight.total_records) |
else: |
print("Unknown") |
print("Total bytes:", end=" ") |
if flight.total_bytes >= 0: |
print(flight.total_bytes) |
else: |
print("Unknown") |
print("Number of endpoints:", len(flight.endpoints)) |
print("Schema:") |
print(flight.schema) |
print('---') |
print('\nActions\n=======') |
for action in client.list_actions(): |
print("Type:", action.type) |
print("Description:", action.description) |
print('---') |
def do_action(args, client, connection_args={}): |
try: |
buf = pyarrow.allocate_buffer(0) |
action = pyarrow.flight.Action(args.action_type, buf) |
print('Running action', args.action_type) |
for result in client.do_action(action): |
print("Got result", result.body.to_pybytes()) |
except pyarrow.lib.ArrowIOError as e: |
print("Error calling action:", e) |
def push_data(args, client, connection_args={}): |
print('File Name:', args.file) |
my_table = csv.read_csv(args.file) |
print('Table rows=', str(len(my_table))) |
df = my_table.to_pandas() |
print(df.head()) |
writer, _ = client.do_put( |
pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema) |
writer.write_table(my_table) |
writer.close() |
def get_flight(args, client, connection_args={}): |
if args.path: |
descriptor = pyarrow.flight.FlightDescriptor.for_path(*args.path) |
else: |
descriptor = pyarrow.flight.FlightDescriptor.for_command(args.command) |
info = client.get_flight_info(descriptor) |
for endpoint in info.endpoints: |
print('Ticket:', endpoint.ticket) |
for location in endpoint.locations: |
print(location) |
get_client = pyarrow.flight.FlightClient(location, |
**connection_args) |
reader = get_client.do_get(endpoint.ticket) |
df = reader.read_pandas() |
print(df) |
def _add_common_arguments(parser): |
parser.add_argument('--tls', action='store_true', |
help='Enable transport-level security') |
parser.add_argument('--tls-roots', default=None, |
help='Path to trusted TLS certificate(s)') |
parser.add_argument("--mtls", nargs=2, default=None, |
metavar=('CERTFILE', 'KEYFILE'), |
help="Enable transport-level security") |
parser.add_argument('host', type=str, |
help="Address or hostname to connect to") |
def main(): |
parser = argparse.ArgumentParser() |
subcommands = parser.add_subparsers() |
cmd_list = subcommands.add_parser('list') |
cmd_list.set_defaults(action='list') |
_add_common_arguments(cmd_list) |
cmd_list.add_argument('-l', '--list', action='store_true', |
help="Print more details.") |
cmd_do = subcommands.add_parser('do') |
cmd_do.set_defaults(action='do') |
_add_common_arguments(cmd_do) |
cmd_do.add_argument('action_type', type=str, |
help="The action type to run.") |
cmd_put = subcommands.add_parser('put') |
cmd_put.set_defaults(action='put') |
_add_common_arguments(cmd_put) |
cmd_put.add_argument('file', type=str, |
help="CSV file to upload.") |
cmd_get = subcommands.add_parser('get') |
cmd_get.set_defaults(action='get') |
_add_common_arguments(cmd_get) |
cmd_get_descriptor = cmd_get.add_mutually_exclusive_group(required=True) |
cmd_get_descriptor.add_argument('-p', '--path', type=str, action='append', |
help="The path for the descriptor.") |
cmd_get_descriptor.add_argument('-c', '--command', type=str, |
help="The command for the descriptor.") |
args = parser.parse_args() |
if not hasattr(args, 'action'): |
parser.print_help() |
sys.exit(1) |
commands = { |
'list': list_flights, |
'do': do_action, |
'get': get_flight, |
'put': push_data, |
} |
host, port = args.host.split(':') |
port = int(port) |
scheme = "grpc+tcp" |
connection_args = {} |
if args.tls: |
scheme = "grpc+tls" |
if args.tls_roots: |
with open(args.tls_roots, "rb") as root_certs: |
connection_args["tls_root_certs"] = root_certs.read() |
if args.mtls: |
with open(args.mtls[0], "rb") as cert_file: |
tls_cert_chain = cert_file.read() |
with open(args.mtls[1], "rb") as key_file: |
tls_private_key = key_file.read() |
connection_args["cert_chain"] = tls_cert_chain |
connection_args["private_key"] = tls_private_key |
client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}", |
**connection_args) |
while True: |
try: |
action = pyarrow.flight.Action("healthcheck", b"") |
options = pyarrow.flight.FlightCallOptions(timeout=1) |
list(client.do_action(action, options=options)) |
break |
except pyarrow.ArrowIOError as e: |
if "Deadline" in str(e): |
print("Server is not ready, waiting...") |
commands[args.action](args, client, connection_args) |
if __name__ == '__main__': |
main() |