Compare commits

...

12 Commits

Author SHA1 Message Date
Gerda Shank
ed8d9dbf8f Put back original v9 manifest schema 2023-08-11 18:16:33 -04:00
Gerda Shank
0b31227c56 Merge branch 'main' into ct-2911-initial_unit_testing 2023-08-11 18:15:03 -04:00
Gerda Shank
b9b4661f4d Remove unused methods and removed unused packages.yml from test 2023-08-07 15:46:58 -04:00
Gerda Shank
7e4bf98461 Replace space with underscor in unit test node type 2023-08-04 10:00:40 -04:00
Gerda Shank
b14c3a09a7 Update to changed statically_parsed interface 2023-08-03 11:54:08 -04:00
Gerda Shank
a8a2331a93 Fix Command and CLICommand enums 2023-08-02 17:45:24 -04:00
Gerda Shank
68968a74d7 mypy. Don't know why these changed... 2023-08-02 15:55:39 -04:00
Gerda Shank
31af2b9979 Add test_unit_testing.py 2023-08-02 15:29:35 -04:00
Gerda Shank
4a8317d974 Changie 2023-08-02 14:50:18 -04:00
Gerda Shank
a2f197851c Merge branch 'main' into ct-2911-initial_unit_testing 2023-08-02 14:49:30 -04:00
Gerda Shank
c578971be4 comment 2023-07-26 08:02:22 -04:00
Gerda Shank
49179cb7fb Squashed commit of the following:
commit 88eb4d6b69
Author: Gerda Shank <gerda@dbtlabs.com>
Date:   Thu Jul 20 13:25:49 2023 -0400

    add @defer_state to unit-test in core/dbt/cli/main.py

commit 1336104236
Author: Gerda Shank <gerda@dbtlabs.com>
Date:   Thu Jul 20 11:24:32 2023 -0400

    kludge to make mypy happy about parse_from_dict

commit 2b264b248e
Merge: 4dc6dd395 eeb057085
Author: Gerda Shank <gerda@dbtlabs.com>
Date:   Thu Jul 20 11:04:00 2023 -0400

    Merge branch 'main' into arky/poc-unit-testing

commit 4dc6dd395e
Author: Chenyu Li <chenyu.li@dbtlabs.com>
Date:   Sat May 6 13:15:05 2023 -0700

    copy nodes being referenced over

commit e203efa95b
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Tue May 9 10:11:22 2023 -0400

    postgres fix

commit 05bd06f401
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Mon May 8 02:38:26 2023 -0400

    null input values

commit e56538d9bb
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Mon May 8 02:18:17 2023 -0400

    override jinja context properties

commit 43ba24fbc3
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Mon May 8 02:07:35 2023 -0400

    hacky lineage between unit tests

commit 8d586d0aab
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Mon May 8 01:12:17 2023 -0400

    fix tests

commit 19793eb9fb
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Mon May 8 00:20:17 2023 -0400

    fix --output json

commit 4cf47109f5
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Mon May 8 00:10:52 2023 -0400

    remove manifest.add_unit_test

commit 58aa0bdb3b
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Mon May 8 00:06:39 2023 -0400

    undo test_type:unit selector

commit f8ce09ca6c
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Mon May 8 00:02:51 2023 -0400

    jinja refactoring

commit 2c301161cb
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Sun May 7 23:35:07 2023 -0400

    more decoupling of unit test from other parsing/execution + overrides support

commit ccc3ad3886
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Fri May 5 23:14:10 2023 -0400

    fix manifest artifact

commit 2c953b227c
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Fri May 5 23:12:11 2023 -0400

    move parsing to UnitTestManifestLoader + requires

commit 2184a4da05
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Wed May 3 10:33:30 2023 -0400

    better type handling

commit f8bdd8b19b
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Sun Apr 30 22:00:52 2023 -0400

    dbt.string_literal

commit 8043106b9f
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Sun Apr 30 21:50:57 2023 -0400

    polish up spec, get column schema from relation

commit 0ab9222eda
Author: Michelle Ark <michelle.ark@dbtlabs.com>
Date:   Sun Apr 30 14:19:50 2023 -0400

    first pass
2023-07-25 14:36:59 -04:00
24 changed files with 912 additions and 14 deletions

View File

@@ -0,0 +1,6 @@
kind: Features
body: Initial implementation of unit testing
time: 2023-08-02T14:50:11.391992-04:00
custom:
Author: gshank
Issue: "8287"

View File

@@ -12,7 +12,7 @@ class RelationConfigChangeAction(StrEnum):
drop = "drop"
@dataclass(frozen=True, eq=True, unsafe_hash=True)
@dataclass(frozen=True, eq=True, unsafe_hash=True) # type: ignore
class RelationConfigChange(RelationConfigBase, ABC):
action: RelationConfigChangeAction
context: Hashable # this is usually a RelationConfig, e.g. IndexConfig, but shouldn't be limited

