mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-18 22:41:27 +00:00
Compare commits
12 Commits
update-ind
...
ct-2911-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed8d9dbf8f | ||
|
|
0b31227c56 | ||
|
|
b9b4661f4d | ||
|
|
7e4bf98461 | ||
|
|
b14c3a09a7 | ||
|
|
a8a2331a93 | ||
|
|
68968a74d7 | ||
|
|
31af2b9979 | ||
|
|
4a8317d974 | ||
|
|
a2f197851c | ||
|
|
c578971be4 | ||
|
|
49179cb7fb |
6
.changes/unreleased/Features-20230802-145011.yaml
Normal file
6
.changes/unreleased/Features-20230802-145011.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
kind: Features
|
||||
body: Initial implementation of unit testing
|
||||
time: 2023-08-02T14:50:11.391992-04:00
|
||||
custom:
|
||||
Author: gshank
|
||||
Issue: "8287"
|
||||
@@ -12,7 +12,7 @@ class RelationConfigChangeAction(StrEnum):
|
||||
drop = "drop"
|
||||
|
||||
|
||||
@dataclass(frozen=True, eq=True, unsafe_hash=True)
|
||||
@dataclass(frozen=True, eq=True, unsafe_hash=True) # type: ignore
|
||||
class RelationConfigChange(RelationConfigBase, ABC):
|
||||
action: RelationConfigChangeAction
|
||||
context: Hashable # this is usually a RelationConfig, e.g. IndexConfig, but shouldn't be limited
|
||||
|
||||
@@ -392,6 +392,7 @@ def command_args(command: CliCommand) -> ArgsList:
|
||||
CliCommand.SOURCE_FRESHNESS: cli.freshness,
|
||||
CliCommand.TEST: cli.test,
|
||||
CliCommand.RETRY: cli.retry,
|
||||
CliCommand.UNIT_TEST: cli.unit_test,
|
||||
}
|
||||
click_cmd: Optional[ClickCommand] = CMD_DICT.get(command, None)
|
||||
if click_cmd is None:
|
||||
|
||||
@@ -39,6 +39,7 @@ from dbt.task.serve import ServeTask
|
||||
from dbt.task.show import ShowTask
|
||||
from dbt.task.snapshot import SnapshotTask
|
||||
from dbt.task.test import TestTask
|
||||
from dbt.task.unit_test import UnitTestTask
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -846,6 +847,52 @@ def test(ctx, **kwargs):
|
||||
return results, success
|
||||
|
||||
|
||||
# dbt unit-test
|
||||
@cli.command("unit-test")
|
||||
@click.pass_context
|
||||
@p.defer
|
||||
@p.deprecated_defer
|
||||
@p.exclude
|
||||
@p.fail_fast
|
||||
@p.favor_state
|
||||
@p.deprecated_favor_state
|
||||
@p.indirect_selection
|
||||
@p.show_output_format
|
||||
@p.profile
|
||||
@p.profiles_dir
|
||||
@p.project_dir
|
||||
@p.select
|
||||
@p.selector
|
||||
@p.state
|
||||
@p.defer_state
|
||||
@p.deprecated_state
|
||||
@p.store_failures
|
||||
@p.target
|
||||
@p.target_path
|
||||
@p.threads
|
||||
@p.vars
|
||||
@p.version_check
|
||||
@requires.postflight
|
||||
@requires.preflight
|
||||
@requires.profile
|
||||
@requires.project
|
||||
@requires.runtime_config
|
||||
@requires.manifest
|
||||
@requires.unit_test_collection
|
||||
def unit_test(ctx, **kwargs):
|
||||
"""Runs tests on data in deployed models. Run this after `dbt run`"""
|
||||
task = UnitTestTask(
|
||||
ctx.obj["flags"],
|
||||
ctx.obj["runtime_config"],
|
||||
ctx.obj["manifest"],
|
||||
ctx.obj["unit_test_collection"],
|
||||
)
|
||||
|
||||
results = task.run()
|
||||
success = task.interpret_results(results)
|
||||
return results, success
|
||||
|
||||
|
||||
# Support running as a module
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@@ -23,6 +23,7 @@ from dbt.parser.manifest import ManifestLoader, write_manifest
|
||||
from dbt.profiler import profiler
|
||||
from dbt.tracking import active_user, initialize_from_flags, track_run
|
||||
from dbt.utils import cast_dict_to_dict_of_strings
|
||||
from dbt.parser.unit_tests import UnitTestManifestLoader
|
||||
from dbt.plugins import set_up_plugin_manager, get_plugin_manager
|
||||
|
||||
from click import Context
|
||||
@@ -265,3 +266,25 @@ def manifest(*args0, write=True, write_perf_info=False):
|
||||
if len(args0) == 0:
|
||||
return outer_wrapper
|
||||
return outer_wrapper(args0[0])
|
||||
|
||||
|
||||
def unit_test_collection(func):
|
||||
"""A decorator used by click command functions for generating a unit test collection provided a manifest"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
ctx = args[0]
|
||||
assert isinstance(ctx, Context)
|
||||
|
||||
req_strs = ["manifest", "runtime_config"]
|
||||
reqs = [ctx.obj.get(req_str) for req_str in req_strs]
|
||||
|
||||
if None in reqs:
|
||||
raise DbtProjectError("manifest and runtime_config required for unit_test_collection")
|
||||
|
||||
collection = UnitTestManifestLoader.load(ctx.obj["manifest"], ctx.obj["runtime_config"])
|
||||
|
||||
ctx.obj["unit_test_collection"] = collection
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return update_wrapper(wrapper, func)
|
||||
|
||||
@@ -24,6 +24,7 @@ class Command(Enum):
|
||||
SOURCE_FRESHNESS = "freshness"
|
||||
TEST = "test"
|
||||
RETRY = "retry"
|
||||
UNIT_TEST = "unit-test"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str) -> "Command":
|
||||
|
||||
@@ -330,6 +330,26 @@ class MacroGenerator(BaseMacroGenerator):
|
||||
return self.call_macro(*args, **kwargs)
|
||||
|
||||
|
||||
class UnitTestMacroGenerator(MacroGenerator):
|
||||
# this makes UnitTestMacroGenerator objects callable like functions
|
||||
def __init__(
|
||||
self,
|
||||
macro_generator: MacroGenerator,
|
||||
call_return_value: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
macro_generator.macro,
|
||||
macro_generator.context,
|
||||
macro_generator.node,
|
||||
macro_generator.stack,
|
||||
)
|
||||
self.call_return_value = call_return_value
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
with self.track_call():
|
||||
return self.call_return_value
|
||||
|
||||
|
||||
class QueryStringGenerator(BaseMacroGenerator):
|
||||
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
@@ -12,7 +12,10 @@ from dbt.flags import get_flags
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.clients import jinja
|
||||
from dbt.clients.system import make_directory
|
||||
from dbt.context.providers import generate_runtime_model_context
|
||||
from dbt.context.providers import (
|
||||
generate_runtime_model_context,
|
||||
generate_runtime_unit_test_context,
|
||||
)
|
||||
from dbt.contracts.graph.manifest import Manifest, UniqueID
|
||||
from dbt.contracts.graph.nodes import (
|
||||
ManifestNode,
|
||||
@@ -21,6 +24,7 @@ from dbt.contracts.graph.nodes import (
|
||||
GraphMemberNode,
|
||||
InjectedCTE,
|
||||
SeedNode,
|
||||
UnitTestNode,
|
||||
)
|
||||
from dbt.exceptions import (
|
||||
GraphDependencyNotFoundError,
|
||||
@@ -44,6 +48,7 @@ def print_compile_stats(stats):
|
||||
names = {
|
||||
NodeType.Model: "model",
|
||||
NodeType.Test: "test",
|
||||
NodeType.Unit: "unit test",
|
||||
NodeType.Snapshot: "snapshot",
|
||||
NodeType.Analysis: "analysis",
|
||||
NodeType.Macro: "macro",
|
||||
@@ -289,8 +294,10 @@ class Compiler:
|
||||
manifest: Manifest,
|
||||
extra_context: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
context = generate_runtime_model_context(node, self.config, manifest)
|
||||
if isinstance(node, UnitTestNode):
|
||||
context = generate_runtime_unit_test_context(node, self.config, manifest)
|
||||
else:
|
||||
context = generate_runtime_model_context(node, self.config, manifest)
|
||||
context.update(extra_context)
|
||||
|
||||
if isinstance(node, GenericTestNode):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from copy import deepcopy
|
||||
import os
|
||||
from typing import (
|
||||
Callable,
|
||||
@@ -17,7 +18,7 @@ from typing_extensions import Protocol
|
||||
from dbt.adapters.base.column import Column
|
||||
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
|
||||
from dbt.clients import agate_helper
|
||||
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
|
||||
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
|
||||
from dbt.config import RuntimeConfig, Project
|
||||
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
|
||||
from dbt.context.base import contextmember, contextproperty, Var
|
||||
@@ -39,6 +40,7 @@ from dbt.contracts.graph.nodes import (
|
||||
RefArgs,
|
||||
AccessType,
|
||||
SemanticModel,
|
||||
UnitTestNode,
|
||||
)
|
||||
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
|
||||
from dbt.contracts.graph.unparsed import NodeVersion
|
||||
@@ -566,6 +568,17 @@ class OperationRefResolver(RuntimeRefResolver):
|
||||
return super().create_relation(target_model)
|
||||
|
||||
|
||||
class RuntimeUnitTestRefResolver(RuntimeRefResolver):
|
||||
def resolve(
|
||||
self,
|
||||
target_name: str,
|
||||
target_package: Optional[str] = None,
|
||||
target_version: Optional[NodeVersion] = None,
|
||||
) -> RelationProxy:
|
||||
target_name = f"{self.model.name}__{target_name}"
|
||||
return super().resolve(target_name, target_package, target_version)
|
||||
|
||||
|
||||
# `source` implementations
|
||||
class ParseSourceResolver(BaseSourceResolver):
|
||||
def resolve(self, source_name: str, table_name: str):
|
||||
@@ -670,6 +683,22 @@ class RuntimeVar(ModelConfiguredVar):
|
||||
pass
|
||||
|
||||
|
||||
class UnitTestVar(RuntimeVar):
|
||||
def __init__(
|
||||
self,
|
||||
context: Dict[str, Any],
|
||||
config: RuntimeConfig,
|
||||
node: Resource,
|
||||
) -> None:
|
||||
config_copy = None
|
||||
assert isinstance(node, UnitTestNode)
|
||||
if node.overrides and node.overrides.vars:
|
||||
config_copy = deepcopy(config)
|
||||
config_copy.cli_vars.update(node.overrides.vars)
|
||||
|
||||
super().__init__(context, config_copy or config, node=node)
|
||||
|
||||
|
||||
# Providers
|
||||
class Provider(Protocol):
|
||||
execute: bool
|
||||
@@ -711,6 +740,16 @@ class RuntimeProvider(Provider):
|
||||
metric = RuntimeMetricResolver
|
||||
|
||||
|
||||
class RuntimeUnitTestProvider(Provider):
|
||||
execute = True
|
||||
Config = RuntimeConfigObject
|
||||
DatabaseWrapper = RuntimeDatabaseWrapper
|
||||
Var = UnitTestVar
|
||||
ref = RuntimeUnitTestRefResolver
|
||||
source = RuntimeSourceResolver # TODO: RuntimeUnitTestSourceResolver
|
||||
metric = RuntimeMetricResolver
|
||||
|
||||
|
||||
class OperationProvider(RuntimeProvider):
|
||||
ref = OperationRefResolver
|
||||
|
||||
@@ -1360,7 +1399,7 @@ class ModelContext(ProviderContext):
|
||||
|
||||
@contextproperty
|
||||
def pre_hooks(self) -> List[Dict[str, Any]]:
|
||||
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
|
||||
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
|
||||
return []
|
||||
# TODO CT-211
|
||||
return [
|
||||
@@ -1369,7 +1408,7 @@ class ModelContext(ProviderContext):
|
||||
|
||||
@contextproperty
|
||||
def post_hooks(self) -> List[Dict[str, Any]]:
|
||||
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
|
||||
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
|
||||
return []
|
||||
# TODO CT-211
|
||||
return [
|
||||
@@ -1462,6 +1501,25 @@ class ModelContext(ProviderContext):
|
||||
return None
|
||||
|
||||
|
||||
class UnitTestContext(ModelContext):
|
||||
model: UnitTestNode
|
||||
|
||||
@contextmember
|
||||
def env_var(self, var: str, default: Optional[str] = None) -> str:
|
||||
"""The env_var() function. Return the overriden unit test environment variable named 'var'.
|
||||
|
||||
If there is no unit test override, return the environment variable named 'var'.
|
||||
|
||||
If there is no such environment variable set, return the default.
|
||||
|
||||
If the default is None, raise an exception for an undefined variable.
|
||||
"""
|
||||
if self.model.overrides and var in self.model.overrides.env_vars:
|
||||
return self.model.overrides.env_vars[var]
|
||||
else:
|
||||
return super().env_var(var, default)
|
||||
|
||||
|
||||
# This is called by '_context_for', used in 'render_with_context'
|
||||
def generate_parser_model_context(
|
||||
model: ManifestNode,
|
||||
@@ -1506,6 +1564,24 @@ def generate_runtime_macro_context(
|
||||
return ctx.to_dict()
|
||||
|
||||
|
||||
def generate_runtime_unit_test_context(
|
||||
unit_test: UnitTestNode,
|
||||
config: RuntimeConfig,
|
||||
manifest: Manifest,
|
||||
) -> Dict[str, Any]:
|
||||
ctx = UnitTestContext(unit_test, config, manifest, RuntimeUnitTestProvider(), None)
|
||||
ctx_dict = ctx.to_dict()
|
||||
|
||||
if unit_test.overrides and unit_test.overrides.macros:
|
||||
for macro_name, macro_value in unit_test.overrides.macros.items():
|
||||
context_value = ctx_dict.get(macro_name)
|
||||
if isinstance(context_value, MacroGenerator):
|
||||
ctx_dict[macro_name] = UnitTestMacroGenerator(context_value, macro_value)
|
||||
else:
|
||||
ctx_dict[macro_name] = macro_value
|
||||
return ctx_dict
|
||||
|
||||
|
||||
class ExposureRefResolver(BaseResolver):
|
||||
def __call__(self, *args, **kwargs) -> str:
|
||||
package = None
|
||||
|
||||
@@ -665,6 +665,7 @@ RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
|
||||
NodeType.Source: SourceConfig,
|
||||
NodeType.Seed: SeedConfig,
|
||||
NodeType.Test: TestConfig,
|
||||
NodeType.Unit: TestConfig,
|
||||
NodeType.Model: NodeConfig,
|
||||
NodeType.Snapshot: SnapshotConfig,
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ from dbt.contracts.graph.unparsed import (
|
||||
UnparsedSourceDefinition,
|
||||
UnparsedSourceTableDefinition,
|
||||
UnparsedColumn,
|
||||
UnparsedUnitTestOverrides,
|
||||
)
|
||||
from dbt.contracts.graph.node_args import ModelNodeArgs
|
||||
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
|
||||
@@ -983,6 +984,13 @@ class GenericTestNode(TestShouldStoreFailures, CompiledNode, HasTestMetadata):
|
||||
return "generic"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnitTestNode(CompiledNode):
|
||||
resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]})
|
||||
attached_node: Optional[str] = None
|
||||
overrides: Optional[UnparsedUnitTestOverrides] = None
|
||||
|
||||
|
||||
# ====================================
|
||||
# Snapshot node
|
||||
# ====================================
|
||||
@@ -1239,6 +1247,10 @@ class SourceDefinition(NodeInfoMixin, ParsedSourceMandatory):
|
||||
def search_name(self):
|
||||
return f"{self.source_name}.{self.name}"
|
||||
|
||||
@property
|
||||
def group(self):
|
||||
return None
|
||||
|
||||
|
||||
# ====================================
|
||||
# Exposure node
|
||||
@@ -1602,6 +1614,10 @@ class SemanticModel(GraphNode):
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def group(self):
|
||||
return None
|
||||
|
||||
|
||||
# ====================================
|
||||
# Patches
|
||||
@@ -1651,6 +1667,7 @@ ManifestSQLNode = Union[
|
||||
SqlNode,
|
||||
GenericTestNode,
|
||||
SnapshotNode,
|
||||
UnitTestNode,
|
||||
]
|
||||
|
||||
# All SQL nodes plus SeedNode (csv files)
|
||||
@@ -1680,7 +1697,4 @@ Resource = Union[
|
||||
Group,
|
||||
]
|
||||
|
||||
TestNode = Union[
|
||||
SingularTestNode,
|
||||
GenericTestNode,
|
||||
]
|
||||
TestNode = Union[SingularTestNode, GenericTestNode]
|
||||
|
||||
@@ -671,6 +671,34 @@ class UnparsedGroup(dbtClassMixin, Replaceable):
|
||||
raise ValidationError("Group owner must have at least one of 'name' or 'email'.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnparsedInputFixture(dbtClassMixin):
|
||||
input: str
|
||||
rows: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnparsedUnitTestOverrides(dbtClassMixin):
|
||||
macros: Dict[str, Any] = field(default_factory=dict)
|
||||
vars: Dict[str, Any] = field(default_factory=dict)
|
||||
env_vars: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnparsedUnitTestCase(dbtClassMixin):
|
||||
name: str
|
||||
given: Sequence[UnparsedInputFixture]
|
||||
expect: List[Dict[str, Any]]
|
||||
description: str = ""
|
||||
overrides: Optional[UnparsedUnitTestOverrides] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnparsedUnitTestSuite(dbtClassMixin):
|
||||
model: str # name of the model being unit tested
|
||||
tests: Sequence[UnparsedUnitTestCase]
|
||||
|
||||
|
||||
#
|
||||
# semantic interfaces unparsed objects
|
||||
#
|
||||
|
||||
@@ -2158,7 +2158,7 @@ class SQLCompiledPath(InfoLevel):
|
||||
return "Z026"
|
||||
|
||||
def message(self) -> str:
|
||||
return f" compiled Code at {self.path}"
|
||||
return f" compiled code at {self.path}"
|
||||
|
||||
|
||||
class CheckNodeTestFailure(InfoLevel):
|
||||
|
||||
@@ -178,7 +178,7 @@ class DbtDatabaseError(DbtRuntimeError):
|
||||
lines = []
|
||||
|
||||
if hasattr(self.node, "build_path") and self.node.build_path:
|
||||
lines.append(f"compiled Code at {self.node.build_path}")
|
||||
lines.append(f"compiled code at {self.node.build_path}")
|
||||
|
||||
return lines + DbtRuntimeError.process_stack(self)
|
||||
|
||||
|
||||
@@ -12,3 +12,31 @@
|
||||
{{ "limit " ~ limit if limit != none }}
|
||||
) dbt_internal_test
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{% macro get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
|
||||
{{ adapter.dispatch('get_unit_test_sql', 'dbt')(main_sql, expected_fixture_sql, expected_column_names) }}
|
||||
{%- endmacro %}
|
||||
|
||||
{% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
|
||||
-- Build actual result given inputs
|
||||
with dbt_internal_unit_test_actual AS (
|
||||
select
|
||||
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as actual_or_expected
|
||||
from (
|
||||
{{ main_sql }}
|
||||
) _dbt_internal_unit_test_actual
|
||||
),
|
||||
-- Build expected result
|
||||
dbt_internal_unit_test_expected AS (
|
||||
select
|
||||
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as actual_or_expected
|
||||
from (
|
||||
{{ expected_fixture_sql }}
|
||||
) _dbt_internal_unit_test_expected
|
||||
)
|
||||
-- Union actual and expected results
|
||||
select * from dbt_internal_unit_test_actual
|
||||
union all
|
||||
select * from dbt_internal_unit_test_expected
|
||||
{%- endmacro %}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
{%- materialization unit, default -%}
|
||||
|
||||
{% set relations = [] %}
|
||||
|
||||
{% set expected_rows = config.get('expected_rows') %}
|
||||
{% set tested_expected_column_names = expected_rows[0].keys() if (expected_rows | length ) > 0 else get_columns_in_query(sql) %} %}
|
||||
|
||||
{%- set target_relation = this.incorporate(type='table') -%}
|
||||
{%- set temp_relation = make_temp_relation(target_relation)-%}
|
||||
{% do run_query(get_create_table_as_sql(True, temp_relation, get_empty_subquery_sql(sql))) %}
|
||||
{%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%}
|
||||
{%- set column_name_to_data_types = {} -%}
|
||||
{%- for column in columns_in_relation -%}
|
||||
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %}
|
||||
|
||||
{% call statement('main', fetch_result=True) -%}
|
||||
|
||||
{{ unit_test_sql }}
|
||||
|
||||
{%- endcall %}
|
||||
|
||||
{% do adapter.drop_relation(temp_relation) %}
|
||||
|
||||
{{ return({'relations': relations}) }}
|
||||
|
||||
{%- endmaterialization -%}
|
||||
@@ -0,0 +1,77 @@
|
||||
{% macro get_fixture_sql(rows, column_name_to_data_types) %}
|
||||
-- Fixture for {{ model.name }}
|
||||
{% set default_row = {} %}
|
||||
|
||||
{%- if not column_name_to_data_types -%}
|
||||
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
|
||||
{%- set column_name_to_data_types = {} -%}
|
||||
{%- for column in columns_in_relation -%}
|
||||
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not column_name_to_data_types -%}
|
||||
{{ exceptions.raise_compiler_error("columns not available for" ~ model.name) }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- for column_name, column_type in column_name_to_data_types.items() -%}
|
||||
{%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- for row in rows -%}
|
||||
{%- do format_row(row, column_name_to_data_types) -%}
|
||||
|
||||
{%- set default_row_copy = default_row.copy() -%}
|
||||
{%- do default_row_copy.update(row) -%}
|
||||
select
|
||||
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if not loop.last %}
|
||||
union all
|
||||
{% endif %}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if (rows | length) == 0 -%}
|
||||
select
|
||||
{%- for column_name, column_value in default_row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %}
|
||||
{%- endfor %}
|
||||
limit 0
|
||||
{%- endif -%}
|
||||
{% endmacro %}
|
||||
|
||||
|
||||
{% macro get_expected_sql(rows, column_name_to_data_types) %}
|
||||
|
||||
{%- if (rows | length) == 0 -%}
|
||||
select * FROM dbt_internal_unit_test_actual
|
||||
limit 0
|
||||
{%- else -%}
|
||||
{%- for row in rows -%}
|
||||
{%- do format_row(row, column_name_to_data_types) -%}
|
||||
select
|
||||
{%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if not loop.last %}
|
||||
union all
|
||||
{% endif %}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{% endmacro %}
|
||||
|
||||
{%- macro format_row(row, column_name_to_data_types) -%}
|
||||
|
||||
{#-- wrap yaml strings in quotes, apply cast --#}
|
||||
{%- for column_name, column_value in row.items() -%}
|
||||
{% set row_update = {column_name: column_value} %}
|
||||
{%- if column_value is string -%}
|
||||
{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%}
|
||||
{%- elif column_value is none -%}
|
||||
{%- set row_update = {column_name: safe_cast('null', column_name_to_data_types[column_name]) } -%}
|
||||
{%- else -%}
|
||||
{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%}
|
||||
{%- endif -%}
|
||||
{%- do row.update(row_update) -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- endmacro -%}
|
||||
@@ -34,6 +34,7 @@ class NodeType(StrEnum):
|
||||
Metric = "metric"
|
||||
Group = "group"
|
||||
SemanticModel = "semantic_model"
|
||||
Unit = "unit_test"
|
||||
|
||||
@classmethod
|
||||
def executable(cls) -> List["NodeType"]:
|
||||
@@ -47,6 +48,7 @@ class NodeType(StrEnum):
|
||||
cls.Documentation,
|
||||
cls.RPCCall,
|
||||
cls.SqlOperation,
|
||||
cls.Unit,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -138,6 +138,11 @@ class SchemaParser(SimpleParser[YamlBlock, ModelNode]):
|
||||
self.root_project, self.project.project_name, self.schema_yaml_vars
|
||||
)
|
||||
|
||||
# This is unnecessary, but mypy was requiring it. Clean up parser code so
|
||||
# we don't have to do this.
|
||||
def parse_from_dict(self, dct):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_compiled_path(cls, block: FileBlock) -> str:
|
||||
# should this raise an error?
|
||||
@@ -297,7 +302,7 @@ class YamlReader(metaclass=ABCMeta):
|
||||
if coerce_dict_str(entry) is None:
|
||||
raise YamlParseListError(path, self.key, data, "expected a dict with string keys")
|
||||
|
||||
if "name" not in entry:
|
||||
if "name" not in entry and "model" not in entry:
|
||||
raise ParsingError("Entry did not contain a name")
|
||||
|
||||
# Render the data (except for tests and descriptions).
|
||||
|
||||
215
core/dbt/parser/unit_tests.py
Normal file
215
core/dbt/parser/unit_tests.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from dbt.contracts.graph.unparsed import UnparsedUnitTestSuite
|
||||
from dbt.contracts.graph.model_config import NodeConfig
|
||||
from dbt_extractor import py_extract_from_source # type: ignore
|
||||
from dbt.contracts.graph.nodes import (
|
||||
ModelNode,
|
||||
UnitTestNode,
|
||||
RefArgs,
|
||||
)
|
||||
from dbt.contracts.graph.manifest import Manifest
|
||||
from dbt.parser.schemas import (
|
||||
SchemaParser,
|
||||
YamlBlock,
|
||||
ValidationError,
|
||||
JSONValidationError,
|
||||
YamlParseDictError,
|
||||
YamlReader,
|
||||
)
|
||||
|
||||
from dbt.exceptions import (
|
||||
ParsingError,
|
||||
)
|
||||
from dbt.parser.search import FileBlock
|
||||
|
||||
from dbt.contracts.files import FileHash, SchemaSourceFile
|
||||
from dbt.node_types import NodeType
|
||||
|
||||
from dbt.context.providers import generate_parse_exposure, get_rendered
|
||||
from typing import List
|
||||
|
||||
|
||||
def _is_model_node(node_id, manifest):
|
||||
return manifest.nodes[node_id].resource_type == NodeType.Model
|
||||
|
||||
|
||||
class UnitTestManifestLoader:
|
||||
@classmethod
|
||||
def load(cls, manifest, root_project) -> Manifest:
|
||||
unit_test_manifest = Manifest(macros=manifest.macros)
|
||||
for file in manifest.files.values():
|
||||
block = FileBlock(file)
|
||||
if isinstance(file, SchemaSourceFile):
|
||||
dct = file.dict_from_yaml
|
||||
if "unit" in dct:
|
||||
yaml_block = YamlBlock.from_file_block(block, dct)
|
||||
# TODO: first root_project should be project, or we should only parse unit tests from root_project
|
||||
schema_parser = SchemaParser(root_project, manifest, root_project)
|
||||
parser = UnitTestParser(schema_parser, yaml_block, unit_test_manifest)
|
||||
parser.parse()
|
||||
|
||||
model_to_unit_tests = {}
|
||||
for node in unit_test_manifest.nodes.values():
|
||||
if isinstance(node, UnitTestNode):
|
||||
model_name = node.name.split("__")[0]
|
||||
if model_name not in model_to_unit_tests:
|
||||
model_to_unit_tests[model_name] = [node.unique_id]
|
||||
else:
|
||||
model_to_unit_tests[model_name].append(node.unique_id)
|
||||
|
||||
for node in unit_test_manifest.nodes.values():
|
||||
if isinstance(node, UnitTestNode):
|
||||
# a unit test should depend on its fixture nodes, and any unit tests on its ref'd nodes
|
||||
for ref in node.refs:
|
||||
for unique_id in model_to_unit_tests.get(ref.name, []):
|
||||
node.depends_on.nodes.append(unique_id)
|
||||
return unit_test_manifest
|
||||
|
||||
|
||||
class UnitTestParser(YamlReader):
|
||||
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, unit_test_manifest: Manifest):
|
||||
super().__init__(schema_parser, yaml, "unit")
|
||||
self.yaml = yaml
|
||||
self.unit_test_manifest = unit_test_manifest
|
||||
|
||||
def parse_unit_test(self, unparsed: UnparsedUnitTestSuite):
|
||||
package_name = self.project.project_name
|
||||
path = self.yaml.path.relative_path
|
||||
# TODO: fix
|
||||
checksum = "f8f57c9e32eafaacfb002a4d03a47ffb412178f58f49ba58fd6f436f09f8a1d6"
|
||||
unit_test_node_ids = []
|
||||
for unit_test in unparsed.tests:
|
||||
input_nodes = []
|
||||
original_input_nodes = []
|
||||
for given in unit_test.given:
|
||||
original_input_node = self._get_original_input_node(given.input)
|
||||
original_input_nodes.append(original_input_node)
|
||||
|
||||
original_input_node_columns = None
|
||||
if (
|
||||
original_input_node.resource_type == NodeType.Model
|
||||
and original_input_node.config.contract.enforced
|
||||
):
|
||||
original_input_node_columns = {
|
||||
column.name: column.data_type for column in original_input_node.columns
|
||||
}
|
||||
|
||||
# TODO: package_name?
|
||||
input_name = f"{unparsed.model}__{unit_test.name}__{original_input_node.name}"
|
||||
input_unique_id = f"model.{package_name}.{input_name}"
|
||||
|
||||
input_node = ModelNode(
|
||||
raw_code=self._build_raw_code(given.rows, original_input_node_columns),
|
||||
resource_type=NodeType.Model,
|
||||
package_name=package_name,
|
||||
path=path,
|
||||
# original_file_path=self.yaml.path.original_file_path,
|
||||
original_file_path=f"models_unit_test/{input_name}.sql",
|
||||
unique_id=input_unique_id,
|
||||
name=input_name,
|
||||
config=NodeConfig(materialized="ephemeral"),
|
||||
database=original_input_node.database,
|
||||
schema=original_input_node.schema,
|
||||
alias=original_input_node.alias,
|
||||
fqn=input_unique_id.split("."),
|
||||
checksum=FileHash(name="sha256", checksum=checksum),
|
||||
)
|
||||
input_nodes.append(input_node)
|
||||
|
||||
actual_node = self.manifest.ref_lookup.perform_lookup(
|
||||
f"model.{package_name}.{unparsed.model}", self.manifest
|
||||
)
|
||||
unit_test_unique_id = f"unit.{package_name}.{unit_test.name}.{unparsed.model}"
|
||||
unit_test_node = UnitTestNode(
|
||||
resource_type=NodeType.Unit,
|
||||
package_name=package_name,
|
||||
path=f"{unparsed.model}.sql",
|
||||
# original_file_path=self.yaml.path.original_file_path,
|
||||
original_file_path=f"models_unit_test/{unparsed.model}.sql",
|
||||
unique_id=unit_test_unique_id,
|
||||
name=f"{unparsed.model}__{unit_test.name}",
|
||||
# TODO: merge with node config
|
||||
config=NodeConfig(materialized="unit", _extra={"expected_rows": unit_test.expect}),
|
||||
raw_code=actual_node.raw_code,
|
||||
database=actual_node.database,
|
||||
schema=actual_node.schema,
|
||||
alias=f"{unparsed.model}__{unit_test.name}",
|
||||
fqn=unit_test_unique_id.split("."),
|
||||
checksum=FileHash(name="sha256", checksum=checksum),
|
||||
attached_node=actual_node.unique_id,
|
||||
overrides=unit_test.overrides,
|
||||
)
|
||||
|
||||
# TODO: generalize this method
|
||||
ctx = generate_parse_exposure(
|
||||
unit_test_node, # type: ignore
|
||||
self.root_project,
|
||||
self.schema_parser.manifest,
|
||||
package_name,
|
||||
)
|
||||
get_rendered(unit_test_node.raw_code, ctx, unit_test_node, capture_macros=True)
|
||||
# unit_test_node now has a populated refs/sources
|
||||
|
||||
# during compilation, refs will resolve to fixtures,
|
||||
# so add original input node ids to depends on explicitly to preserve lineage
|
||||
for original_input_node in original_input_nodes:
|
||||
# TODO: consider nulling out the original_input_node.raw_code
|
||||
self.unit_test_manifest.nodes[original_input_node.unique_id] = original_input_node
|
||||
unit_test_node.depends_on.nodes.append(original_input_node.unique_id)
|
||||
|
||||
self.unit_test_manifest.nodes[unit_test_node.unique_id] = unit_test_node
|
||||
# self.unit_test_manifest.nodes[actual_node.unique_id] = actual_node
|
||||
for input_node in input_nodes:
|
||||
self.unit_test_manifest.nodes[input_node.unique_id] = input_node
|
||||
# should be a process_refs / process_sources call isntead?
|
||||
unit_test_node.depends_on.nodes.append(input_node.unique_id)
|
||||
unit_test_node_ids.append(unit_test_node.unique_id)
|
||||
|
||||
# find out all nodes that are referenced but not in unittest manifest
|
||||
all_depends_on = set()
|
||||
for node_id in self.unit_test_manifest.nodes:
|
||||
if _is_model_node(node_id, self.unit_test_manifest):
|
||||
all_depends_on.update(self.unit_test_manifest.nodes[node_id].depends_on.nodes) # type: ignore
|
||||
not_in_manifest = all_depends_on - set(self.unit_test_manifest.nodes.keys())
|
||||
|
||||
# copy those node also over into unit_test_manifest
|
||||
for node_id in not_in_manifest:
|
||||
self.unit_test_manifest.nodes[node_id] = self.manifest.nodes[node_id]
|
||||
|
||||
def parse(self):
|
||||
for data in self.get_key_dicts():
|
||||
try:
|
||||
UnparsedUnitTestSuite.validate(data)
|
||||
unparsed = UnparsedUnitTestSuite.from_dict(data)
|
||||
except (ValidationError, JSONValidationError) as exc:
|
||||
raise YamlParseDictError(self.yaml.path, self.key, data, exc)
|
||||
|
||||
self.parse_unit_test(unparsed)
|
||||
|
||||
def _build_raw_code(self, rows, column_name_to_data_types) -> str:
|
||||
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
|
||||
rows=rows, column_name_to_data_types=column_name_to_data_types
|
||||
)
|
||||
|
||||
def _get_original_input_node(self, input: str):
|
||||
statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}")
|
||||
if statically_parsed["refs"]:
|
||||
# set refs and sources on the node object
|
||||
refs: List[RefArgs] = []
|
||||
for ref in statically_parsed["refs"]:
|
||||
name = ref.get("name")
|
||||
package = ref.get("package")
|
||||
version = ref.get("version")
|
||||
refs.append(RefArgs(name, package, version))
|
||||
# TODO: disabled lookup, versioned lookup, public models
|
||||
original_input_node = self.manifest.ref_lookup.find(
|
||||
name, package, version, self.manifest
|
||||
)
|
||||
elif statically_parsed["sources"]:
|
||||
input_package_name, input_source_name = statically_parsed["sources"][0]
|
||||
original_input_node = self.manifest.source_lookup.find(
|
||||
input_source_name, input_package_name, self.manifest
|
||||
)
|
||||
else:
|
||||
raise ParsingError("given input must be ref or source")
|
||||
|
||||
return original_input_node
|
||||
@@ -17,6 +17,7 @@ from dbt.task.run_operation import RunOperationTask
|
||||
from dbt.task.seed import SeedTask
|
||||
from dbt.task.snapshot import SnapshotTask
|
||||
from dbt.task.test import TestTask
|
||||
from dbt.task.unit_test import UnitTestTask
|
||||
|
||||
RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
|
||||
|
||||
@@ -30,6 +31,7 @@ TASK_DICT = {
|
||||
"test": TestTask,
|
||||
"run": RunTask,
|
||||
"run-operation": RunOperationTask,
|
||||
"unit-test": UnitTestTask,
|
||||
}
|
||||
|
||||
CMD_DICT = {
|
||||
@@ -42,6 +44,7 @@ CMD_DICT = {
|
||||
"test": CliCommand.TEST,
|
||||
"run": CliCommand.RUN,
|
||||
"run-operation": CliCommand.RUN_OPERATION,
|
||||
"unit-test": CliCommand.UNIT_TEST,
|
||||
}
|
||||
|
||||
|
||||
|
||||
192
core/dbt/task/unit_test.py
Normal file
192
core/dbt/task/unit_test.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from dataclasses import dataclass
|
||||
from dbt.dataclass_schema import dbtClassMixin
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
import io
|
||||
|
||||
from .compile import CompileRunner
|
||||
from .run import RunTask
|
||||
|
||||
from dbt.contracts.graph.nodes import UnitTestNode
|
||||
from dbt.contracts.graph.manifest import Manifest
|
||||
from dbt.contracts.results import TestStatus, RunResult
|
||||
from dbt.context.providers import generate_runtime_model_context
|
||||
from dbt.clients.jinja import MacroGenerator
|
||||
from dbt.events.functions import fire_event
|
||||
from dbt.events.types import (
|
||||
LogTestResult,
|
||||
LogStartLine,
|
||||
)
|
||||
from dbt.graph import ResourceTypeSelector
|
||||
from dbt.exceptions import (
|
||||
DbtInternalError,
|
||||
MissingMaterializationError,
|
||||
)
|
||||
from dbt.node_types import NodeType
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnitTestResultData(dbtClassMixin):
|
||||
should_error: bool
|
||||
adapter_response: Dict[str, Any]
|
||||
diff: Optional[str] = None
|
||||
|
||||
|
||||
class UnitTestRunner(CompileRunner):
|
||||
def describe_node(self):
|
||||
return f"{self.node.resource_type} {self.node.name}"
|
||||
|
||||
def print_result_line(self, result):
|
||||
model = result.node
|
||||
|
||||
fire_event(
|
||||
LogTestResult(
|
||||
name=model.name,
|
||||
status=str(result.status),
|
||||
index=self.node_index,
|
||||
num_models=self.num_nodes,
|
||||
execution_time=result.execution_time,
|
||||
node_info=model.node_info,
|
||||
num_failures=result.failures,
|
||||
),
|
||||
level=LogTestResult.status_to_level(str(result.status)),
|
||||
)
|
||||
|
||||
def print_start_line(self):
|
||||
fire_event(
|
||||
LogStartLine(
|
||||
description=self.describe_node(),
|
||||
index=self.node_index,
|
||||
total=self.num_nodes,
|
||||
node_info=self.node.node_info,
|
||||
)
|
||||
)
|
||||
|
||||
def before_execute(self):
|
||||
self.print_start_line()
|
||||
|
||||
def execute_unit_test(self, node: UnitTestNode, manifest: Manifest) -> UnitTestResultData:
|
||||
# generate_runtime_unit_test_context not strictly needed - this is to run the 'unit' materialization, not compile the node.compield_code
|
||||
context = generate_runtime_model_context(node, self.config, manifest)
|
||||
|
||||
materialization_macro = manifest.find_materialization_macro_by_name(
|
||||
self.config.project_name, node.get_materialization(), self.adapter.type()
|
||||
)
|
||||
|
||||
if materialization_macro is None:
|
||||
raise MissingMaterializationError(
|
||||
materialization=node.get_materialization(), adapter_type=self.adapter.type()
|
||||
)
|
||||
|
||||
if "config" not in context:
|
||||
raise DbtInternalError(
|
||||
"Invalid materialization context generated, missing config: {}".format(context)
|
||||
)
|
||||
|
||||
# generate materialization macro
|
||||
macro_func = MacroGenerator(materialization_macro, context)
|
||||
# execute materialization macro
|
||||
macro_func()
|
||||
# load results from context
|
||||
# could eventually be returned directly by materialization
|
||||
result = context["load_result"]("main")
|
||||
adapter_response = result["response"].to_dict(omit_none=True)
|
||||
table = result["table"]
|
||||
actual = self._get_unit_test_table(table, "actual")
|
||||
expected = self._get_unit_test_table(table, "expected")
|
||||
should_error = actual.rows != expected.rows
|
||||
diff = None
|
||||
if should_error:
|
||||
actual_output = self._agate_table_to_str(actual)
|
||||
expected_output = self._agate_table_to_str(expected)
|
||||
|
||||
diff = f"\n\nActual:\n{actual_output}\n\nExpected:\n{expected_output}\n"
|
||||
return UnitTestResultData(
|
||||
diff=diff,
|
||||
should_error=should_error,
|
||||
adapter_response=adapter_response,
|
||||
)
|
||||
|
||||
def execute(self, node: UnitTestNode, manifest: Manifest):
|
||||
result = self.execute_unit_test(node, manifest)
|
||||
thread_id = threading.current_thread().name
|
||||
|
||||
status = TestStatus.Pass
|
||||
message = None
|
||||
failures = 0
|
||||
if result.should_error:
|
||||
status = TestStatus.Fail
|
||||
message = result.diff
|
||||
failures = 1
|
||||
|
||||
return RunResult(
|
||||
node=node,
|
||||
status=status,
|
||||
timing=[],
|
||||
thread_id=thread_id,
|
||||
execution_time=0,
|
||||
message=message,
|
||||
adapter_response=result.adapter_response,
|
||||
failures=failures,
|
||||
)
|
||||
|
||||
def after_execute(self, result):
|
||||
self.print_result_line(result)
|
||||
|
||||
def _get_unit_test_table(self, result_table, actual_or_expected: str):
|
||||
unit_test_table = result_table.where(
|
||||
lambda row: row["actual_or_expected"] == actual_or_expected
|
||||
)
|
||||
columns = list(unit_test_table.columns.keys())
|
||||
columns.remove("actual_or_expected")
|
||||
return unit_test_table.select(columns)
|
||||
|
||||
def _agate_table_to_str(self, table) -> str:
|
||||
# Hack to get Agate table output as string
|
||||
output = io.StringIO()
|
||||
if self.config.args.output == "json":
|
||||
table.to_json(path=output)
|
||||
else:
|
||||
table.print_table(output=output, max_rows=None)
|
||||
return output.getvalue().strip()
|
||||
|
||||
|
||||
class UnitTestSelector(ResourceTypeSelector):
|
||||
def __init__(self, graph, manifest, previous_state):
|
||||
super().__init__(
|
||||
graph=graph,
|
||||
manifest=manifest,
|
||||
previous_state=previous_state,
|
||||
resource_types=[NodeType.Unit],
|
||||
)
|
||||
|
||||
|
||||
class UnitTestTask(RunTask):
|
||||
"""
|
||||
Unit testing:
|
||||
Read schema files + custom data tests and validate that
|
||||
constraints are satisfied.
|
||||
"""
|
||||
|
||||
def __init__(self, args, config, manifest, collection):
|
||||
# This will initialize the RunTask with the unit test manifest ("collection") as the manifest
|
||||
super().__init__(args, config, collection)
|
||||
self.collection = collection
|
||||
self.original_manifest = manifest
|
||||
|
||||
__test__ = False
|
||||
|
||||
def raise_on_first_error(self):
|
||||
return False
|
||||
|
||||
def get_node_selector(self) -> UnitTestSelector:
|
||||
if self.manifest is None or self.graph is None:
|
||||
raise DbtInternalError("manifest and graph must be set to get perform node selection")
|
||||
return UnitTestSelector(
|
||||
graph=self.graph,
|
||||
manifest=self.manifest,
|
||||
previous_state=self.previous_state,
|
||||
)
|
||||
|
||||
def get_runner_type(self, _):
|
||||
return UnitTestRunner
|
||||
122
tests/functional/unit_testing/test_unit_testing.py
Normal file
122
tests/functional/unit_testing/test_unit_testing.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import pytest
|
||||
from dbt.tests.util import run_dbt
|
||||
|
||||
my_model_sql = """
|
||||
SELECT
|
||||
a+b as c,
|
||||
concat(string_a, string_b) as string_c,
|
||||
not_testing, date_a,
|
||||
{{ dbt.string_literal(type_numeric()) }} as macro_call,
|
||||
{{ dbt.string_literal(var('my_test')) }} as var_call,
|
||||
{{ dbt.string_literal(env_var('MY_TEST', 'default')) }} as env_var_call,
|
||||
{{ dbt.string_literal(invocation_id) }} as invocation_id
|
||||
FROM {{ ref('my_model_a')}} my_model_a
|
||||
JOIN {{ ref('my_model_b' )}} my_model_b
|
||||
ON my_model_a.id = my_model_b.id
|
||||
"""
|
||||
|
||||
my_model_a_sql = """
|
||||
SELECT
|
||||
1 as a,
|
||||
1 as id,
|
||||
2 as not_testing,
|
||||
'a' as string_a,
|
||||
DATE '2020-01-02' as date_a
|
||||
"""
|
||||
|
||||
my_model_b_sql = """
|
||||
SELECT
|
||||
2 as b,
|
||||
1 as id,
|
||||
2 as c,
|
||||
'b' as string_b
|
||||
"""
|
||||
|
||||
test_my_model_yml = """
|
||||
unit:
|
||||
- model: my_model
|
||||
tests:
|
||||
- name: test_my_model
|
||||
given:
|
||||
- input: ref('my_model_a')
|
||||
rows:
|
||||
- {id: 1, a: 1}
|
||||
- input: ref('my_model_b')
|
||||
rows:
|
||||
- {id: 1, b: 2}
|
||||
- {id: 2, b: 2}
|
||||
expect:
|
||||
- {c: 2}
|
||||
|
||||
- name: test_my_model_empty
|
||||
given:
|
||||
- input: ref('my_model_a')
|
||||
rows: []
|
||||
- input: ref('my_model_b')
|
||||
rows:
|
||||
- {id: 1, b: 2}
|
||||
- {id: 2, b: 2}
|
||||
expect: []
|
||||
- name: test_my_model_overrides
|
||||
given:
|
||||
- input: ref('my_model_a')
|
||||
rows:
|
||||
- {id: 1, a: 1}
|
||||
- input: ref('my_model_b')
|
||||
rows:
|
||||
- {id: 1, b: 2}
|
||||
- {id: 2, b: 2}
|
||||
overrides:
|
||||
macros:
|
||||
type_numeric: override
|
||||
invocation_id: 123
|
||||
vars:
|
||||
my_test: var_override
|
||||
env_vars:
|
||||
MY_TEST: env_var_override
|
||||
expect:
|
||||
- {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123}
|
||||
- name: test_my_model_string_concat
|
||||
given:
|
||||
- input: ref('my_model_a')
|
||||
rows:
|
||||
- {id: 1, string_a: a}
|
||||
- input: ref('my_model_b')
|
||||
rows:
|
||||
- {id: 1, string_b: b}
|
||||
expect:
|
||||
- {string_c: ab}
|
||||
- name: test_my_model_datetime
|
||||
given:
|
||||
- input: ref('my_model_a')
|
||||
rows:
|
||||
- {id: 1, date_a: "2020-01-01"}
|
||||
- input: ref('my_model_b')
|
||||
rows:
|
||||
- {id: 1}
|
||||
expect:
|
||||
- {date_a: "2020-01-01"}
|
||||
"""
|
||||
|
||||
|
||||
class TestUnitTests:
|
||||
@pytest.fixture(scope="class")
|
||||
def models(self):
|
||||
return {
|
||||
"my_model.sql": my_model_sql,
|
||||
"my_model_a.sql": my_model_a_sql,
|
||||
"my_model_b.sql": my_model_b_sql,
|
||||
"test_my_model.yml": test_my_model_yml,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def project_config_update(self):
|
||||
return {"vars": {"my_test": "my_test_var"}}
|
||||
|
||||
def test_basic(self, project):
|
||||
run_dbt(["deps"])
|
||||
results = run_dbt(["run"])
|
||||
assert len(results) == 3
|
||||
|
||||
results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False)
|
||||
assert len(results) == 5
|
||||
@@ -17,6 +17,7 @@ node_type_pluralizations = {
|
||||
NodeType.Metric: "metrics",
|
||||
NodeType.Group: "groups",
|
||||
NodeType.SemanticModel: "semantic_models",
|
||||
NodeType.Unit: "unit_tests",
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user