forked from repo-mirrors/dbt-core
Add batch context object to microbatch jinja context (#11031)
* Add `batch_id` to jinja context of microbatch batches
* Add changie doc
* Update `format_batch_start` to assume `batch_start` is always provided
* Add "runtime only" property `batch_context` to `ModelNode`
By it being "runtime only" we mean that it doesn't exist on the artifact
and thus won't be written out to the manifest artifact.
* Begin populating `batch_context` during materialization execution for microbatch batches
* Fix circular import
* Fixup MicrobatchBuilder.batch_id property method
* Ensure MicrobatchModelRunner doesn't double compile batches
We were compiling the node for each batch _twice_. Besides making microbatch
models more expensive than they needed to be, double compiling wasn't
causing any issue. However the first compilation was happening _before_ we
had added the batch context information to the model node for the batch. This
was leading to models which try to access the `batch_context` information on the
model to blow up, which was undesirable. As such, we've now gone and skipped
the first compilation. We've done this similar to how SavedQuery nodes skip
compilation.
* Add `__post_serialize__` method to `BatchContext` to ensure correct dict shape
This is weird, but necessary, I apologize. Mashumaro handles the
dictification of this class via a compile time generated `to_dict`
method based off of the _typing_ of th class. By default `datetime`
types are converted to strings. We don't want that, we want them to
stay datetimes.
* Update tests to check for `batch_context`
* Update `resolve_event_time_filter` to use new `batch_context`
* Stop testing for batchless compiled code for microbatch models
In 45daec72f4 we stopped an extra compilation
that was happening per batch prior to the batch_context being loaded. Stopping
this extra compilation means that compiled sql for the microbatch model without
the event time filter / batch context is no longer produced. We have discussed
this and _believe_ it is okay given that this is a new node type that has not
hit GA yet.
* Rename `ModelNode.batch_context` to `ModelNode.batch`
* Rename `build_batch_context` to `build_jinja_context_for_batch`
The name `build_batch_context` was confusing as
1) We have a `BatchContext` object, which the method was not building
2) The method builds the jinja context for the batch
As such it felt appropriate to rename the method to more accurately
communicate what it does.
* Rename test macro `invalid_batch_context_macro_sql` to `invalid_batch_jinja_context_macro_sql`
This rename was to make it more clear that the jinja context for a
batch was being checked, as a batch_context has a slightly different
connotation.
* Update changie doc
This commit is contained in:
6
.changes/unreleased/Features-20241121-125630.yaml
Normal file
6
.changes/unreleased/Features-20241121-125630.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
kind: Features
|
||||
body: Add `batch` context object to model jinja context
|
||||
time: 2024-11-21T12:56:30.715473-06:00
|
||||
custom:
|
||||
Author: QMalcolm
|
||||
Issue: "11025"
|
||||
@@ -244,9 +244,10 @@ class BaseResolver(metaclass=abc.ABCMeta):
|
||||
and self.model.config.materialized == "incremental"
|
||||
and self.model.config.incremental_strategy == "microbatch"
|
||||
and self.manifest.use_microbatch_batches(project_name=self.config.project_name)
|
||||
and self.model.batch is not None
|
||||
):
|
||||
start = self.model.config.get("__dbt_internal_microbatch_event_time_start")
|
||||
end = self.model.config.get("__dbt_internal_microbatch_event_time_end")
|
||||
start = self.model.batch.event_time_start
|
||||
end = self.model.batch.event_time_end
|
||||
|
||||
if start is not None or end is not None:
|
||||
event_time_filter = EventTimeFilter(
|
||||
|
||||
@@ -93,6 +93,7 @@ from dbt_common.contracts.constraints import (
|
||||
ConstraintType,
|
||||
ModelLevelConstraint,
|
||||
)
|
||||
from dbt_common.dataclass_schema import dbtClassMixin
|
||||
from dbt_common.events.contextvars import set_log_contextvars
|
||||
from dbt_common.events.functions import warn_or_error
|
||||
|
||||
@@ -442,9 +443,30 @@ class HookNode(HookNodeResource, CompiledNode):
|
||||
return HookNodeResource
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchContext(dbtClassMixin):
|
||||
id: str
|
||||
event_time_start: datetime
|
||||
event_time_end: datetime
|
||||
|
||||
def __post_serialize__(self, data, context):
|
||||
# This is insane, but necessary, I apologize. Mashumaro handles the
|
||||
# dictification of this class via a compile time generated `to_dict`
|
||||
# method based off of the _typing_ of th class. By default `datetime`
|
||||
# types are converted to strings. We don't want that, we want them to
|
||||
# stay datetimes.
|
||||
# Note: This is safe because the `BatchContext` isn't part of the artifact
|
||||
# and thus doesn't get written out.
|
||||
new_data = super().__post_serialize__(data, context)
|
||||
new_data["event_time_start"] = self.event_time_start
|
||||
new_data["event_time_end"] = self.event_time_end
|
||||
return new_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelNode(ModelResource, CompiledNode):
|
||||
previous_batch_results: Optional[BatchResults] = None
|
||||
batch: Optional[BatchContext] = None
|
||||
_has_this: Optional[bool] = None
|
||||
|
||||
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
|
||||
|
||||
@@ -100,25 +100,25 @@ class MicrobatchBuilder:
|
||||
|
||||
return batches
|
||||
|
||||
def build_batch_context(self, incremental_batch: bool) -> Dict[str, Any]:
|
||||
def build_jinja_context_for_batch(self, incremental_batch: bool) -> Dict[str, Any]:
|
||||
"""
|
||||
Create context with entries that reflect microbatch model + incremental execution state
|
||||
|
||||
Assumes self.model has been (re)-compiled with necessary batch filters applied.
|
||||
"""
|
||||
batch_context: Dict[str, Any] = {}
|
||||
jinja_context: Dict[str, Any] = {}
|
||||
|
||||
# Microbatch model properties
|
||||
batch_context["model"] = self.model.to_dict()
|
||||
batch_context["sql"] = self.model.compiled_code
|
||||
batch_context["compiled_code"] = self.model.compiled_code
|
||||
jinja_context["model"] = self.model.to_dict()
|
||||
jinja_context["sql"] = self.model.compiled_code
|
||||
jinja_context["compiled_code"] = self.model.compiled_code
|
||||
|
||||
# Add incremental context variables for batches running incrementally
|
||||
if incremental_batch:
|
||||
batch_context["is_incremental"] = lambda: True
|
||||
batch_context["should_full_refresh"] = lambda: False
|
||||
jinja_context["is_incremental"] = lambda: True
|
||||
jinja_context["should_full_refresh"] = lambda: False
|
||||
|
||||
return batch_context
|
||||
return jinja_context
|
||||
|
||||
@staticmethod
|
||||
def offset_timestamp(timestamp: datetime, batch_size: BatchSize, offset: int) -> datetime:
|
||||
@@ -193,12 +193,11 @@ class MicrobatchBuilder:
|
||||
return truncated
|
||||
|
||||
@staticmethod
|
||||
def format_batch_start(
|
||||
batch_start: Optional[datetime], batch_size: BatchSize
|
||||
) -> Optional[str]:
|
||||
if batch_start is None:
|
||||
return batch_start
|
||||
def batch_id(start_time: datetime, batch_size: BatchSize) -> str:
|
||||
return MicrobatchBuilder.format_batch_start(start_time, batch_size).replace("-", "")
|
||||
|
||||
@staticmethod
|
||||
def format_batch_start(batch_start: datetime, batch_size: BatchSize) -> str:
|
||||
return str(
|
||||
batch_start.date() if (batch_start and batch_size != BatchSize.hour) else batch_start
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ from dbt.clients.jinja import MacroGenerator
|
||||
from dbt.config import RuntimeConfig
|
||||
from dbt.context.providers import generate_runtime_model_context
|
||||
from dbt.contracts.graph.manifest import Manifest
|
||||
from dbt.contracts.graph.nodes import HookNode, ModelNode, ResultNode
|
||||
from dbt.contracts.graph.nodes import BatchContext, HookNode, ModelNode, ResultNode
|
||||
from dbt.events.types import (
|
||||
GenericExceptionOnRun,
|
||||
LogHookEndLine,
|
||||
@@ -341,6 +341,13 @@ class MicrobatchModelRunner(ModelRunner):
|
||||
self.batches: Dict[int, BatchType] = {}
|
||||
self.relation_exists: bool = False
|
||||
|
||||
def compile(self, manifest: Manifest):
|
||||
# The default compile function is _always_ called. However, we do our
|
||||
# compilation _later_ in `_execute_microbatch_materialization`. This
|
||||
# meant the node was being compiled _twice_ for each batch. To get around
|
||||
# this, we've overriden the default compile method to do nothing
|
||||
return self.node
|
||||
|
||||
def set_batch_idx(self, batch_idx: int) -> None:
|
||||
self.batch_idx = batch_idx
|
||||
|
||||
@@ -353,7 +360,7 @@ class MicrobatchModelRunner(ModelRunner):
|
||||
def describe_node(self) -> str:
|
||||
return f"{self.node.language} microbatch model {self.get_node_representation()}"
|
||||
|
||||
def describe_batch(self, batch_start: Optional[datetime]) -> str:
|
||||
def describe_batch(self, batch_start: datetime) -> str:
|
||||
# Only visualize date if batch_start year/month/day
|
||||
formatted_batch_start = MicrobatchBuilder.format_batch_start(
|
||||
batch_start, self.node.config.batch_size
|
||||
@@ -530,10 +537,16 @@ class MicrobatchModelRunner(ModelRunner):
|
||||
# call materialization_macro to get a batch-level run result
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
# Set start/end in context prior to re-compiling
|
||||
# LEGACY: Set start/end in context prior to re-compiling (Will be removed for 1.10+)
|
||||
# TODO: REMOVE before 1.10 GA
|
||||
model.config["__dbt_internal_microbatch_event_time_start"] = batch[0]
|
||||
model.config["__dbt_internal_microbatch_event_time_end"] = batch[1]
|
||||
|
||||
# Create batch context on model node prior to re-compiling
|
||||
model.batch = BatchContext(
|
||||
id=MicrobatchBuilder.batch_id(batch[0], model.config.batch_size),
|
||||
event_time_start=batch[0],
|
||||
event_time_end=batch[1],
|
||||
)
|
||||
# Recompile node to re-resolve refs with event time filters rendered, update context
|
||||
self.compiler.compile_node(
|
||||
model,
|
||||
@@ -544,10 +557,10 @@ class MicrobatchModelRunner(ModelRunner):
|
||||
),
|
||||
)
|
||||
# Update jinja context with batch context members
|
||||
batch_context = microbatch_builder.build_batch_context(
|
||||
jinja_context = microbatch_builder.build_jinja_context_for_batch(
|
||||
incremental_batch=self.relation_exists
|
||||
)
|
||||
context.update(batch_context)
|
||||
context.update(jinja_context)
|
||||
|
||||
# Materialize batch and cache any materialized relations
|
||||
result = MacroGenerator(
|
||||
|
||||
@@ -64,8 +64,8 @@ microbatch_yearly_model_downstream_sql = """
|
||||
select * from {{ ref('microbatch_model') }}
|
||||
"""
|
||||
|
||||
invalid_batch_context_macro_sql = """
|
||||
{% macro check_invalid_batch_context() %}
|
||||
invalid_batch_jinja_context_macro_sql = """
|
||||
{% macro check_invalid_batch_jinja_context() %}
|
||||
|
||||
{% if model is not mapping %}
|
||||
{{ exceptions.raise_compiler_error("`model` is invalid: expected mapping type") }}
|
||||
@@ -83,9 +83,9 @@ invalid_batch_context_macro_sql = """
|
||||
"""
|
||||
|
||||
microbatch_model_with_context_checks_sql = """
|
||||
{{ config(pre_hook="{{ check_invalid_batch_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
|
||||
{{ config(pre_hook="{{ check_invalid_batch_jinja_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
|
||||
|
||||
{{ check_invalid_batch_context() }}
|
||||
{{ check_invalid_batch_jinja_context() }}
|
||||
select * from {{ ref('input_model') }}
|
||||
"""
|
||||
|
||||
@@ -404,7 +404,7 @@ class TestMicrobatchJinjaContext(BaseMicrobatchTest):
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def macros(self):
|
||||
return {"check_batch_context.sql": invalid_batch_context_macro_sql}
|
||||
return {"check_batch_jinja_context.sql": invalid_batch_jinja_context_macro_sql}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def models(self):
|
||||
@@ -498,6 +498,13 @@ microbatch_model_context_vars = """
|
||||
{{ config(materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
|
||||
{{ log("start: "~ model.config.__dbt_internal_microbatch_event_time_start, info=True)}}
|
||||
{{ log("end: "~ model.config.__dbt_internal_microbatch_event_time_end, info=True)}}
|
||||
{% if model.batch %}
|
||||
{{ log("batch.event_time_start: "~ model.batch.event_time_start, info=True)}}
|
||||
{{ log("batch.event_time_end: "~ model.batch.event_time_end, info=True)}}
|
||||
{{ log("batch.id: "~ model.batch.id, info=True)}}
|
||||
{{ log("start timezone: "~ model.batch.event_time_start.tzinfo, info=True)}}
|
||||
{{ log("end timezone: "~ model.batch.event_time_end.tzinfo, info=True)}}
|
||||
{% endif %}
|
||||
select * from {{ ref('input_model') }}
|
||||
"""
|
||||
|
||||
@@ -516,12 +523,23 @@ class TestMicrobatchJinjaContextVarsAvailable(BaseMicrobatchTest):
|
||||
|
||||
assert "start: 2020-01-01 00:00:00+00:00" in logs
|
||||
assert "end: 2020-01-02 00:00:00+00:00" in logs
|
||||
assert "batch.event_time_start: 2020-01-01 00:00:00+00:00" in logs
|
||||
assert "batch.event_time_end: 2020-01-02 00:00:00+00:00" in logs
|
||||
assert "batch.id: 20200101" in logs
|
||||
assert "start timezone: UTC" in logs
|
||||
assert "end timezone: UTC" in logs
|
||||
|
||||
assert "start: 2020-01-02 00:00:00+00:00" in logs
|
||||
assert "end: 2020-01-03 00:00:00+00:00" in logs
|
||||
assert "batch.event_time_start: 2020-01-02 00:00:00+00:00" in logs
|
||||
assert "batch.event_time_end: 2020-01-03 00:00:00+00:00" in logs
|
||||
assert "batch.id: 20200102" in logs
|
||||
|
||||
assert "start: 2020-01-03 00:00:00+00:00" in logs
|
||||
assert "end: 2020-01-03 13:57:00+00:00" in logs
|
||||
assert "batch.event_time_start: 2020-01-03 00:00:00+00:00" in logs
|
||||
assert "batch.event_time_end: 2020-01-03 13:57:00+00:00" in logs
|
||||
assert "batch.id: 20200103" in logs
|
||||
|
||||
|
||||
microbatch_model_failing_incremental_partition_sql = """
|
||||
@@ -675,16 +693,6 @@ class TestMicrobatchCompiledRunPaths(BaseMicrobatchTest):
|
||||
with patch_microbatch_end_time("2020-01-03 13:57:00"):
|
||||
run_dbt(["run"])
|
||||
|
||||
# Compiled paths - compiled model without filter only
|
||||
assert read_file(
|
||||
project.project_root,
|
||||
"target",
|
||||
"compiled",
|
||||
"test",
|
||||
"models",
|
||||
"microbatch_model.sql",
|
||||
)
|
||||
|
||||
# Compiled paths - batch compilations
|
||||
assert read_file(
|
||||
project.project_root,
|
||||
|
||||
@@ -96,6 +96,7 @@ REQUIRED_PARSED_NODE_KEYS = frozenset(
|
||||
"deprecation_date",
|
||||
"defer_relation",
|
||||
"time_spine",
|
||||
"batch",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -489,11 +489,11 @@ class TestMicrobatchBuilder:
|
||||
assert len(actual_batches) == len(expected_batches)
|
||||
assert actual_batches == expected_batches
|
||||
|
||||
def test_build_batch_context_incremental_batch(self, microbatch_model):
|
||||
def test_build_jinja_context_for_incremental_batch(self, microbatch_model):
|
||||
microbatch_builder = MicrobatchBuilder(
|
||||
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
|
||||
)
|
||||
context = microbatch_builder.build_batch_context(incremental_batch=True)
|
||||
context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=True)
|
||||
|
||||
assert context["model"] == microbatch_model.to_dict()
|
||||
assert context["sql"] == microbatch_model.compiled_code
|
||||
@@ -502,11 +502,11 @@ class TestMicrobatchBuilder:
|
||||
assert context["is_incremental"]() is True
|
||||
assert context["should_full_refresh"]() is False
|
||||
|
||||
def test_build_batch_context_incremental_batch_false(self, microbatch_model):
|
||||
def test_build_jinja_context_for_incremental_batch_false(self, microbatch_model):
|
||||
microbatch_builder = MicrobatchBuilder(
|
||||
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
|
||||
)
|
||||
context = microbatch_builder.build_batch_context(incremental_batch=False)
|
||||
context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=False)
|
||||
|
||||
assert context["model"] == microbatch_model.to_dict()
|
||||
assert context["sql"] == microbatch_model.compiled_code
|
||||
@@ -605,7 +605,6 @@ class TestMicrobatchBuilder:
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,batch_start,expected_formatted_batch_start",
|
||||
[
|
||||
(None, None, None),
|
||||
(BatchSize.year, datetime(2020, 1, 1, 1), "2020-01-01"),
|
||||
(BatchSize.month, datetime(2020, 1, 1, 1), "2020-01-01"),
|
||||
(BatchSize.day, datetime(2020, 1, 1, 1), "2020-01-01"),
|
||||
|
||||
Reference in New Issue
Block a user