View File

@@ -392,6 +392,7 @@ def command_args(command: CliCommand) -> ArgsList:
CliCommand.SOURCE_FRESHNESS: cli.freshness,
CliCommand.TEST: cli.test,
CliCommand.RETRY: cli.retry,
CliCommand.UNIT_TEST: cli.unit_test,
}
click_cmd: Optional[ClickCommand] = CMD_DICT.get(command, None)
if click_cmd is None:

View File

@@ -39,6 +39,7 @@ from dbt.task.serve import ServeTask
from dbt.task.show import ShowTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.task.unit_test import UnitTestTask
@dataclass
@@ -846,6 +847,52 @@ def test(ctx, **kwargs):
return results, success
# dbt unit-test
@cli.command("unit-test")
@click.pass_context
@p.defer
@p.deprecated_defer
@p.exclude
@p.fail_fast
@p.favor_state
@p.deprecated_favor_state
@p.indirect_selection
@p.show_output_format
@p.profile
@p.profiles_dir
@p.project_dir
@p.select
@p.selector
@p.state
@p.defer_state
@p.deprecated_state
@p.store_failures
@p.target
@p.target_path
@p.threads
@p.vars
@p.version_check
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
@requires.unit_test_collection
def unit_test(ctx, **kwargs):
"""Runs tests on data in deployed models. Run this after `dbt run`"""
task = UnitTestTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
ctx.obj["manifest"],
ctx.obj["unit_test_collection"],
)
results = task.run()
success = task.interpret_results(results)
return results, success
# Support running as a module
if __name__ == "__main__":
cli()

View File

@@ -23,6 +23,7 @@ from dbt.parser.manifest import ManifestLoader, write_manifest
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.utils import cast_dict_to_dict_of_strings
from dbt.parser.unit_tests import UnitTestManifestLoader
from dbt.plugins import set_up_plugin_manager, get_plugin_manager
from click import Context
@@ -265,3 +266,25 @@ def manifest(*args0, write=True, write_perf_info=False):
if len(args0) == 0:
return outer_wrapper
return outer_wrapper(args0[0])
def unit_test_collection(func):
"""A decorator used by click command functions for generating a unit test collection provided a manifest"""
def wrapper(*args, **kwargs):
ctx = args[0]
assert isinstance(ctx, Context)
req_strs = ["manifest", "runtime_config"]
reqs = [ctx.obj.get(req_str) for req_str in req_strs]
if None in reqs:
raise DbtProjectError("manifest and runtime_config required for unit_test_collection")
collection = UnitTestManifestLoader.load(ctx.obj["manifest"], ctx.obj["runtime_config"])
ctx.obj["unit_test_collection"] = collection
return func(*args, **kwargs)
return update_wrapper(wrapper, func)

View File

@@ -24,6 +24,7 @@ class Command(Enum):
SOURCE_FRESHNESS = "freshness"
TEST = "test"
RETRY = "retry"
UNIT_TEST = "unit-test"
@classmethod
def from_str(cls, s: str) -> "Command":

View File

@@ -330,6 +330,26 @@ class MacroGenerator(BaseMacroGenerator):
return self.call_macro(*args, **kwargs)
class UnitTestMacroGenerator(MacroGenerator):
# this makes UnitTestMacroGenerator objects callable like functions
def __init__(
self,
macro_generator: MacroGenerator,
call_return_value: Any,
) -> None:
super().__init__(
macro_generator.macro,
macro_generator.context,
macro_generator.node,
macro_generator.stack,
)
self.call_return_value = call_return_value
def __call__(self, *args, **kwargs):
with self.track_call():
return self.call_return_value
class QueryStringGenerator(BaseMacroGenerator):
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
super().__init__(context)

View File

