Compare commits

...

4 Commits

Author SHA1 Message Date
Michelle Ark
37dbd118d3 refactor: internalize parallel to RunTask._submit_batch 2024-11-28 17:06:51 -05:00
Michelle Ark
32002ea69f only run pre_hook on first batch, post_hook on last batch 2024-11-28 16:43:30 -05:00
Michelle Ark
bec5d57114 use Task.get_runner 2024-11-28 16:23:05 -05:00
Michelle Ark
0d61609acd microbatch: split out first and last batch to run in serial 2024-11-28 16:11:55 -05:00
3 changed files with 67 additions and 38 deletions

View File

@@ -602,15 +602,15 @@ class MicrobatchModelRunner(ModelRunner):
)
return relation is not None
def _should_run_in_parallel(
self,
relation_exists: bool,
) -> bool:
def should_run_in_parallel(self) -> bool:
if not self.adapter.supports(Capability.MicrobatchConcurrency):
run_in_parallel = False
elif not relation_exists:
elif not self.relation_exists:
# If the relation doesn't exist, we can't run in parallel
run_in_parallel = False
elif self.batch_idx == 0 or self.batch_idx == len(self.batches) - 1:
# First and last batch don't run in parallel
run_in_parallel = False
elif self.node.config.concurrent_batches is not None:
# If the relation exists and the `concurrent_batches` config isn't None, use the config value
run_in_parallel = self.node.config.concurrent_batches
@@ -703,52 +703,79 @@ class RunTask(CompileTask):
runner: MicrobatchModelRunner,
pool: ThreadPool,
) -> RunResult:
# Initial run computes batch metadata, unless model is skipped
# Initial run computes batch metadata
result = self.call_runner(runner)
batches, node, relation_exists = runner.batches, runner.node, runner.relation_exists
# Return early if model should be skipped, or there are no batches to execute
if result.status == RunStatus.Skipped:
return result
elif len(runner.batches) == 0:
return result
batch_results: List[RunResult] = []
# Execute batches serially until a relation exists, at which point future batches are run in parallel
relation_exists = runner.relation_exists
batch_idx = 0
while batch_idx < len(runner.batches):
batch_runner = MicrobatchModelRunner(
self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes
# Run all batches except last batch, in parallel if possible
while batch_idx < len(runner.batches) - 1:
relation_exists = self._submit_batch(
node, relation_exists, batches, batch_idx, batch_results, pool
)
batch_runner.set_batch_idx(batch_idx)
batch_runner.set_relation_exists(relation_exists)
batch_runner.set_batches(runner.batches)
if runner._should_run_in_parallel(relation_exists):
fire_event(
MicrobatchExecutionDebug(
msg=f"{batch_runner.describe_batch} is being run concurrently"
)
)
self._submit(pool, [batch_runner], batch_results.append)
else:
fire_event(
MicrobatchExecutionDebug(
msg=f"{batch_runner.describe_batch} is being run sequentially"
)
)
batch_results.append(self.call_runner(batch_runner))
relation_exists = batch_runner.relation_exists
batch_idx += 1
# Wait until all batches have completed
while len(batch_results) != len(runner.batches):
# Wait until all submitted batches have completed
while len(batch_results) != batch_idx:
pass
# Final batch runs once all others complete to ensure post_hook runs at the end
self._submit_batch(node, relation_exists, batches, batch_idx, batch_results, pool)
# Finalize run: merge results, track model run, and print final result line
runner.merge_batch_results(result, batch_results)
track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter)
runner.print_result_line(result)
return result
def _submit_batch(
self,
node: ModelNode,
relation_exists: bool,
batches: Dict[int, BatchType],
batch_idx: int,
batch_results: List[RunResult],
pool: ThreadPool,
):
node_copy = deepcopy(node)
# Only run pre_hook(s) for first batch
if batch_idx != 0:
node_copy.config.pre_hook = []
# Only run post_hook(s) for last batch
elif batch_idx != len(batches) - 1:
node_copy.config.post_hook = []
batch_runner = self.get_runner(node_copy)
assert isinstance(batch_runner, MicrobatchModelRunner)
batch_runner.set_batch_idx(batch_idx)
batch_runner.set_relation_exists(relation_exists)
batch_runner.set_batches(batches)
if batch_runner.should_run_in_parallel():
fire_event(
MicrobatchExecutionDebug(
msg=f"{batch_runner.describe_batch} is being run concurrently"
)
)
self._submit(pool, [batch_runner], batch_results.append)
else:
fire_event(
MicrobatchExecutionDebug(
msg=f"{batch_runner.describe_batch} is being run sequentially"
)
)
batch_results.append(self.call_runner(batch_runner))
relation_exists = batch_runner.relation_exists
return relation_exists
def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]:
package_name = hook.package_name
if package_name == self.config.project_name:

View File

@@ -875,7 +875,7 @@ class TestMicrobatchCanRunParallelOrSequential(BaseMicrobatchTest):
def test_microbatch(
self, mocker: MockerFixture, project, batch_exc_catcher: EventCatcher
) -> None:
mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner._should_run_in_parallel")
mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner.should_run_in_parallel")
# Should be run in parallel
mocked_srip.return_value = True

View File

@@ -264,7 +264,7 @@ class TestMicrobatchModelRunner:
(False, False, False, True, False),
],
)
def test__should_run_in_parallel(
def test_should_run_in_parallel(
self,
mocker: MockerFixture,
model_runner: MicrobatchModelRunner,
@@ -276,11 +276,13 @@ class TestMicrobatchModelRunner:
) -> None:
model_runner.node._has_this = has_this
model_runner.node.config = ModelConfig(concurrent_batches=concurrent_batches)
model_runner.set_relation_exists(has_relation)
mocked_supports = mocker.patch.object(model_runner.adapter, "supports")
mocked_supports.return_value = adapter_microbatch_concurrency
# Assert result of _should_run_in_parallel
assert model_runner._should_run_in_parallel(has_relation) == expectation
# Assert result of should_run_in_parallel
assert model_runner.should_run_in_parallel() == expectation
class TestRunTask: