Compare commits

...

7 Commits

Author SHA1 Message Date
Michelle Ark
002c3c4088 add unit test for Flags initialized from ProjectFlags with project_only_flags' 2024-01-23 14:58:10 -05:00
Michelle Ark
1bd8aeb518 Merge branch 'main' into feature/source-freshness-hooks 2024-01-23 10:37:21 -05:00
Michelle Ark
dd5377ff25 Merge branch 'main' into feature/source-freshness-hooks 2024-01-12 15:15:23 -05:00
Michelle Ark
870b1d3c69 rename to project_only_flags 2024-01-12 13:35:26 -05:00
Michelle Ark
0fd1381d7f add flags.source_freshness_run_project_hooks and tests 2024-01-12 10:58:13 -05:00
Ofek Weiss
73ef35f306 Handle errors in on-run-end hooks 2024-01-10 16:23:52 -05:00
Ofek Weiss
fe66aba2a5 added hook support for dbt source freshness 2024-01-10 16:23:22 -05:00
8 changed files with 137 additions and 10 deletions

View File

@@ -0,0 +1,6 @@
kind: Features
body: Added hook support for `dbt source freshness`
time: 2023-12-31T17:12:05.587185+02:00
custom:
Author: ofek1weiss
Issue: "5609"

View File

@@ -236,8 +236,8 @@ class Flags:
# Add entire invocation command to flags # Add entire invocation command to flags
object.__setattr__(self, "INVOCATION_COMMAND", "dbt " + " ".join(sys.argv[1:])) object.__setattr__(self, "INVOCATION_COMMAND", "dbt " + " ".join(sys.argv[1:]))
# Overwrite default assignments with user config if available.
if project_flags: if project_flags:
# Overwrite default assignments with project flags if available.
param_assigned_from_default_copy = params_assigned_from_default.copy() param_assigned_from_default_copy = params_assigned_from_default.copy()
for param_assigned_from_default in params_assigned_from_default: for param_assigned_from_default in params_assigned_from_default:
project_flags_param_value = getattr( project_flags_param_value = getattr(
@@ -252,6 +252,13 @@ class Flags:
param_assigned_from_default_copy.remove(param_assigned_from_default) param_assigned_from_default_copy.remove(param_assigned_from_default)
params_assigned_from_default = param_assigned_from_default_copy params_assigned_from_default = param_assigned_from_default_copy
# Add project-level flags that are not available as CLI options / env vars
for (
project_level_flag_name,
project_level_flag_value,
) in project_flags.project_only_flags.items():
object.__setattr__(self, project_level_flag_name.upper(), project_level_flag_value)
# Set hard coded flags. # Set hard coded flags.
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name) object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)

View File

@@ -307,6 +307,7 @@ class ProjectFlags(ExtensibleDbtClassMixin, Replaceable):
populate_cache: Optional[bool] = None populate_cache: Optional[bool] = None
printer_width: Optional[int] = None printer_width: Optional[int] = None
send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS
source_freshness_run_project_hooks: bool = False
static_parser: Optional[bool] = None static_parser: Optional[bool] = None
use_colors: Optional[bool] = None use_colors: Optional[bool] = None
use_colors_file: Optional[bool] = None use_colors_file: Optional[bool] = None
@@ -316,6 +317,10 @@ class ProjectFlags(ExtensibleDbtClassMixin, Replaceable):
warn_error_options: Optional[Dict[str, Union[str, List[str]]]] = None warn_error_options: Optional[Dict[str, Union[str, List[str]]]] = None
write_json: Optional[bool] = None write_json: Optional[bool] = None
@property
def project_only_flags(self) -> Dict[str, Any]:
return {"source_freshness_run_project_hooks": self.source_freshness_run_project_hooks}
@dataclass @dataclass
class ProfileConfig(dbtClassMixin, Replaceable): class ProfileConfig(dbtClassMixin, Replaceable):

View File

@@ -1,13 +1,13 @@
import os import os
import threading import threading
import time import time
from typing import Optional from typing import Optional, List
from .base import BaseRunner from .base import BaseRunner
from .printer import ( from .printer import (
print_run_result_error, print_run_result_error,
) )
from .runnable import GraphRunnableTask from .run import RunTask
from dbt.artifacts.freshness import ( from dbt.artifacts.freshness import (
FreshnessResult, FreshnessResult,
@@ -23,11 +23,12 @@ from dbt.events.types import (
LogStartLine, LogStartLine,
LogFreshnessResult, LogFreshnessResult,
) )
from dbt.node_types import NodeType from dbt.contracts.results import RunStatus
from dbt.node_types import NodeType, RunHookType
from dbt.adapters.capability import Capability from dbt.adapters.capability import Capability
from dbt.adapters.contracts.connection import AdapterResponse from dbt.adapters.contracts.connection import AdapterResponse
from dbt.contracts.graph.nodes import SourceDefinition from dbt.contracts.graph.nodes import SourceDefinition, HookNode
from dbt_common.events.base_types import EventLevel from dbt_common.events.base_types import EventLevel
from dbt.graph import ResourceTypeSelector from dbt.graph import ResourceTypeSelector
@@ -170,7 +171,7 @@ class FreshnessSelector(ResourceTypeSelector):
return node.has_freshness return node.has_freshness
class FreshnessTask(GraphRunnableTask): class FreshnessTask(RunTask):
def result_path(self): def result_path(self):
if self.args.output: if self.args.output:
return os.path.realpath(self.args.output) return os.path.realpath(self.args.output)
@@ -200,7 +201,17 @@ class FreshnessTask(GraphRunnableTask):
def task_end_messages(self, results): def task_end_messages(self, results):
for result in results: for result in results:
if result.status in (FreshnessStatus.Error, FreshnessStatus.RuntimeErr): if result.status in (
FreshnessStatus.Error,
FreshnessStatus.RuntimeErr,
RunStatus.Error,
):
print_run_result_error(result) print_run_result_error(result)
fire_event(FreshnessCheckComplete()) fire_event(FreshnessCheckComplete())
def get_hooks_by_type(self, hook_type: RunHookType) -> List[HookNode]:
if self.args.source_freshness_run_project_hooks:
return super().get_hooks_by_type(hook_type)
else:
return []

View File

@@ -2,7 +2,7 @@ import os
import pytest import pytest
import yaml import yaml
from dbt.tests.util import run_dbt from dbt.tests.util import run_dbt, run_dbt_and_capture
from tests.functional.sources.fixtures import ( from tests.functional.sources.fixtures import (
models_schema_yml, models_schema_yml,
models_view_model_sql, models_view_model_sql,
@@ -57,10 +57,17 @@ class BaseSourcesTest:
}, },
} }
def run_dbt_with_vars(self, project, cmd, *args, **kwargs): def _extend_cmd_with_vars(self, project, cmd):
vars_dict = { vars_dict = {
"test_run_schema": project.test_schema, "test_run_schema": project.test_schema,
"test_loaded_at": project.adapter.quote("updated_at"), "test_loaded_at": project.adapter.quote("updated_at"),
} }
cmd.extend(["--vars", yaml.safe_dump(vars_dict)]) cmd.extend(["--vars", yaml.safe_dump(vars_dict)])
def run_dbt_with_vars(self, project, cmd, *args, **kwargs):
self._extend_cmd_with_vars(project, cmd)
return run_dbt(cmd, *args, **kwargs) return run_dbt(cmd, *args, **kwargs)
def run_dbt_and_capture_with_vars(self, project, cmd, *args, **kwargs):
self._extend_cmd_with_vars(project, cmd)
return run_dbt_and_capture(cmd, *args, **kwargs)

View File

@@ -400,3 +400,85 @@ class TestMetadataFreshnessFails:
runner.invoke(["parse"]) runner.invoke(["parse"])
assert got_warning assert got_warning
class TestHooksInSourceFreshness(SuccessfulSourceFreshnessTest):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"config-version": 2,
"on-run-start": ["{{ log('on-run-start hooks called') }}"],
"on-run-end": ["{{ log('on-run-end hooks called') }}"],
"flags": {
"source_freshness_run_project_hooks": True,
},
}
def test_hooks_do_run_for_source_freshness(
self,
project,
):
_, log_output = self.run_dbt_and_capture_with_vars(
project,
[
"source",
"freshness",
],
expect_pass=False,
)
assert "on-run-start" in log_output
assert "on-run-end" in log_output
class TestHooksInSourceFreshnessDisabled(SuccessfulSourceFreshnessTest):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"config-version": 2,
"on-run-start": ["{{ log('on-run-start hooks called') }}"],
"on-run-end": ["{{ log('on-run-end hooks called') }}"],
"flags": {
"source_freshness_run_project_hooks": False,
},
}
def test_hooks_do_run_for_source_freshness(
self,
project,
):
_, log_output = self.run_dbt_and_capture_with_vars(
project,
[
"source",
"freshness",
],
expect_pass=False,
)
assert "on-run-start" not in log_output
assert "on-run-end" not in log_output
class TestHooksInSourceFreshnessDefault(SuccessfulSourceFreshnessTest):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"config-version": 2,
"on-run-start": ["{{ log('on-run-start hooks called') }}"],
"on-run-end": ["{{ log('on-run-end hooks called') }}"],
}
def test_hooks_do_run_for_source_freshness(
self,
project,
):
_, log_output = self.run_dbt_and_capture_with_vars(
project,
[
"source",
"freshness",
],
expect_pass=False,
)
# default behaviour - no hooks run in source freshness
assert "on-run-start" not in log_output
assert "on-run-end" not in log_output

View File

@@ -366,6 +366,14 @@ class TestFlags:
assert flags_a.USE_COLORS == flags_b.USE_COLORS assert flags_a.USE_COLORS == flags_b.USE_COLORS
def test_set_project_only_flags(self, project_flags, run_context):
flags = Flags(run_context, project_flags)
for project_only_flag, project_only_flag_value in project_flags.project_only_flags.items():
assert getattr(flags, project_only_flag) == project_only_flag_value
# sanity check: ensure project_only_flag is not part of the click context
assert project_only_flag not in run_context.params
def _create_flags_from_dict(self, cmd, d): def _create_flags_from_dict(self, cmd, d):
write_file("", "profiles.yml") write_file("", "profiles.yml")
result = Flags.from_dict(cmd, d) result = Flags.from_dict(cmd, d)

View File

@@ -16,6 +16,7 @@ import dbt.parser.manifest
from dbt import tracking from dbt import tracking
from dbt.contracts.files import SourceFile, FileHash, FilePath from dbt.contracts.files import SourceFile, FileHash, FilePath
from dbt.contracts.graph.manifest import MacroManifest, ManifestStateCheck from dbt.contracts.graph.manifest import MacroManifest, ManifestStateCheck
from dbt.contracts.project import ProjectFlags
from dbt.graph import NodeSelector, parse_difference from dbt.graph import NodeSelector, parse_difference
from dbt.events.logging import setup_event_logger from dbt.events.logging import setup_event_logger
from dbt.mp_context import get_mp_context from dbt.mp_context import get_mp_context
@@ -130,7 +131,7 @@ class GraphTest(unittest.TestCase):
cfg.update(extra_cfg) cfg.update(extra_cfg)
config = config_from_parts_or_dicts(project=cfg, profile=self.profile) config = config_from_parts_or_dicts(project=cfg, profile=self.profile)
dbt.flags.set_from_args(Namespace(), config) dbt.flags.set_from_args(Namespace(), ProjectFlags())
setup_event_logger(dbt.flags.get_flags()) setup_event_logger(dbt.flags.get_flags())
object.__setattr__(dbt.flags.get_flags(), "PARTIAL_PARSE", False) object.__setattr__(dbt.flags.get_flags(), "PARTIAL_PARSE", False)
return config return config