@@ -12,7 +12,10 @@ from dbt.flags import get_flags
from dbt.adapters.factory import get_adapter
from dbt.clients import jinja
from dbt.clients.system import make_directory
from dbt.context.providers import generate_runtime_model_context
from dbt.context.providers import (
generate_runtime_model_context,
generate_runtime_unit_test_context,
)
from dbt.contracts.graph.manifest import Manifest, UniqueID
from dbt.contracts.graph.nodes import (
ManifestNode,
@@ -21,6 +24,7 @@ from dbt.contracts.graph.nodes import (
GraphMemberNode,
InjectedCTE,
SeedNode,
UnitTestNode,
)
from dbt.exceptions import (
GraphDependencyNotFoundError,
@@ -44,6 +48,7 @@ def print_compile_stats(stats):
names = {
NodeType.Model: "model",
NodeType.Test: "test",
NodeType.Unit: "unit test",
NodeType.Snapshot: "snapshot",
NodeType.Analysis: "analysis",
NodeType.Macro: "macro",
@@ -289,8 +294,10 @@ class Compiler:
manifest: Manifest,
extra_context: Dict[str, Any],
) -> Dict[str, Any]:
context = generate_runtime_model_context(node, self.config, manifest)
if isinstance(node, UnitTestNode):
context = generate_runtime_unit_test_context(node, self.config, manifest)
else:
context = generate_runtime_model_context(node, self.config, manifest)
context.update(extra_context)
if isinstance(node, GenericTestNode):

View File

@@ -1,4 +1,5 @@
import abc
from copy import deepcopy
import os
from typing import (
Callable,
@@ -17,7 +18,7 @@ from typing_extensions import Protocol
from dbt.adapters.base.column import Column
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
from dbt.config import RuntimeConfig, Project
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.context.base import contextmember, contextproperty, Var
@@ -39,6 +40,7 @@ from dbt.contracts.graph.nodes import (
RefArgs,
AccessType,
SemanticModel,
UnitTestNode,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion
@@ -566,6 +568,17 @@ class OperationRefResolver(RuntimeRefResolver):
return super().create_relation(target_model)
class RuntimeUnitTestRefResolver(RuntimeRefResolver):
def resolve(
self,
target_name: str,
target_package: Optional[str] = None,
target_version: Optional[NodeVersion] = None,
) -> RelationProxy:
target_name = f"{self.model.name}__{target_name}"
return super().resolve(target_name, target_package, target_version)
# `source` implementations
class ParseSourceResolver(BaseSourceResolver):
def resolve(self, source_name: str, table_name: str):
@@ -670,6 +683,22 @@ class RuntimeVar(ModelConfiguredVar):
pass
class UnitTestVar(RuntimeVar):
def __init__(
self,
context: Dict[str, Any],
config: RuntimeConfig,
node: Resource,
) -> None:
config_copy = None
assert isinstance(node, UnitTestNode)
if node.overrides and node.overrides.vars:
config_copy = deepcopy(config)
config_copy.cli_vars.update(node.overrides.vars)
super().__init__(context, config_copy or config, node=node)
# Providers
class Provider(Protocol):
execute: bool
@@ -711,6 +740,16 @@ class RuntimeProvider(Provider):
metric = RuntimeMetricResolver
class RuntimeUnitTestProvider(Provider):
execute = True
Config = RuntimeConfigObject
DatabaseWrapper = RuntimeDatabaseWrapper
Var = UnitTestVar
ref = RuntimeUnitTestRefResolver
source = RuntimeSourceResolver # TODO: RuntimeUnitTestSourceResolver
metric = RuntimeMetricResolver
class OperationProvider(RuntimeProvider):
ref = OperationRefResolver
@@ -1360,7 +1399,7 @@ class ModelContext(ProviderContext):
@contextproperty
def pre_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
@@ -1369,7 +1408,7 @@ class ModelContext(ProviderContext):
@contextproperty
def post_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
@@ -1462,6 +1501,25 @@ class ModelContext(ProviderContext):
return None
class UnitTestContext(ModelContext):
model: UnitTestNode
@contextmember
def env_var(self, var: str, default: Optional[str] = None) -> str:
"""The env_var() function. Return the overriden unit test environment variable named 'var'.
If there is no unit test override, return the environment variable named 'var'.
If there is no such environment variable set, return the default.
If the default is None, raise an exception for an undefined variable.
"""
if self.model.overrides and var in self.model.overrides.env_vars:
return self.model.overrides.env_vars[var]
else:
return super().env_var(var, default)
# This is called by '_context_for', used in 'render_with_context'
def generate_parser_model_context(
model: ManifestNode,
@@ -1506,6 +1564,24 @@ def generate_runtime_macro_context(
return ctx.to_dict()
def generate_runtime_unit_test_context(
unit_test: UnitTestNode,
config: RuntimeConfig,
manifest: Manifest,
) -> Dict[str, Any]:
ctx = UnitTestContext(unit_test, config, manifest, RuntimeUnitTestProvider(), None)
ctx_dict = ctx.to_dict()
if unit_test.overrides and unit_test.overrides.macros:
for macro_name, macro_value in unit_test.overrides.macros.items():
context_value = ctx_dict.get(macro_name)
if isinstance(context_value, MacroGenerator):
ctx_dict[macro_name] = UnitTestMacroGenerator(context_value, macro_value)
else:
ctx_dict[macro_name] = macro_value
return ctx_dict
class ExposureRefResolver(BaseResolver):
def __call__(self, *args, **kwargs) -> str:
package = None

View File

@@ -665,6 +665,7 @@ RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
NodeType.Source: SourceConfig,
NodeType.Seed: SeedConfig,
NodeType.Test: TestConfig,
NodeType.Unit: TestConfig,
NodeType.Model: NodeConfig,
NodeType.Snapshot: SnapshotConfig,
}

View File

@@ -34,6 +34,7 @@ from dbt.contracts.graph.unparsed import (
UnparsedSourceDefinition,
UnparsedSourceTableDefinition,
UnparsedColumn,
UnparsedUnitTestOverrides,
)
from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
@@ -983,6 +984,13 @@ class GenericTestNode(TestShouldStoreFailures, CompiledNode, HasTestMetadata):
return "generic"
@dataclass
class UnitTestNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]})
attached_node: Optional[str] = None
overrides: Optional[UnparsedUnitTestOverrides] = None
# ====================================
# Snapshot node
# ====================================
@@ -1239,6 +1247,10 @@ class SourceDefinition(NodeInfoMixin, ParsedSourceMandatory):
def search_name(self):
return f"{self.source_name}.{self.name}"
@property
def group(self):
return None
# ====================================
# Exposure node
@@ -1602,6 +1614,10 @@ class SemanticModel(GraphNode):
else None
)
@property
def group(self):
return None
# ====================================
# Patches
@@ -1651,6 +1667,7 @@ ManifestSQLNode = Union[
SqlNode,
GenericTestNode,
SnapshotNode,
UnitTestNode,
]
# All SQL nodes plus SeedNode (csv files)
@@ -1680,7 +1697,4 @@ Resource = Union[
Group,
]
TestNode = Union[
SingularTestNode,
GenericTestNode,
]
TestNode = Union[SingularTestNode, GenericTestNode]

View File

@@ -671,6 +671,34 @@ class UnparsedGroup(dbtClassMixin, Replaceable):
raise ValidationError("Group owner must have at least one of 'name' or 'email'.")
@dataclass
class UnparsedInputFixture(dbtClassMixin):
input: str
rows: List[Dict[str, Any]] = field(default_factory=list)
@dataclass
class UnparsedUnitTestOverrides(dbtClassMixin):
macros: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
env_vars: Dict[str, Any] = field(default_factory=dict)
@dataclass
class UnparsedUnitTestCase(dbtClassMixin):
name: str
given: Sequence[UnparsedInputFixture]
expect: List[Dict[str, Any]]
description: str = ""
overrides: Optional[UnparsedUnitTestOverrides] = None
@dataclass
class UnparsedUnitTestSuite(dbtClassMixin):
model: str # name of the model being unit tested
tests: Sequence[UnparsedUnitTestCase]
#
# semantic interfaces unparsed objects
#

View File

@@ -2158,7 +2158,7 @@ class SQLCompiledPath(InfoLevel):
return "Z026"
def message(self) -> str:
return f" compiled Code at {self.path}"
return f" compiled code at {self.path}"
class CheckNodeTestFailure(InfoLevel):

View File

@@ -178,7 +178,7 @@ class DbtDatabaseError(DbtRuntimeError):
lines = []
if hasattr(self.node, "build_path") and self.node.build_path:
lines.append(f"compiled Code at {self.node.build_path}")
lines.append(f"compiled code at {self.node.build_path}")
return lines + DbtRuntimeError.process_stack(self)

View File

@@ -12,3 +12,31 @@
{{ "limit " ~ limit if limit != none }}
) dbt_internal_test
{%- endmacro %}
{% macro get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
{{ adapter.dispatch('get_unit_test_sql', 'dbt')(main_sql, expected_fixture_sql, expected_column_names) }}
{%- endmacro %}
{% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
-- Build actual result given inputs
with dbt_internal_unit_test_actual AS (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as actual_or_expected
from (
{{ main_sql }}
) _dbt_internal_unit_test_actual
),
-- Build expected result
dbt_internal_unit_test_expected AS (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as actual_or_expected
from (
{{ expected_fixture_sql }}
) _dbt_internal_unit_test_expected
)
-- Union actual and expected results
select * from dbt_internal_unit_test_actual
union all
select * from dbt_internal_unit_test_expected
{%- endmacro %}

View File

@@ -0,0 +1,29 @@
{%- materialization unit, default -%}
{% set relations = [] %}
{% set expected_rows = config.get('expected_rows') %}
{% set tested_expected_column_names = expected_rows[0].keys() if (expected_rows | length ) > 0 else get_columns_in_query(sql) %} %}
{%- set target_relation = this.incorporate(type='table') -%}
{%- set temp_relation = make_temp_relation(target_relation)-%}
{% do run_query(get_create_table_as_sql(True, temp_relation, get_empty_subquery_sql(sql))) %}
{%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}
{% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %}
{% call statement('main', fetch_result=True) -%}
{{ unit_test_sql }}
{%- endcall %}
{% do adapter.drop_relation(temp_relation) %}
{{ return({'relations': relations}) }}
{%- endmaterialization -%}

View File

@@ -0,0 +1,77 @@
{% macro get_fixture_sql(rows, column_name_to_data_types) %}
-- Fixture for {{ model.name }}
{% set default_row = {} %}
{%- if not column_name_to_data_types -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}
{%- endif -%}
{%- if not column_name_to_data_types -%}
{{ exceptions.raise_compiler_error("columns not available for" ~ model.name) }}
{%- endif -%}
{%- for column_name, column_type in column_name_to_data_types.items() -%}
{%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%}
{%- endfor -%}
{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
select
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
{% endif %}
{%- endfor -%}
{%- if (rows | length) == 0 -%}
select
{%- for column_name, column_value in default_row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- endfor %}
limit 0
{%- endif -%}
{% endmacro %}
{% macro get_expected_sql(rows, column_name_to_data_types) %}
{%- if (rows | length) == 0 -%}
select * FROM dbt_internal_unit_test_actual
limit 0
{%- else -%}
{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
select
{%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
{% endif %}
{%- endfor -%}
{%- endif -%}
{% endmacro %}
{%- macro format_row(row, column_name_to_data_types) -%}
{#-- wrap yaml strings in quotes, apply cast --#}
{%- for column_name, column_value in row.items() -%}
{% set row_update = {column_name: column_value} %}
{%- if column_value is string -%}
{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%}
{%- elif column_value is none -%}
{%- set row_update = {column_name: safe_cast('null', column_name_to_data_types[column_name]) } -%}
{%- else -%}
{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%}
{%- endif -%}
{%- do row.update(row_update) -%}
{%- endfor -%}
{%- endmacro -%}

View File

@@ -34,6 +34,7 @@ class NodeType(StrEnum):
Metric = "metric"
Group = "group"
SemanticModel = "semantic_model"
Unit = "unit_test"
@classmethod
def executable(cls) -> List["NodeType"]:
@@ -47,6 +48,7 @@ class NodeType(StrEnum):
cls.Documentation,
cls.RPCCall,
cls.SqlOperation,
cls.Unit,
]
@classmethod

View File

@@ -138,6 +138,11 @@ class SchemaParser(SimpleParser[YamlBlock, ModelNode]):
self.root_project, self.project.project_name, self.schema_yaml_vars
)
# This is unnecessary, but mypy was requiring it. Clean up parser code so
# we don't have to do this.
def parse_from_dict(self, dct):
pass
@classmethod
def get_compiled_path(cls, block: FileBlock) -> str:
# should this raise an error?
@@ -297,7 +302,7 @@ class YamlReader(metaclass=ABCMeta):
if coerce_dict_str(entry) is None:
raise YamlParseListError(path, self.key, data, "expected a dict with string keys")
if "name" not in entry:
if "name" not in entry and "model" not in entry:
raise ParsingError("Entry did not contain a name")
# Render the data (except for tests and descriptions).

View File

@@ -0,0 +1,215 @@
from dbt.contracts.graph.unparsed import UnparsedUnitTestSuite
from dbt.contracts.graph.model_config import NodeConfig
from dbt_extractor import py_extract_from_source # type: ignore
from dbt.contracts.graph.nodes import (
ModelNode,
UnitTestNode,
RefArgs,
)
from dbt.contracts.graph.manifest import Manifest
from dbt.parser.schemas import (
SchemaParser,
YamlBlock,
ValidationError,
JSONValidationError,
YamlParseDictError,
YamlReader,
)
from dbt.exceptions import (
ParsingError,
)
from dbt.parser.search import FileBlock
from dbt.contracts.files import FileHash, SchemaSourceFile
from dbt.node_types import NodeType
from dbt.context.providers import generate_parse_exposure, get_rendered
from typing import List
def _is_model_node(node_id, manifest):
return manifest.nodes[node_id].resource_type == NodeType.Model
class UnitTestManifestLoader:
@classmethod
def load(cls, manifest, root_project) -> Manifest:
unit_test_manifest = Manifest(macros=manifest.macros)
for file in manifest.files.values():
block = FileBlock(file)
if isinstance(file, SchemaSourceFile):
dct = file.dict_from_yaml
if "unit" in dct:
yaml_block = YamlBlock.from_file_block(block, dct)
# TODO: first root_project should be project, or we should only parse unit tests from root_project
schema_parser = SchemaParser(root_project, manifest, root_project)
parser = UnitTestParser(schema_parser, yaml_block, unit_test_manifest)
parser.parse()
model_to_unit_tests = {}
for node in unit_test_manifest.nodes.values():
if isinstance(node, UnitTestNode):
model_name = node.name.split("__")[0]
if model_name not in model_to_unit_tests:
model_to_unit_tests[model_name] = [node.unique_id]
else:
model_to_unit_tests[model_name].append(node.unique_id)
for node in unit_test_manifest.nodes.values():
if isinstance(node, UnitTestNode):
# a unit test should depend on its fixture nodes, and any unit tests on its ref'd nodes
for ref in node.refs:
for unique_id in model_to_unit_tests.get(ref.name, []):
node.depends_on.nodes.append(unique_id)
return unit_test_manifest
class UnitTestParser(YamlReader):
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, unit_test_manifest: Manifest):
super().__init__(schema_parser, yaml, "unit")
self.yaml = yaml
self.unit_test_manifest = unit_test_manifest
def parse_unit_test(self, unparsed: UnparsedUnitTestSuite):
package_name = self.project.project_name
path = self.yaml.path.relative_path
# TODO: fix
checksum = "f8f57c9e32eafaacfb002a4d03a47ffb412178f58f49ba58fd6f436f09f8a1d6"
unit_test_node_ids = []
for unit_test in unparsed.tests:
input_nodes = []
original_input_nodes = []
for given in unit_test.given:
original_input_node = self._get_original_input_node(given.input)
original_input_nodes.append(original_input_node)
original_input_node_columns = None
if (
original_input_node.resource_type == NodeType.Model
and original_input_node.config.contract.enforced
):
original_input_node_columns = {
column.name: column.data_type for column in original_input_node.columns
}
# TODO: package_name?
input_name = f"{unparsed.model}__{unit_test.name}__{original_input_node.name}"
input_unique_id = f"model.{package_name}.{input_name}"
input_node = ModelNode(
raw_code=self._build_raw_code(given.rows, original_input_node_columns),
resource_type=NodeType.Model,
package_name=package_name,
path=path,
# original_file_path=self.yaml.path.original_file_path,
original_file_path=f"models_unit_test/{input_name}.sql",
unique_id=input_unique_id,
name=input_name,
config=NodeConfig(materialized="ephemeral"),
database=original_input_node.database,
schema=original_input_node.schema,
alias=original_input_node.alias,
fqn=input_unique_id.split("."),
checksum=FileHash(name="sha256", checksum=checksum),
)
input_nodes.append(input_node)
actual_node = self.manifest.ref_lookup.perform_lookup(
f"model.{package_name}.{unparsed.model}", self.manifest
)
unit_test_unique_id = f"unit.{package_name}.{unit_test.name}.{unparsed.model}"
unit_test_node = UnitTestNode(
resource_type=NodeType.Unit,
package_name=package_name,
path=f"{unparsed.model}.sql",
# original_file_path=self.yaml.path.original_file_path,
original_file_path=f"models_unit_test/{unparsed.model}.sql",
unique_id=unit_test_unique_id,
name=f"{unparsed.model}__{unit_test.name}",
# TODO: merge with node config
config=NodeConfig(materialized="unit", _extra={"expected_rows": unit_test.expect}),
raw_code=actual_node.raw_code,
database=actual_node.database,
schema=actual_node.schema,
alias=f"{unparsed.model}__{unit_test.name}",
fqn=unit_test_unique_id.split("."),
checksum=FileHash(name="sha256", checksum=checksum),
attached_node=actual_node.unique_id,
overrides=unit_test.overrides,
)
# TODO: generalize this method
ctx = generate_parse_exposure(
unit_test_node, # type: ignore
self.root_project,
self.schema_parser.manifest,
package_name,
)
get_rendered(unit_test_node.raw_code, ctx, unit_test_node, capture_macros=True)
# unit_test_node now has a populated refs/sources
# during compilation, refs will resolve to fixtures,
# so add original input node ids to depends on explicitly to preserve lineage
for original_input_node in original_input_nodes:
# TODO: consider nulling out the original_input_node.raw_code
self.unit_test_manifest.nodes[original_input_node.unique_id] = original_input_node
unit_test_node.depends_on.nodes.append(original_input_node.unique_id)
self.unit_test_manifest.nodes[unit_test_node.unique_id] = unit_test_node
# self.unit_test_manifest.nodes[actual_node.unique_id] = actual_node
for input_node in input_nodes:
self.unit_test_manifest.nodes[input_node.unique_id] = input_node
# should be a process_refs / process_sources call isntead?
unit_test_node.depends_on.nodes.append(input_node.unique_id)
unit_test_node_ids.append(unit_test_node.unique_id)
# find out all nodes that are referenced but not in unittest manifest
all_depends_on = set()
for node_id in self.unit_test_manifest.nodes:
if _is_model_node(node_id, self.unit_test_manifest):
all_depends_on.update(self.unit_test_manifest.nodes[node_id].depends_on.nodes) # type: ignore
not_in_manifest = all_depends_on - set(self.unit_test_manifest.nodes.keys())
# copy those node also over into unit_test_manifest
for node_id in not_in_manifest:
self.unit_test_manifest.nodes[node_id] = self.manifest.nodes[node_id]
def parse(self):
for data in self.get_key_dicts():
try:
UnparsedUnitTestSuite.validate(data)
unparsed = UnparsedUnitTestSuite.from_dict(data)
except (ValidationError, JSONValidationError) as exc:
raise YamlParseDictError(self.yaml.path, self.key, data, exc)
self.parse_unit_test(unparsed)
def _build_raw_code(self, rows, column_name_to_data_types) -> str:
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)
def _get_original_input_node(self, input: str):
statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}")
if statically_parsed["refs"]:
# set refs and sources on the node object
refs: List[RefArgs] = []
for ref in statically_parsed["refs"]:
name = ref.get("name")
package = ref.get("package")
version = ref.get("version")
refs.append(RefArgs(name, package, version))
# TODO: disabled lookup, versioned lookup, public models
original_input_node = self.manifest.ref_lookup.find(
name, package, version, self.manifest
)
elif statically_parsed["sources"]:
input_package_name, input_source_name = statically_parsed["sources"][0]
original_input_node = self.manifest.source_lookup.find(
input_source_name, input_package_name, self.manifest
)
else:
raise ParsingError("given input must be ref or source")
return original_input_node

