Compare commits

...

1 Commits

Author SHA1 Message Date
Michelle Ark
2a146bb51a first stab: microbatch 2024-07-23 13:56:59 -04:00
7 changed files with 63 additions and 7 deletions

View File

@@ -219,6 +219,7 @@ class CompiledResource(ParsedResource):
extra_ctes: List[InjectedCTE] = field(default_factory=list) extra_ctes: List[InjectedCTE] = field(default_factory=list)
_pre_injected_sql: Optional[str] = None _pre_injected_sql: Optional[str] = None
contract: Contract = field(default_factory=Contract) contract: Contract = field(default_factory=Contract)
event_time: Optional[str] = None
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context) dct = super().__post_serialize__(dct, context)

View File

@@ -70,3 +70,4 @@ class SourceDefinition(ParsedSourceMandatory):
unrendered_config: Dict[str, Any] = field(default_factory=dict) unrendered_config: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None relation_name: Optional[str] = None
created_at: float = field(default_factory=lambda: time.time()) created_at: float = field(default_factory=lambda: time.time())
event_time: Optional[str] = None

View File

@@ -20,6 +20,7 @@ from typing_extensions import Protocol
from dbt import selected_resources from dbt import selected_resources
from dbt.adapters.base.column import Column from dbt.adapters.base.column import Column
from dbt.adapters.base.relation import EventTimeFilter
from dbt.adapters.contracts.connection import AdapterResponse from dbt.adapters.contracts.connection import AdapterResponse
from dbt.adapters.exceptions import MissingConfigError from dbt.adapters.exceptions import MissingConfigError
from dbt.adapters.factory import ( from dbt.adapters.factory import (
@@ -230,6 +231,21 @@ class BaseResolver(metaclass=abc.ABCMeta):
def resolve_limit(self) -> Optional[int]: def resolve_limit(self) -> Optional[int]:
return 0 if getattr(self.config.args, "EMPTY", False) else None return 0 if getattr(self.config.args, "EMPTY", False) else None
@property
def resolve_event_time_filter(self) -> Optional[EventTimeFilter]:
field_name = getattr(self.model, "event_time")
start_time = getattr(self.model, "start_time")
end_time = getattr(self.model, "end_time")
if start_time and end_time and field_name:
return EventTimeFilter(
field_name=field_name,
start_time=start_time,
end_time=end_time,
)
return None
@abc.abstractmethod @abc.abstractmethod
def __call__(self, *args: str) -> Union[str, RelationProxy, MetricReference]: def __call__(self, *args: str) -> Union[str, RelationProxy, MetricReference]:
pass pass
@@ -545,7 +561,11 @@ class RuntimeRefResolver(BaseRefResolver):
def create_relation(self, target_model: ManifestNode) -> RelationProxy: def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model: if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None) self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from(target_model, limit=self.resolve_limit) return self.Relation.create_ephemeral_from(
target_model,
limit=self.resolve_limit,
event_time_filter=self.resolve_event_time_filter,
)
elif ( elif (
hasattr(target_model, "defer_relation") hasattr(target_model, "defer_relation")
and target_model.defer_relation and target_model.defer_relation
@@ -563,10 +583,18 @@ class RuntimeRefResolver(BaseRefResolver):
) )
): ):
return self.Relation.create_from( return self.Relation.create_from(
self.config, target_model.defer_relation, limit=self.resolve_limit self.config,
target_model.defer_relation,
limit=self.resolve_limit,
event_time_filter=self.resolve_event_time_filter,
) )
else: else:
return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit) return self.Relation.create_from(
self.config,
target_model,
limit=self.resolve_limit,
event_time_filter=self.resolve_event_time_filter,
)
def validate( def validate(
self, self,
@@ -633,7 +661,12 @@ class RuntimeSourceResolver(BaseSourceResolver):
target_kind="source", target_kind="source",
disabled=(isinstance(target_source, Disabled)), disabled=(isinstance(target_source, Disabled)),
) )
return self.Relation.create_from(self.config, target_source, limit=self.resolve_limit) return self.Relation.create_from(
self.config,
target_source,
limit=self.resolve_limit,
event_time_filter=self.resolve_event_time_filter,
)
class RuntimeUnitTestSourceResolver(BaseSourceResolver): class RuntimeUnitTestSourceResolver(BaseSourceResolver):

View File

@@ -379,6 +379,10 @@ class CompiledNode(CompiledResource, ParsedNode):
"""Contains attributes necessary for SQL files and nodes with refs, sources, etc, """Contains attributes necessary for SQL files and nodes with refs, sources, etc,
so all ManifestNodes except SeedNode.""" so all ManifestNodes except SeedNode."""
# TODO: should these go here? and get set during execution?
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
@property @property
def empty(self): def empty(self):
return not self.raw_code.strip() return not self.raw_code.strip()

View File

@@ -1,7 +1,7 @@
import functools import functools
import threading import threading
import time import time
from datetime import datetime from datetime import datetime, timedelta
from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple
from dbt import tracking, utils from dbt import tracking, utils
@@ -214,6 +214,14 @@ class ModelRunner(CompileRunner):
) )
def before_execute(self): def before_execute(self):
if self.node.config.get("microbatch"):
# TODO: actually use partition_grain
# partition_grain = self.node.config.get("partition_grain")
lookback = self.node.config.get("lookback")
self.node.end_time = datetime.now()
self.node.start_time = self.node.end_time - timedelta(days=lookback)
self.node.start_time.replace(minute=0, hour=0, second=0, microsecond=0)
self.print_start_line() self.print_start_line()
def after_execute(self, result): def after_execute(self, result):

View File

@@ -41,9 +41,12 @@ class TestRuntimeRefResolver:
mock_db_wrapper = mock.Mock() mock_db_wrapper = mock.Mock()
mock_db_wrapper.Relation = BaseRelation mock_db_wrapper.Relation = BaseRelation
mock_model = mock.Mock()
mock_model.event_time = None
return RuntimeRefResolver( return RuntimeRefResolver(
db_wrapper=mock_db_wrapper, db_wrapper=mock_db_wrapper,
model=mock.Mock(), model=mock_model,
config=mock.Mock(), config=mock.Mock(),
manifest=mock.Mock(), manifest=mock.Mock(),
) )
@@ -82,9 +85,12 @@ class TestRuntimeSourceResolver:
mock_db_wrapper = mock.Mock() mock_db_wrapper = mock.Mock()
mock_db_wrapper.Relation = BaseRelation mock_db_wrapper.Relation = BaseRelation
mock_model = mock.Mock()
mock_model.event_time = None
return RuntimeSourceResolver( return RuntimeSourceResolver(
db_wrapper=mock_db_wrapper, db_wrapper=mock_db_wrapper,
model=mock.Mock(), model=mock_model,
config=mock.Mock(), config=mock.Mock(),
manifest=mock.Mock(), manifest=mock.Mock(),
) )

View File

@@ -94,6 +94,9 @@ REQUIRED_PARSED_NODE_KEYS = frozenset(
"constraints", "constraints",
"deprecation_date", "deprecation_date",
"defer_relation", "defer_relation",
"event_time",
"start_time",
"end_time",
} }
) )