Compare commits

...

1 Commits

Author SHA1 Message Date
Michelle Ark
2a146bb51a first stab: microbatch 2024-07-23 13:56:59 -04:00
7 changed files with 63 additions and 7 deletions

View File

@@ -219,6 +219,7 @@ class CompiledResource(ParsedResource):
extra_ctes: List[InjectedCTE] = field(default_factory=list)
_pre_injected_sql: Optional[str] = None
contract: Contract = field(default_factory=Contract)
event_time: Optional[str] = None
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)

View File

@@ -70,3 +70,4 @@ class SourceDefinition(ParsedSourceMandatory):
unrendered_config: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
created_at: float = field(default_factory=lambda: time.time())
event_time: Optional[str] = None

View File

@@ -20,6 +20,7 @@ from typing_extensions import Protocol
from dbt import selected_resources
from dbt.adapters.base.column import Column
from dbt.adapters.base.relation import EventTimeFilter
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.adapters.exceptions import MissingConfigError
from dbt.adapters.factory import (
@@ -230,6 +231,21 @@ class BaseResolver(metaclass=abc.ABCMeta):
def resolve_limit(self) -> Optional[int]:
return 0 if getattr(self.config.args, "EMPTY", False) else None
@property
def resolve_event_time_filter(self) -> Optional[EventTimeFilter]:
field_name = getattr(self.model, "event_time")
start_time = getattr(self.model, "start_time")
end_time = getattr(self.model, "end_time")
if start_time and end_time and field_name:
return EventTimeFilter(
field_name=field_name,
start_time=start_time,
end_time=end_time,
)
return None
@abc.abstractmethod
def __call__(self, *args: str) -> Union[str, RelationProxy, MetricReference]:
pass
@@ -545,7 +561,11 @@ class RuntimeRefResolver(BaseRefResolver):
def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from(target_model, limit=self.resolve_limit)
return self.Relation.create_ephemeral_from(
target_model,
limit=self.resolve_limit,
event_time_filter=self.resolve_event_time_filter,
)
elif (
hasattr(target_model, "defer_relation")
and target_model.defer_relation
@@ -563,10 +583,18 @@ class RuntimeRefResolver(BaseRefResolver):
)
):
return self.Relation.create_from(
self.config, target_model.defer_relation, limit=self.resolve_limit
self.config,
target_model.defer_relation,
limit=self.resolve_limit,
event_time_filter=self.resolve_event_time_filter,
)
else:
return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit)
return self.Relation.create_from(
self.config,
target_model,
limit=self.resolve_limit,
event_time_filter=self.resolve_event_time_filter,
)
def validate(
self,
@@ -633,7 +661,12 @@ class RuntimeSourceResolver(BaseSourceResolver):
target_kind="source",
disabled=(isinstance(target_source, Disabled)),
)
return self.Relation.create_from(self.config, target_source, limit=self.resolve_limit)
return self.Relation.create_from(
self.config,
target_source,
limit=self.resolve_limit,
event_time_filter=self.resolve_event_time_filter,
)
class RuntimeUnitTestSourceResolver(BaseSourceResolver):

View File

@@ -379,6 +379,10 @@ class CompiledNode(CompiledResource, ParsedNode):
"""Contains attributes necessary for SQL files and nodes with refs, sources, etc,
so all ManifestNodes except SeedNode."""
# TODO: should these go here? and get set during execution?
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
@property
def empty(self):
return not self.raw_code.strip()

View File

@@ -1,7 +1,7 @@
import functools
import threading
import time
from datetime import datetime
from datetime import datetime, timedelta
from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple
from dbt import tracking, utils
@@ -214,6 +214,14 @@ class ModelRunner(CompileRunner):
)
def before_execute(self):
if self.node.config.get("microbatch"):
# TODO: actually use partition_grain
# partition_grain = self.node.config.get("partition_grain")
lookback = self.node.config.get("lookback")
self.node.end_time = datetime.now()
self.node.start_time = self.node.end_time - timedelta(days=lookback)
self.node.start_time.replace(minute=0, hour=0, second=0, microsecond=0)
self.print_start_line()
def after_execute(self, result):

View File

@@ -41,9 +41,12 @@ class TestRuntimeRefResolver:
mock_db_wrapper = mock.Mock()
mock_db_wrapper.Relation = BaseRelation
mock_model = mock.Mock()
mock_model.event_time = None
return RuntimeRefResolver(
db_wrapper=mock_db_wrapper,
model=mock.Mock(),
model=mock_model,
config=mock.Mock(),
manifest=mock.Mock(),
)
@@ -82,9 +85,12 @@ class TestRuntimeSourceResolver:
mock_db_wrapper = mock.Mock()
mock_db_wrapper.Relation = BaseRelation
mock_model = mock.Mock()
mock_model.event_time = None
return RuntimeSourceResolver(
db_wrapper=mock_db_wrapper,
model=mock.Mock(),
model=mock_model,
config=mock.Mock(),
manifest=mock.Mock(),
)

View File

@@ -94,6 +94,9 @@ REQUIRED_PARSED_NODE_KEYS = frozenset(
"constraints",
"deprecation_date",
"defer_relation",
"event_time",
"start_time",
"end_time",
}
)