View File

@@ -17,6 +17,7 @@ from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.task.unit_test import UnitTestTask
RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
@@ -30,6 +31,7 @@ TASK_DICT = {
"test": TestTask,
"run": RunTask,
"run-operation": RunOperationTask,
"unit-test": UnitTestTask,
}
CMD_DICT = {
@@ -42,6 +44,7 @@ CMD_DICT = {
"test": CliCommand.TEST,
"run": CliCommand.RUN,
"run-operation": CliCommand.RUN_OPERATION,
"unit-test": CliCommand.UNIT_TEST,
}

192
core/dbt/task/unit_test.py Normal file
View File

@@ -0,0 +1,192 @@
from dataclasses import dataclass
from dbt.dataclass_schema import dbtClassMixin
import threading
from typing import Dict, Any, Optional
import io
from .compile import CompileRunner
from .run import RunTask
from dbt.contracts.graph.nodes import UnitTestNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import TestStatus, RunResult
from dbt.context.providers import generate_runtime_model_context
from dbt.clients.jinja import MacroGenerator
from dbt.events.functions import fire_event
from dbt.events.types import (
LogTestResult,
LogStartLine,
)
from dbt.graph import ResourceTypeSelector
from dbt.exceptions import (
DbtInternalError,
MissingMaterializationError,
)
from dbt.node_types import NodeType
@dataclass
class UnitTestResultData(dbtClassMixin):
should_error: bool
adapter_response: Dict[str, Any]
diff: Optional[str] = None
class UnitTestRunner(CompileRunner):
def describe_node(self):
return f"{self.node.resource_type} {self.node.name}"
def print_result_line(self, result):
model = result.node
fire_event(
LogTestResult(
name=model.name,
status=str(result.status),
index=self.node_index,
num_models=self.num_nodes,
execution_time=result.execution_time,
node_info=model.node_info,
num_failures=result.failures,
),
level=LogTestResult.status_to_level(str(result.status)),
)
def print_start_line(self):
fire_event(
LogStartLine(
description=self.describe_node(),
index=self.node_index,
total=self.num_nodes,
node_info=self.node.node_info,
)
)
def before_execute(self):
self.print_start_line()
def execute_unit_test(self, node: UnitTestNode, manifest: Manifest) -> UnitTestResultData:
# generate_runtime_unit_test_context not strictly needed - this is to run the 'unit' materialization, not compile the node.compield_code
context = generate_runtime_model_context(node, self.config, manifest)
materialization_macro = manifest.find_materialization_macro_by_name(
self.config.project_name, node.get_materialization(), self.adapter.type()
)
if materialization_macro is None:
raise MissingMaterializationError(
materialization=node.get_materialization(), adapter_type=self.adapter.type()
)
if "config" not in context:
raise DbtInternalError(
"Invalid materialization context generated, missing config: {}".format(context)
)
# generate materialization macro
macro_func = MacroGenerator(materialization_macro, context)
# execute materialization macro
macro_func()
# load results from context
# could eventually be returned directly by materialization
result = context["load_result"]("main")
adapter_response = result["response"].to_dict(omit_none=True)
table = result["table"]
actual = self._get_unit_test_table(table, "actual")
expected = self._get_unit_test_table(table, "expected")
should_error = actual.rows != expected.rows
diff = None
if should_error:
actual_output = self._agate_table_to_str(actual)
expected_output = self._agate_table_to_str(expected)
diff = f"\n\nActual:\n{actual_output}\n\nExpected:\n{expected_output}\n"
return UnitTestResultData(
diff=diff,
should_error=should_error,
adapter_response=adapter_response,
)
def execute(self, node: UnitTestNode, manifest: Manifest):
result = self.execute_unit_test(node, manifest)
thread_id = threading.current_thread().name
status = TestStatus.Pass
message = None
failures = 0
if result.should_error:
status = TestStatus.Fail
message = result.diff
failures = 1
return RunResult(
node=node,
status=status,
timing=[],
thread_id=thread_id,
execution_time=0,
message=message,
adapter_response=result.adapter_response,
failures=failures,
)
def after_execute(self, result):
self.print_result_line(result)
def _get_unit_test_table(self, result_table, actual_or_expected: str):
unit_test_table = result_table.where(
lambda row: row["actual_or_expected"] == actual_or_expected
)
columns = list(unit_test_table.columns.keys())
columns.remove("actual_or_expected")
return unit_test_table.select(columns)
def _agate_table_to_str(self, table) -> str:
# Hack to get Agate table output as string
output = io.StringIO()
if self.config.args.output == "json":
table.to_json(path=output)
else:
table.print_table(output=output, max_rows=None)
return output.getvalue().strip()
class UnitTestSelector(ResourceTypeSelector):
def __init__(self, graph, manifest, previous_state):
super().__init__(
graph=graph,
manifest=manifest,
previous_state=previous_state,
resource_types=[NodeType.Unit],
)
class UnitTestTask(RunTask):
"""
Unit testing:
Read schema files + custom data tests and validate that
constraints are satisfied.
"""
def __init__(self, args, config, manifest, collection):
# This will initialize the RunTask with the unit test manifest ("collection") as the manifest
super().__init__(args, config, collection)
self.collection = collection
self.original_manifest = manifest
__test__ = False
def raise_on_first_error(self):
return False
def get_node_selector(self) -> UnitTestSelector:
if self.manifest is None or self.graph is None:
raise DbtInternalError("manifest and graph must be set to get perform node selection")
return UnitTestSelector(
graph=self.graph,
manifest=self.manifest,
previous_state=self.previous_state,
)
def get_runner_type(self, _):
return UnitTestRunner

View File

@@ -0,0 +1,122 @@
import pytest
from dbt.tests.util import run_dbt
my_model_sql = """
SELECT
a+b as c,
concat(string_a, string_b) as string_c,
not_testing, date_a,
{{ dbt.string_literal(type_numeric()) }} as macro_call,
{{ dbt.string_literal(var('my_test')) }} as var_call,
{{ dbt.string_literal(env_var('MY_TEST', 'default')) }} as env_var_call,
{{ dbt.string_literal(invocation_id) }} as invocation_id
FROM {{ ref('my_model_a')}} my_model_a
JOIN {{ ref('my_model_b' )}} my_model_b
ON my_model_a.id = my_model_b.id
"""
my_model_a_sql = """
SELECT
1 as a,
1 as id,
2 as not_testing,
'a' as string_a,
DATE '2020-01-02' as date_a
"""
my_model_b_sql = """
SELECT
2 as b,
1 as id,
2 as c,
'b' as string_b
"""
test_my_model_yml = """
unit:
- model: my_model
tests:
- name: test_my_model
given:
- input: ref('my_model_a')
rows:
- {id: 1, a: 1}
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
expect:
- {c: 2}
- name: test_my_model_empty
given:
- input: ref('my_model_a')
rows: []
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
expect: []
- name: test_my_model_overrides
given:
- input: ref('my_model_a')
rows:
- {id: 1, a: 1}
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
overrides:
macros:
type_numeric: override
invocation_id: 123
vars:
my_test: var_override
env_vars:
MY_TEST: env_var_override
expect:
- {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123}
- name: test_my_model_string_concat
given:
- input: ref('my_model_a')
rows:
- {id: 1, string_a: a}
- input: ref('my_model_b')
rows:
- {id: 1, string_b: b}
expect:
- {string_c: ab}
- name: test_my_model_datetime
given:
- input: ref('my_model_a')
rows:
- {id: 1, date_a: "2020-01-01"}
- input: ref('my_model_b')
rows:
- {id: 1}
expect:
- {date_a: "2020-01-01"}
"""
class TestUnitTests:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_model_a.sql": my_model_a_sql,
"my_model_b.sql": my_model_b_sql,
"test_my_model.yml": test_my_model_yml,
}
@pytest.fixture(scope="class")
def project_config_update(self):
return {"vars": {"my_test": "my_test_var"}}
def test_basic(self, project):
run_dbt(["deps"])
results = run_dbt(["run"])
assert len(results) == 3
results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False)
assert len(results) == 5

View File

@@ -17,6 +17,7 @@ node_type_pluralizations = {
NodeType.Metric: "metrics",
NodeType.Group: "groups",
NodeType.SemanticModel: "semantic_models",
NodeType.Unit: "unit_tests",
}