| |
|
| |
|
| |
|
| |
|
| |
|
| | from caffe2.python import control, core, test_util, workspace |
| |
|
| | import logging |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class TestControl(test_util.TestCase): |
| | def setUp(self): |
| | super(TestControl, self).setUp() |
| | self.N_ = 10 |
| |
|
| | self.init_net_ = core.Net("init-net") |
| | cnt = self.init_net_.CreateCounter([], init_count=0) |
| | const_n = self.init_net_.ConstantFill( |
| | [], shape=[], value=self.N_, dtype=core.DataType.INT64) |
| | const_0 = self.init_net_.ConstantFill( |
| | [], shape=[], value=0, dtype=core.DataType.INT64) |
| |
|
| | self.cnt_net_ = core.Net("cnt-net") |
| | self.cnt_net_.CountUp([cnt]) |
| | curr_cnt = self.cnt_net_.RetrieveCount([cnt]) |
| | self.init_net_.ConstantFill( |
| | [], [curr_cnt], shape=[], value=0, dtype=core.DataType.INT64) |
| | self.cnt_net_.AddExternalOutput(curr_cnt) |
| |
|
| | self.cnt_2_net_ = core.Net("cnt-2-net") |
| | self.cnt_2_net_.CountUp([cnt]) |
| | self.cnt_2_net_.CountUp([cnt]) |
| | curr_cnt_2 = self.cnt_2_net_.RetrieveCount([cnt]) |
| | self.init_net_.ConstantFill( |
| | [], [curr_cnt_2], shape=[], value=0, dtype=core.DataType.INT64) |
| | self.cnt_2_net_.AddExternalOutput(curr_cnt_2) |
| |
|
| | self.cond_net_ = core.Net("cond-net") |
| | cond_blob = self.cond_net_.LT([curr_cnt, const_n]) |
| | self.cond_net_.AddExternalOutput(cond_blob) |
| |
|
| | self.not_cond_net_ = core.Net("not-cond-net") |
| | cond_blob = self.not_cond_net_.GE([curr_cnt, const_n]) |
| | self.not_cond_net_.AddExternalOutput(cond_blob) |
| |
|
| | self.true_cond_net_ = core.Net("true-cond-net") |
| | true_blob = self.true_cond_net_.LT([const_0, const_n]) |
| | self.true_cond_net_.AddExternalOutput(true_blob) |
| |
|
| | self.false_cond_net_ = core.Net("false-cond-net") |
| | false_blob = self.false_cond_net_.GT([const_0, const_n]) |
| | self.false_cond_net_.AddExternalOutput(false_blob) |
| |
|
| | self.idle_net_ = core.Net("idle-net") |
| | self.idle_net_.ConstantFill( |
| | [], shape=[], value=0, dtype=core.DataType.INT64) |
| |
|
| | def CheckNetOutput(self, nets_and_expects): |
| | """ |
| | Check the net output is expected |
| | nets_and_expects is a list of tuples (net, expect) |
| | """ |
| | for net, expect in nets_and_expects: |
| | output = workspace.FetchBlob( |
| | net.Proto().external_output[-1]) |
| | self.assertEqual(output, expect) |
| |
|
| | def CheckNetAllOutput(self, net, expects): |
| | """ |
| | Check the net output is expected |
| | expects is a list of bools. |
| | """ |
| | self.assertEqual(len(net.Proto().external_output), len(expects)) |
| | for i in range(len(expects)): |
| | output = workspace.FetchBlob( |
| | net.Proto().external_output[i]) |
| | self.assertEqual(output, expects[i]) |
| |
|
| | def BuildAndRunPlan(self, step): |
| | plan = core.Plan("test") |
| | plan.AddStep(control.Do('init', self.init_net_)) |
| | plan.AddStep(step) |
| | self.assertEqual(workspace.RunPlan(plan), True) |
| |
|
| | def ForLoopTest(self, nets_or_steps): |
| | step = control.For('myFor', nets_or_steps, self.N_) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(self.cnt_net_, self.N_)]) |
| |
|
| | def testForLoopWithNets(self): |
| | self.ForLoopTest(self.cnt_net_) |
| | self.ForLoopTest([self.cnt_net_, self.idle_net_]) |
| |
|
| | def testForLoopWithStep(self): |
| | step = control.Do('count', self.cnt_net_) |
| | self.ForLoopTest(step) |
| | self.ForLoopTest([step, self.idle_net_]) |
| |
|
| | def WhileLoopTest(self, nets_or_steps): |
| | step = control.While('myWhile', self.cond_net_, nets_or_steps) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(self.cnt_net_, self.N_)]) |
| |
|
| | def testWhileLoopWithNet(self): |
| | self.WhileLoopTest(self.cnt_net_) |
| | self.WhileLoopTest([self.cnt_net_, self.idle_net_]) |
| |
|
| | def testWhileLoopWithStep(self): |
| | step = control.Do('count', self.cnt_net_) |
| | self.WhileLoopTest(step) |
| | self.WhileLoopTest([step, self.idle_net_]) |
| |
|
| | def UntilLoopTest(self, nets_or_steps): |
| | step = control.Until('myUntil', self.not_cond_net_, nets_or_steps) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(self.cnt_net_, self.N_)]) |
| |
|
| | def testUntilLoopWithNet(self): |
| | self.UntilLoopTest(self.cnt_net_) |
| | self.UntilLoopTest([self.cnt_net_, self.idle_net_]) |
| |
|
| | def testUntilLoopWithStep(self): |
| | step = control.Do('count', self.cnt_net_) |
| | self.UntilLoopTest(step) |
| | self.UntilLoopTest([step, self.idle_net_]) |
| |
|
| | def DoWhileLoopTest(self, nets_or_steps): |
| | step = control.DoWhile('myDoWhile', self.cond_net_, nets_or_steps) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(self.cnt_net_, self.N_)]) |
| |
|
| | def testDoWhileLoopWithNet(self): |
| | self.DoWhileLoopTest(self.cnt_net_) |
| | self.DoWhileLoopTest([self.idle_net_, self.cnt_net_]) |
| |
|
| | def testDoWhileLoopWithStep(self): |
| | step = control.Do('count', self.cnt_net_) |
| | self.DoWhileLoopTest(step) |
| | self.DoWhileLoopTest([self.idle_net_, step]) |
| |
|
| | def DoUntilLoopTest(self, nets_or_steps): |
| | step = control.DoUntil('myDoUntil', self.not_cond_net_, nets_or_steps) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(self.cnt_net_, self.N_)]) |
| |
|
| | def testDoUntilLoopWithNet(self): |
| | self.DoUntilLoopTest(self.cnt_net_) |
| | self.DoUntilLoopTest([self.cnt_net_, self.idle_net_]) |
| |
|
| | def testDoUntilLoopWithStep(self): |
| | step = control.Do('count', self.cnt_net_) |
| | self.DoUntilLoopTest(step) |
| | self.DoUntilLoopTest([self.idle_net_, step]) |
| |
|
| | def IfCondTest(self, cond_net, expect, cond_on_blob): |
| | if cond_on_blob: |
| | step = control.Do( |
| | 'if-all', |
| | control.Do('count', cond_net), |
| | control.If('myIf', cond_net.Proto().external_output[-1], |
| | self.cnt_net_)) |
| | else: |
| | step = control.If('myIf', cond_net, self.cnt_net_) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(self.cnt_net_, expect)]) |
| |
|
| | def testIfCondTrueOnNet(self): |
| | self.IfCondTest(self.true_cond_net_, 1, False) |
| |
|
| | def testIfCondTrueOnBlob(self): |
| | self.IfCondTest(self.true_cond_net_, 1, True) |
| |
|
| | def testIfCondFalseOnNet(self): |
| | self.IfCondTest(self.false_cond_net_, 0, False) |
| |
|
| | def testIfCondFalseOnBlob(self): |
| | self.IfCondTest(self.false_cond_net_, 0, True) |
| |
|
| | def IfElseCondTest(self, cond_net, cond_value, expect, cond_on_blob): |
| | if cond_value: |
| | run_net = self.cnt_net_ |
| | else: |
| | run_net = self.cnt_2_net_ |
| | if cond_on_blob: |
| | step = control.Do( |
| | 'if-else-all', |
| | control.Do('count', cond_net), |
| | control.If('myIfElse', cond_net.Proto().external_output[-1], |
| | self.cnt_net_, self.cnt_2_net_)) |
| | else: |
| | step = control.If('myIfElse', cond_net, |
| | self.cnt_net_, self.cnt_2_net_) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(run_net, expect)]) |
| |
|
| | def testIfElseCondTrueOnNet(self): |
| | self.IfElseCondTest(self.true_cond_net_, True, 1, False) |
| |
|
| | def testIfElseCondTrueOnBlob(self): |
| | self.IfElseCondTest(self.true_cond_net_, True, 1, True) |
| |
|
| | def testIfElseCondFalseOnNet(self): |
| | self.IfElseCondTest(self.false_cond_net_, False, 2, False) |
| |
|
| | def testIfElseCondFalseOnBlob(self): |
| | self.IfElseCondTest(self.false_cond_net_, False, 2, True) |
| |
|
| | def IfNotCondTest(self, cond_net, expect, cond_on_blob): |
| | if cond_on_blob: |
| | step = control.Do( |
| | 'if-not', |
| | control.Do('count', cond_net), |
| | control.IfNot('myIfNot', cond_net.Proto().external_output[-1], |
| | self.cnt_net_)) |
| | else: |
| | step = control.IfNot('myIfNot', cond_net, self.cnt_net_) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(self.cnt_net_, expect)]) |
| |
|
| | def testIfNotCondTrueOnNet(self): |
| | self.IfNotCondTest(self.true_cond_net_, 0, False) |
| |
|
| | def testIfNotCondTrueOnBlob(self): |
| | self.IfNotCondTest(self.true_cond_net_, 0, True) |
| |
|
| | def testIfNotCondFalseOnNet(self): |
| | self.IfNotCondTest(self.false_cond_net_, 1, False) |
| |
|
| | def testIfNotCondFalseOnBlob(self): |
| | self.IfNotCondTest(self.false_cond_net_, 1, True) |
| |
|
| | def IfNotElseCondTest(self, cond_net, cond_value, expect, cond_on_blob): |
| | if cond_value: |
| | run_net = self.cnt_2_net_ |
| | else: |
| | run_net = self.cnt_net_ |
| | if cond_on_blob: |
| | step = control.Do( |
| | 'if-not-else', |
| | control.Do('count', cond_net), |
| | control.IfNot('myIfNotElse', |
| | cond_net.Proto().external_output[-1], |
| | self.cnt_net_, self.cnt_2_net_)) |
| | else: |
| | step = control.IfNot('myIfNotElse', cond_net, |
| | self.cnt_net_, self.cnt_2_net_) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(run_net, expect)]) |
| |
|
| | def testIfNotElseCondTrueOnNet(self): |
| | self.IfNotElseCondTest(self.true_cond_net_, True, 2, False) |
| |
|
| | def testIfNotElseCondTrueOnBlob(self): |
| | self.IfNotElseCondTest(self.true_cond_net_, True, 2, True) |
| |
|
| | def testIfNotElseCondFalseOnNet(self): |
| | self.IfNotElseCondTest(self.false_cond_net_, False, 1, False) |
| |
|
| | def testIfNotElseCondFalseOnBlob(self): |
| | self.IfNotElseCondTest(self.false_cond_net_, False, 1, True) |
| |
|
| | def testSwitch(self): |
| | step = control.Switch( |
| | 'mySwitch', |
| | (self.false_cond_net_, self.cnt_net_), |
| | (self.true_cond_net_, self.cnt_2_net_) |
| | ) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(self.cnt_net_, 0), (self.cnt_2_net_, 2)]) |
| |
|
| | def testSwitchNot(self): |
| | step = control.SwitchNot( |
| | 'mySwitchNot', |
| | (self.false_cond_net_, self.cnt_net_), |
| | (self.true_cond_net_, self.cnt_2_net_) |
| | ) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(self.cnt_net_, 1), (self.cnt_2_net_, 0)]) |
| |
|
| | def testBoolNet(self): |
| | bool_net = control.BoolNet(('a', True)) |
| | step = control.Do('bool', bool_net) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetAllOutput(bool_net, [True]) |
| |
|
| | bool_net = control.BoolNet(('a', True), ('b', False)) |
| | step = control.Do('bool', bool_net) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetAllOutput(bool_net, [True, False]) |
| |
|
| | bool_net = control.BoolNet([('a', True), ('b', False)]) |
| | step = control.Do('bool', bool_net) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetAllOutput(bool_net, [True, False]) |
| |
|
| | def testCombineConditions(self): |
| | |
| | combine_net = control.CombineConditions( |
| | 'test', [self.true_cond_net_, self.false_cond_net_], 'Or') |
| | step = control.Do('combine', |
| | self.true_cond_net_, |
| | self.false_cond_net_, |
| | combine_net) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(combine_net, True)]) |
| |
|
| | |
| | combine_net = control.CombineConditions( |
| | 'test', [self.true_cond_net_, self.false_cond_net_], 'And') |
| | step = control.Do('combine', |
| | self.true_cond_net_, |
| | self.false_cond_net_, |
| | combine_net) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(combine_net, False)]) |
| |
|
| | def testMergeConditionNets(self): |
| | |
| | merge_net = control.MergeConditionNets( |
| | 'test', [self.true_cond_net_, self.false_cond_net_], 'Or') |
| | step = control.Do('merge', merge_net) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(merge_net, True)]) |
| |
|
| | |
| | merge_net = control.MergeConditionNets( |
| | 'test', [self.true_cond_net_, self.false_cond_net_], 'And') |
| | step = control.Do('merge', merge_net) |
| | self.BuildAndRunPlan(step) |
| | self.CheckNetOutput([(merge_net, False)]) |
| |
|