Compare commits

...

42 Commits

Author SHA1 Message Date
Michelle Ark
3bf34dc369 update test command 2023-12-04 15:53:05 +09:00
Michelle Ark
b461ac3f3b Merge branch 'unit_testing_feature_branch' into support-complex-types-unit-testing 2023-12-04 15:50:19 +09:00
Michelle Ark
70c54a0cfc use TYPE_LABELS, expand test coverage 2023-12-04 12:02:10 +09:00
Jeremy Cohen
be01871a26 Lower as can be 2023-12-04 11:12:32 +09:00
Jeremy Cohen
bf8322158b Case-insensitive comparisons 2023-12-04 11:12:18 +09:00
Gerda Shank
ca82f54808 Enable unit testing in non-root packages (#9184) 2023-11-30 14:42:52 -05:00
Gerda Shank
bf6bffad94 Merge branch 'main' into unit_testing_feature_branch 2023-11-30 12:26:16 -05:00
Gerda Shank
3f1ed23c11 Move unit testing to test and build commands (#9108)
* Switch to using 'test' command instead of 'unit-test'

* Remove old unit test

* Put daff changes into task/test.py

* changie

* test_type:unit

* Add unit test to build and make test

* Select test_type:data

* Add unit tets to test_graph_selector_methods.py

* Fix fqn to incude path components

* Update build test

* Remove part of message in test_csv_fixtures.py that's different on Windows

* Rename build test directory
2023-11-27 09:50:39 -05:00
Michelle Ark
8197fa7b4d add adapter-zone test, expand data type coverage 2023-11-21 15:05:25 -05:00
Michelle Ark
827e35e3ee Merge branch 'unit_testing_feature_branch' into support-complex-types-unit-testing 2023-11-21 13:15:32 -05:00
Gerda Shank
e00199186e Merge branch 'main' into unit_testing_feature_branch 2023-11-21 12:31:35 -05:00
Gerda Shank
a559259e32 Merge branch 'unit_testing_feature_branch' of github.com:dbt-labs/dbt-core into unit_testing_feature_branch 2023-11-16 13:45:51 -05:00
Gerda Shank
964a7283cb Merge branch 'main' into unit_testing_feature_branch 2023-11-16 13:45:35 -05:00
Kshitij Aranke
3432436dae Fix #8652: Use seed file from disk for unit testing if rows not specified in YAML config (#9064)
Co-authored-by: Michelle Ark <MichelleArk@users.noreply.github.com>
Fix #8652: Use seed value if rows not specified
2023-11-16 16:24:55 +00:00
Michelle Ark
f1d68f402a array support 2023-11-15 18:15:15 -05:00
Michelle Ark
35f579e3eb Use daff for diff formatting in unit testing (#8984) 2023-11-15 11:27:50 -05:00
Gerda Shank
c6be2d288f Allow use of sources as unit testing inputs (#9059) 2023-11-15 10:31:32 -05:00
Gerda Shank
436dae6bb3 Merge branch 'main' into unit_testing_feature_branch 2023-11-14 17:29:01 -05:00
Jeremy Cohen
ebf48d2b50 Unit test support for state:modified and --defer (#9032)
* Add unit tests to state:modified selection

* Get defer working too yolo

* Refactor per marky suggestion

* Add changelog

* separate out unit test state tests + fix csv fixture tests

* formatting

* detect changes to fixture files with state:modified

---------

Co-authored-by: Michelle Ark <michelle.ark@dbtlabs.com>
2023-11-14 13:42:52 -05:00
Emily Rockman
3b033ac108 csv file fixtures (#9044)
* WIP

* adding tests

* fix tests

* more tests

* fix broken tests

* fix broken tests

* change to csv extension

* fix unit test yaml

* add mixed inline and file csv test

* add mixed inline and file csv test

* additional changes

* read file directly

* some cleanup and soem test fixes - wip

* fix test

* use better file searching

* fix final test

* cleanup

* use absolute path and fix tests
2023-11-09 14:35:26 -06:00
Gerda Shank
2792e0c2ce Merge branch 'main' into unit_testing_feature_branch 2023-11-08 11:26:37 -05:00
Emily Rockman
f629baa95d convert to use unit test name at top level key (#8966)
* use unit test name as top level

* remove breakpoints

* finish converting tests

* fix unit test node name

* breakpoints

* fix partial parsing bug

* comment out duplicate test

* fix test and make unique id match other uniqu id patterns

* clean up

* fix incremental test

* Update tests/functional/unit_testing/test_unit_testing.py
2023-11-03 13:59:02 -05:00
Emily Rockman
02a3dc5be3 update unit test key: unit -> unit-tests (#8988)
* WIP

* remove breakpoint

* fix tests, fix schema
2023-11-03 13:17:18 -05:00
Michelle Ark
aa91ea4c00 Support unit testing incremental models (#8891) 2023-11-01 21:08:20 -04:00
Gerda Shank
f77c2260f2 Merge branch 'main' into unit_testing_feature_branch 2023-11-01 11:26:47 -04:00
Gerda Shank
df4e4ed388 Merge branch 'main' into unit_testing_feature_branch 2023-10-12 13:35:19 -04:00
Gerda Shank
3b6f9bdef4 Enable inline csv format in unit testing (#8743) 2023-10-05 11:17:27 -04:00
Gerda Shank
5cafb96956 Merge branch 'main' into unit_testing_feature_branch 2023-10-05 10:08:09 -04:00
Gerda Shank
bb6fd3029b Merge branch 'main' into unit_testing_feature_branch 2023-10-02 15:59:24 -04:00
Gerda Shank
ac719e441c Merge branch 'main' into unit_testing_feature_branch 2023-09-26 19:57:32 -04:00
Gerda Shank
08ef90aafa Merge branch 'main' into unit_testing_feature_branch 2023-09-22 09:55:36 -04:00
Gerda Shank
3dbf0951b2 Merge branch 'main' into unit_testing_feature_branch 2023-09-13 17:36:15 -04:00
Gerda Shank
c48e34c47a Add additional functional test for unit testing selection, artifacts, etc (#8639) 2023-09-13 10:46:00 -04:00
Michelle Ark
12342ca92b unit test config: tags & meta (#8565) 2023-09-12 10:54:11 +01:00
Gerda Shank
2b376d9dba Merge branch 'main' into unit_testing_feature_branch 2023-09-11 13:01:51 -04:00
Gerda Shank
120b36e2f5 Merge branch 'main' into unit_testing_feature_branch 2023-09-07 11:08:52 -04:00
Gerda Shank
1e64f94bf0 Merge branch 'main' into unit_testing_feature_branch 2023-08-31 09:22:53 -04:00
Gerda Shank
b3bcbd5ea4 Merge branch 'main' into unit_testing_feature_branch 2023-08-30 14:51:30 -04:00
Gerda Shank
42e66fda65 8295 unit testing artifacts (#8477) 2023-08-29 17:54:54 -04:00
Gerda Shank
7ea7069999 Merge branch 'main' into unit_testing_feature_branch 2023-08-28 16:28:32 -04:00
Gerda Shank
24abc3719a Merge branch 'main' into unit_testing_feature_branch 2023-08-23 10:15:43 -04:00
Gerda Shank
181f5209a0 Initial implementation of unit testing (from pr #2911)
Co-authored-by: Michelle Ark <michelle.ark@dbtlabs.com>
2023-08-14 16:04:23 -04:00
71 changed files with 5329 additions and 82 deletions

View File

@@ -0,0 +1,6 @@
kind: Features
body: Initial implementation of unit testing
time: 2023-08-02T14:50:11.391992-04:00
custom:
Author: gshank
Issue: "8287"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Unit test manifest artifacts and selection
time: 2023-08-28T10:18:25.958929-04:00
custom:
Author: gshank
Issue: "8295"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Support config with tags & meta for unit tests
time: 2023-09-06T23:47:41.059915-04:00
custom:
Author: michelleark
Issue: "8294"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Enable inline csv fixtures in unit tests
time: 2023-09-28T16:32:05.573776-04:00
custom:
Author: gshank
Issue: "8626"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Support unit testing incremental models
time: 2023-11-01T10:18:45.341781-04:00
custom:
Author: michelleark
Issue: "8422"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Add support of csv file fixtures to unit testing
time: 2023-11-06T19:47:52.501495-06:00
custom:
Author: emmyoop
Issue: "8290"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Unit tests support --defer and state:modified
time: 2023-11-07T23:10:06.376588-05:00
custom:
Author: jtcohen6
Issue: "8517"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Support source inputs in unit tests
time: 2023-11-11T19:11:50.870494-05:00
custom:
Author: gshank
Issue: "8507"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Use daff to render diff displayed in stdout when unit test fails
time: 2023-11-14T10:15:55.689307-05:00
custom:
Author: michelleark
Issue: "8558"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Move unit testing to test command
time: 2023-11-16T14:40:06.121336-05:00
custom:
Author: gshank
Issue: "8979"

View File

@@ -0,0 +1,6 @@
kind: Features
body: Support unit tests in non-root packages
time: 2023-11-30T13:09:48.206007-05:00
custom:
Author: gshank
Issue: "8285"

View File

@@ -0,0 +1,6 @@
kind: Fixes
body: Use seed file from disk for unit testing if rows not specified in YAML config
time: 2023-11-13T15:45:35.008565Z
custom:
Author: aranke
Issue: "8652"

View File

@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add unit testing functional tests
time: 2023-09-12T19:05:06.023126-04:00
custom:
Author: gshank
Issue: "8512"

View File

@@ -2,7 +2,13 @@ from collections.abc import Hashable
from dataclasses import dataclass, field
from typing import Optional, TypeVar, Any, Type, Dict, Iterator, Tuple, Set, Union, FrozenSet
from dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ResultNode, ParsedNode
from dbt.contracts.graph.nodes import (
SourceDefinition,
ManifestNode,
ResultNode,
ParsedNode,
UnitTestSourceDefinition,
)
from dbt.contracts.relation import (
RelationType,
ComponentName,
@@ -211,7 +217,9 @@ class BaseRelation(FakeAPIObject, Hashable):
)
@classmethod
def create_from_source(cls: Type[Self], source: SourceDefinition, **kwargs: Any) -> Self:
def create_from_source(
cls: Type[Self], source: Union[SourceDefinition, UnitTestSourceDefinition], **kwargs: Any
) -> Self:
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop("column", None)
quote_policy = deep_merge(
@@ -237,7 +245,7 @@ class BaseRelation(FakeAPIObject, Hashable):
cls: Type[Self],
config: HasQuoting,
node: ManifestNode,
limit: Optional[int],
limit: Optional[int] = None,
) -> Self:
# Note that ephemeral models are based on the name.
identifier = cls.add_ephemeral_prefix(node.name)
@@ -271,8 +279,10 @@ class BaseRelation(FakeAPIObject, Hashable):
node: ResultNode,
**kwargs: Any,
) -> Self:
if node.resource_type == NodeType.Source:
if not isinstance(node, SourceDefinition):
if node.resource_type == NodeType.Source or isinstance(node, UnitTestSourceDefinition):
if not (
isinstance(node, SourceDefinition) or isinstance(node, UnitTestSourceDefinition)
):
raise DbtInternalError(
"type mismatch, expected SourceDefinition but got {}".format(type(node))
)

View File

@@ -12,7 +12,7 @@ class RelationConfigChangeAction(StrEnum):
drop = "drop"
@dataclass(frozen=True, eq=True, unsafe_hash=True)
@dataclass(frozen=True, eq=True, unsafe_hash=True) # type: ignore
class RelationConfigChange(RelationConfigBase, ABC):
action: RelationConfigChangeAction
context: Hashable # this is usually a RelationConfig, e.g. IndexConfig, but shouldn't be limited

View File

@@ -3,6 +3,7 @@ from codecs import BOM_UTF8
import agate
import datetime
import isodate
import io
import json
import dbt.utils
from typing import Iterable, List, Dict, Union, Optional, Any
@@ -137,6 +138,23 @@ def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table:
)
def json_rows_from_table(table: agate.Table) -> List[Dict[str, Any]]:
"Convert a table to a list of row dict objects"
output = io.StringIO()
table.to_json(path=output) # type: ignore
return json.loads(output.getvalue())
def list_rows_from_table(table: agate.Table) -> List[Any]:
"Convert a table to a list of lists, where the first element represents the header"
rows = [[col.name for col in table.columns]]
for row in table.rows:
rows.append(list(row.values()))
return rows
def empty_table():
"Returns an empty Agate table. To be used in place of None"

View File

@@ -330,6 +330,26 @@ class MacroGenerator(BaseMacroGenerator):
return self.call_macro(*args, **kwargs)
class UnitTestMacroGenerator(MacroGenerator):
# this makes UnitTestMacroGenerator objects callable like functions
def __init__(
self,
macro_generator: MacroGenerator,
call_return_value: Any,
) -> None:
super().__init__(
macro_generator.macro,
macro_generator.context,
macro_generator.node,
macro_generator.stack,
)
self.call_return_value = call_return_value
def __call__(self, *args, **kwargs):
with self.track_call():
return self.call_return_value
class QueryStringGenerator(BaseMacroGenerator):
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
super().__init__(context)

View File

@@ -12,7 +12,10 @@ from dbt.flags import get_flags
from dbt.adapters.factory import get_adapter
from dbt.clients import jinja
from dbt.clients.system import make_directory
from dbt.context.providers import generate_runtime_model_context
from dbt.context.providers import (
generate_runtime_model_context,
generate_runtime_unit_test_context,
)
from dbt.contracts.graph.manifest import Manifest, UniqueID
from dbt.contracts.graph.nodes import (
ManifestNode,
@@ -21,6 +24,8 @@ from dbt.contracts.graph.nodes import (
GraphMemberNode,
InjectedCTE,
SeedNode,
UnitTestNode,
UnitTestDefinition,
)
from dbt.exceptions import (
GraphDependencyNotFoundError,
@@ -44,6 +49,7 @@ def print_compile_stats(stats):
names = {
NodeType.Model: "model",
NodeType.Test: "test",
NodeType.Unit: "unit test",
NodeType.Snapshot: "snapshot",
NodeType.Analysis: "analysis",
NodeType.Macro: "macro",
@@ -91,6 +97,7 @@ def _generate_stats(manifest: Manifest):
stats[NodeType.Macro] += len(manifest.macros)
stats[NodeType.Group] += len(manifest.groups)
stats[NodeType.SemanticModel] += len(manifest.semantic_models)
stats[NodeType.Unit] += len(manifest.unit_tests)
# TODO: should we be counting dimensions + entities?
@@ -191,6 +198,8 @@ class Linker:
self.link_node(exposure, manifest)
for metric in manifest.metrics.values():
self.link_node(metric, manifest)
for unit_test in manifest.unit_tests.values():
self.link_node(unit_test, manifest)
for saved_query in manifest.saved_queries.values():
self.link_node(saved_query, manifest)
@@ -291,8 +300,10 @@ class Compiler:
manifest: Manifest,
extra_context: Dict[str, Any],
) -> Dict[str, Any]:
context = generate_runtime_model_context(node, self.config, manifest)
if isinstance(node, UnitTestNode):
context = generate_runtime_unit_test_context(node, self.config, manifest)
else:
context = generate_runtime_model_context(node, self.config, manifest)
context.update(extra_context)
if isinstance(node, GenericTestNode):
@@ -529,6 +540,9 @@ class Compiler:
the node's raw_code into compiled_code, and then calls the
recursive method to "prepend" the ctes.
"""
if isinstance(node, UnitTestDefinition):
return node
# Make sure Lexer for sqlparse 0.4.4 is initialized
from sqlparse.lexer import Lexer # type: ignore

View File

@@ -441,6 +441,7 @@ class PartialProject(RenderComponents):
snapshots: Dict[str, Any]
sources: Dict[str, Any]
tests: Dict[str, Any]
unit_tests: Dict[str, Any]
metrics: Dict[str, Any]
semantic_models: Dict[str, Any]
saved_queries: Dict[str, Any]
@@ -454,6 +455,7 @@ class PartialProject(RenderComponents):
snapshots = cfg.snapshots
sources = cfg.sources
tests = cfg.tests
unit_tests = cfg.unit_tests
metrics = cfg.metrics
semantic_models = cfg.semantic_models
saved_queries = cfg.saved_queries
@@ -515,6 +517,7 @@ class PartialProject(RenderComponents):
query_comment=query_comment,
sources=sources,
tests=tests,
unit_tests=unit_tests,
metrics=metrics,
semantic_models=semantic_models,
saved_queries=saved_queries,
@@ -625,6 +628,7 @@ class Project:
snapshots: Dict[str, Any]
sources: Dict[str, Any]
tests: Dict[str, Any]
unit_tests: Dict[str, Any]
metrics: Dict[str, Any]
semantic_models: Dict[str, Any]
saved_queries: Dict[str, Any]
@@ -658,6 +662,13 @@ class Project:
generic_test_paths.append(os.path.join(test_path, "generic"))
return generic_test_paths
@property
def fixture_paths(self):
fixture_paths = []
for test_path in self.test_paths:
fixture_paths.append(os.path.join(test_path, "fixtures"))
return fixture_paths
def __str__(self):
cfg = self.to_project_config(with_packages=True)
return str(cfg)
@@ -703,6 +714,7 @@ class Project:
"snapshots": self.snapshots,
"sources": self.sources,
"tests": self.tests,
"unit-tests": self.unit_tests,
"metrics": self.metrics,
"semantic-models": self.semantic_models,
"saved-queries": self.saved_queries,

View File

@@ -166,6 +166,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
query_comment=project.query_comment,
sources=project.sources,
tests=project.tests,
unit_tests=project.unit_tests,
metrics=project.metrics,
semantic_models=project.semantic_models,
saved_queries=project.saved_queries,
@@ -324,6 +325,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
"snapshots": self._get_config_paths(self.snapshots),
"sources": self._get_config_paths(self.sources),
"tests": self._get_config_paths(self.tests),
"unit_tests": self._get_config_paths(self.unit_tests),
"metrics": self._get_config_paths(self.metrics),
"semantic_models": self._get_config_paths(self.semantic_models),
"saved_queries": self._get_config_paths(self.saved_queries),

View File

@@ -51,6 +51,8 @@ class UnrenderedConfig(ConfigSource):
model_configs = unrendered.get("saved_queries")
elif resource_type == NodeType.Exposure:
model_configs = unrendered.get("exposures")
elif resource_type == NodeType.Unit:
model_configs = unrendered.get("unit_tests")
else:
model_configs = unrendered.get("models")
if model_configs is None:
@@ -80,6 +82,8 @@ class RenderedConfig(ConfigSource):
model_configs = self.project.saved_queries
elif resource_type == NodeType.Exposure:
model_configs = self.project.exposures
elif resource_type == NodeType.Unit:
model_configs = self.project.unit_tests
else:
model_configs = self.project.models
return model_configs

View File

@@ -1,4 +1,5 @@
import abc
from copy import deepcopy
import os
from typing import (
Callable,
@@ -17,7 +18,7 @@ from typing_extensions import Protocol
from dbt.adapters.base.column import Column
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
from dbt.config import RuntimeConfig, Project
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.context.base import contextmember, contextproperty, Var
@@ -39,6 +40,7 @@ from dbt.contracts.graph.nodes import (
RefArgs,
AccessType,
SemanticModel,
UnitTestNode,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion
@@ -572,6 +574,17 @@ class OperationRefResolver(RuntimeRefResolver):
return super().create_relation(target_model)
class RuntimeUnitTestRefResolver(RuntimeRefResolver):
def resolve(
self,
target_name: str,
target_package: Optional[str] = None,
target_version: Optional[NodeVersion] = None,
) -> RelationProxy:
target_name = f"{self.model.name}__{target_name}"
return super().resolve(target_name, target_package, target_version)
# `source` implementations
class ParseSourceResolver(BaseSourceResolver):
def resolve(self, source_name: str, table_name: str):
@@ -599,6 +612,29 @@ class RuntimeSourceResolver(BaseSourceResolver):
return self.Relation.create_from_source(target_source, limit=self.resolve_limit)
class RuntimeUnitTestSourceResolver(BaseSourceResolver):
def resolve(self, source_name: str, table_name: str):
target_source = self.manifest.resolve_source(
source_name,
table_name,
self.current_project,
self.model.package_name,
)
if target_source is None or isinstance(target_source, Disabled):
raise TargetNotFoundError(
node=self.model,
target_name=f"{source_name}.{table_name}",
target_kind="source",
disabled=(isinstance(target_source, Disabled)),
)
# For unit tests, this isn't a "real" source, it's a ModelNode taking
# the place of a source. We don't really need to return the relation here,
# we just need to set_cte, but skipping it confuses typing. We *do* need
# the relation in the "this" property.
self.model.set_cte(target_source.unique_id, None)
return self.Relation.create_ephemeral_from_node(self.config, target_source)
# metric` implementations
class ParseMetricResolver(BaseMetricResolver):
def resolve(self, name: str, package: Optional[str] = None) -> MetricReference:
@@ -676,6 +712,22 @@ class RuntimeVar(ModelConfiguredVar):
pass
class UnitTestVar(RuntimeVar):
def __init__(
self,
context: Dict[str, Any],
config: RuntimeConfig,
node: Resource,
) -> None:
config_copy = None
assert isinstance(node, UnitTestNode)
if node.overrides and node.overrides.vars:
config_copy = deepcopy(config)
config_copy.cli_vars.update(node.overrides.vars)
super().__init__(context, config_copy or config, node=node)
# Providers
class Provider(Protocol):
execute: bool
@@ -717,6 +769,16 @@ class RuntimeProvider(Provider):
metric = RuntimeMetricResolver
class RuntimeUnitTestProvider(Provider):
execute = True
Config = RuntimeConfigObject
DatabaseWrapper = RuntimeDatabaseWrapper
Var = UnitTestVar
ref = RuntimeUnitTestRefResolver
source = RuntimeUnitTestSourceResolver
metric = RuntimeMetricResolver
class OperationProvider(RuntimeProvider):
ref = OperationRefResolver
@@ -1388,7 +1450,7 @@ class ModelContext(ProviderContext):
@contextproperty()
def pre_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
@@ -1397,7 +1459,7 @@ class ModelContext(ProviderContext):
@contextproperty()
def post_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
@@ -1490,6 +1552,33 @@ class ModelContext(ProviderContext):
return None
class UnitTestContext(ModelContext):
model: UnitTestNode
@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
"""The env_var() function. Return the overriden unit test environment variable named 'var'.
If there is no unit test override, return the environment variable named 'var'.
If there is no such environment variable set, return the default.
If the default is None, raise an exception for an undefined variable.
"""
if self.model.overrides and var in self.model.overrides.env_vars:
return self.model.overrides.env_vars[var]
else:
return super().env_var(var, default)
@contextproperty()
def this(self) -> Optional[str]:
if self.model.this_input_node_unique_id:
this_node = self.manifest.expect(self.model.this_input_node_unique_id)
self.model.set_cte(this_node.unique_id, None) # type: ignore
return self.adapter.Relation.add_ephemeral_prefix(this_node.name)
return None
# This is called by '_context_for', used in 'render_with_context'
def generate_parser_model_context(
model: ManifestNode,
@@ -1534,6 +1623,24 @@ def generate_runtime_macro_context(
return ctx.to_dict()
def generate_runtime_unit_test_context(
unit_test: UnitTestNode,
config: RuntimeConfig,
manifest: Manifest,
) -> Dict[str, Any]:
ctx = UnitTestContext(unit_test, config, manifest, RuntimeUnitTestProvider(), None)
ctx_dict = ctx.to_dict()
if unit_test.overrides and unit_test.overrides.macros:
for macro_name, macro_value in unit_test.overrides.macros.items():
context_value = ctx_dict.get(macro_name)
if isinstance(context_value, MacroGenerator):
ctx_dict[macro_name] = UnitTestMacroGenerator(context_value, macro_value)
else:
ctx_dict[macro_name] = macro_value
return ctx_dict
class ExposureRefResolver(BaseResolver):
def __call__(self, *args, **kwargs) -> str:
package = None

View File

@@ -231,6 +231,7 @@ class SchemaSourceFile(BaseSourceFile):
# node patches contain models, seeds, snapshots, analyses
ndp: List[str] = field(default_factory=list)
semantic_models: List[str] = field(default_factory=list)
unit_tests: List[str] = field(default_factory=list)
saved_queries: List[str] = field(default_factory=list)
# any macro patches in this file by macro unique_id.
mcp: Dict[str, str] = field(default_factory=dict)

View File

@@ -42,6 +42,7 @@ from dbt.contracts.graph.nodes import (
SemanticModel,
SourceDefinition,
UnpatchedSourceDefinition,
UnitTestDefinition,
)
from dbt.contracts.graph.unparsed import SourcePatch, NodeVersion, UnparsedVersion
from dbt.contracts.graph.manifest_upgrade import upgrade_manifest_json
@@ -799,6 +800,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
disabled: MutableMapping[str, List[GraphMemberNode]] = field(default_factory=dict)
env_vars: MutableMapping[str, str] = field(default_factory=dict)
semantic_models: MutableMapping[str, SemanticModel] = field(default_factory=dict)
unit_tests: MutableMapping[str, UnitTestDefinition] = field(default_factory=dict)
saved_queries: MutableMapping[str, SavedQuery] = field(default_factory=dict)
_doc_lookup: Optional[DocLookup] = field(
@@ -960,6 +962,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
files={k: _deepcopy(v) for k, v in self.files.items()},
state_check=_deepcopy(self.state_check),
semantic_models={k: _deepcopy(v) for k, v in self.semantic_models.items()},
unit_tests={k: _deepcopy(v) for k, v in self.unit_tests.items()},
saved_queries={k: _deepcopy(v) for k, v in self.saved_queries.items()},
)
copy.build_flat_graph()
@@ -1030,6 +1033,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
parent_map=self.parent_map,
group_map=self.group_map,
semantic_models=self.semantic_models,
unit_tests=self.unit_tests,
saved_queries=self.saved_queries,
)
@@ -1049,6 +1053,8 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
return self.metrics[unique_id]
elif unique_id in self.semantic_models:
return self.semantic_models[unique_id]
elif unique_id in self.unit_tests:
return self.unit_tests[unique_id]
elif unique_id in self.saved_queries:
return self.saved_queries[unique_id]
else:
@@ -1493,6 +1499,12 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
self.semantic_models[semantic_model.unique_id] = semantic_model
source_file.semantic_models.append(semantic_model.unique_id)
def add_unit_test(self, source_file: SchemaSourceFile, unit_test: UnitTestDefinition):
if unit_test.unique_id in self.unit_tests:
raise DuplicateResourceNameError(unit_test, self.unit_tests[unit_test.unique_id])
self.unit_tests[unit_test.unique_id] = unit_test
source_file.unit_tests.append(unit_test.unique_id)
def add_saved_query(self, source_file: SchemaSourceFile, saved_query: SavedQuery) -> None:
_check_duplicates(saved_query, self.saved_queries)
self.saved_queries[saved_query.unique_id] = saved_query
@@ -1525,6 +1537,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
self.disabled,
self.env_vars,
self.semantic_models,
self.unit_tests,
self.saved_queries,
self._doc_lookup,
self._source_lookup,
@@ -1607,6 +1620,11 @@ class WritableManifest(ArtifactMixin):
description="Metadata about the manifest",
)
)
unit_tests: Mapping[UniqueID, UnitTestDefinition] = field(
metadata=dict(
description="The unit tests defined in the project",
)
)
@classmethod
def compatible_previous_versions(cls) -> Iterable[Tuple[str, int]]:

View File

@@ -145,6 +145,9 @@ def upgrade_manifest_json(manifest: dict, manifest_schema_version: int) -> dict:
manifest["groups"] = {}
if "group_map" not in manifest:
manifest["group_map"] = {}
# add unit_tests key
if "unit_tests" not in manifest:
manifest["unit_tests"] = {}
for metric_content in manifest.get("metrics", {}).values():
# handle attr renames + value translation ("expression" -> "derived")
metric_content = upgrade_ref_content(metric_content)

View File

@@ -554,6 +554,11 @@ class ModelConfig(NodeConfig):
)
@dataclass
class UnitTestNodeConfig(NodeConfig):
expected_rows: List[Dict[str, Any]] = field(default_factory=list)
@dataclass
class SeedConfig(NodeConfig):
materialized: str = "seed"
@@ -726,6 +731,18 @@ class SnapshotConfig(EmptySnapshotConfig):
return self.from_dict(data)
@dataclass
class UnitTestConfig(BaseConfig):
tags: Union[str, List[str]] = field(
default_factory=list_str,
metadata=metas(ShowBehavior.Hide, MergeBehavior.Append, CompareBehavior.Exclude),
)
meta: Dict[str, Any] = field(
default_factory=dict,
metadata=MergeBehavior.Update.meta(),
)
RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
NodeType.Metric: MetricConfig,
NodeType.SemanticModel: SemanticModelConfig,
@@ -736,6 +753,7 @@ RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
NodeType.Test: TestConfig,
NodeType.Model: NodeConfig,
NodeType.Snapshot: SnapshotConfig,
NodeType.Unit: UnitTestConfig,
}

View File

@@ -35,12 +35,18 @@ from dbt.contracts.graph.unparsed import (
UnparsedSourceDefinition,
UnparsedSourceTableDefinition,
UnparsedColumn,
UnitTestOverrides,
UnitTestInputFixture,
UnitTestOutputFixture,
)
from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.graph.semantic_layer_common import WhereFilterIntersection
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
from dbt.events.functions import warn_or_error
from dbt.exceptions import ParsingError, ContractBreakingChangeError
from dbt.exceptions import (
ParsingError,
ContractBreakingChangeError,
)
from dbt.events.types import (
SeedIncreased,
SeedExceedsLimitSamePath,
@@ -72,6 +78,8 @@ from .model_config import (
EmptySnapshotConfig,
SnapshotConfig,
SemanticModelConfig,
UnitTestConfig,
UnitTestNodeConfig,
SavedQueryConfig,
)
@@ -1054,6 +1062,78 @@ class GenericTestNode(TestShouldStoreFailures, CompiledNode, HasTestMetadata):
return "generic"
@dataclass
class UnitTestSourceDefinition(ModelNode):
source_name: str = "undefined"
quoting: Quoting = field(default_factory=Quoting)
@property
def search_name(self):
return f"{self.source_name}.{self.name}"
@dataclass
class UnitTestNode(CompiledNode):
resource_type: Literal[NodeType.Unit]
tested_node_unique_id: Optional[str] = None
this_input_node_unique_id: Optional[str] = None
overrides: Optional[UnitTestOverrides] = None
config: UnitTestNodeConfig = field(default_factory=UnitTestNodeConfig)
@dataclass
class UnitTestDefinitionMandatory:
model: str
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
@dataclass
class UnitTestDefinition(NodeInfoMixin, GraphNode, UnitTestDefinitionMandatory):
description: str = ""
overrides: Optional[UnitTestOverrides] = None
depends_on: DependsOn = field(default_factory=DependsOn)
config: UnitTestConfig = field(default_factory=UnitTestConfig)
checksum: Optional[str] = None
schema: Optional[str] = None
@property
def build_path(self):
# TODO: is this actually necessary?
return self.original_file_path
@property
def compiled_path(self):
# TODO: is this actually necessary?
return self.original_file_path
@property
def depends_on_nodes(self):
return self.depends_on.nodes
@property
def tags(self) -> List[str]:
tags = self.config.tags
return [tags] if isinstance(tags, str) else tags
def build_unit_test_checksum(self, project_root: str, fixture_paths: List[str]):
# everything except 'description'
data = f"{self.model}-{self.given}-{self.expect}-{self.overrides}"
# include underlying fixture data
for input in self.given:
if input.fixture:
data += f"-{input.get_rows(project_root, fixture_paths)}"
self.checksum = hashlib.new("sha256", data.encode("utf-8")).hexdigest()
def same_contents(self, other: Optional["UnitTestDefinition"]) -> bool:
if other is None:
return False
return self.checksum == other.checksum
# ====================================
# Snapshot node
# ====================================
@@ -1310,6 +1390,10 @@ class SourceDefinition(NodeInfoMixin, ParsedSourceMandatory):
def search_name(self):
return f"{self.source_name}.{self.name}"
@property
def group(self):
return None
# ====================================
# Exposure node
@@ -1849,6 +1933,7 @@ ManifestSQLNode = Union[
SqlNode,
GenericTestNode,
SnapshotNode,
UnitTestNode,
]
# All SQL nodes plus SeedNode (csv files)
@@ -1869,6 +1954,7 @@ GraphMemberNode = Union[
Metric,
SavedQuery,
SemanticModel,
UnitTestDefinition,
]
# All "nodes" (or node-like objects) in this file
@@ -1879,7 +1965,4 @@ Resource = Union[
Group,
]
TestNode = Union[
SingularTestNode,
GenericTestNode,
]
TestNode = Union[SingularTestNode, GenericTestNode]

View File

@@ -1,7 +1,10 @@
import datetime
import re
import csv
from io import StringIO
from dbt import deprecations
from dbt.clients.system import find_matching
from dbt.node_types import NodeType
from dbt.contracts.graph.semantic_models import (
Defaults,
@@ -759,3 +762,101 @@ def normalize_date(d: Optional[datetime.date]) -> Optional[datetime.datetime]:
dt = dt.astimezone()
return dt
class UnitTestFormat(StrEnum):
CSV = "csv"
Dict = "dict"
class UnitTestFixture:
@property
def format(self) -> UnitTestFormat:
return UnitTestFormat.Dict
@property
def rows(self) -> Optional[Union[str, List[Dict[str, Any]]]]:
return None
@property
def fixture(self) -> Optional[str]:
return None
def get_rows(self, project_root: str, paths: List[str]) -> List[Dict[str, Any]]:
if self.format == UnitTestFormat.Dict:
assert isinstance(self.rows, List)
return self.rows
elif self.format == UnitTestFormat.CSV:
rows = []
if self.fixture is not None:
assert isinstance(self.fixture, str)
file_path = self.get_fixture_path(self.fixture, project_root, paths)
with open(file_path, newline="") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
rows.append(row)
else: # using inline csv
assert isinstance(self.rows, str)
dummy_file = StringIO(self.rows)
reader = csv.DictReader(dummy_file)
rows = []
for row in reader:
rows.append(row)
return rows
def get_fixture_path(self, fixture: str, project_root: str, paths: List[str]) -> str:
fixture_path = f"{fixture}.csv"
matches = find_matching(project_root, paths, fixture_path)
if len(matches) == 0:
raise ParsingError(f"Could not find fixture file {fixture} for unit test")
elif len(matches) > 1:
raise ParsingError(
f"Found multiple fixture files named {fixture} at {[d['relative_path'] for d in matches]}. Please use a unique name for each fixture file."
)
return matches[0]["absolute_path"]
def validate_fixture(self, fixture_type, test_name) -> None:
if self.format == UnitTestFormat.Dict and not isinstance(self.rows, list):
raise ParsingError(
f"Unit test {test_name} has {fixture_type} rows which do not match format {self.format}"
)
if self.format == UnitTestFormat.CSV and not (
isinstance(self.rows, str) or isinstance(self.fixture, str)
):
raise ParsingError(
f"Unit test {test_name} has {fixture_type} rows or fixtures which do not match format {self.format}. Expected string."
)
@dataclass
class UnitTestInputFixture(dbtClassMixin, UnitTestFixture):
input: str
rows: Optional[Union[str, List[Dict[str, Any]]]] = None
format: UnitTestFormat = UnitTestFormat.Dict
fixture: Optional[str] = None
@dataclass
class UnitTestOutputFixture(dbtClassMixin, UnitTestFixture):
rows: Optional[Union[str, List[Dict[str, Any]]]] = None
format: UnitTestFormat = UnitTestFormat.Dict
fixture: Optional[str] = None
@dataclass
class UnitTestOverrides(dbtClassMixin):
macros: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
env_vars: Dict[str, Any] = field(default_factory=dict)
@dataclass
class UnparsedUnitTest(dbtClassMixin):
name: str
model: str # name of the model being unit tested
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
description: str = ""
overrides: Optional[UnitTestOverrides] = None
config: Dict[str, Any] = field(default_factory=dict)

View File

@@ -217,6 +217,7 @@ class Project(dbtClassMixin, Replaceable):
analyses: Dict[str, Any] = field(default_factory=dict)
sources: Dict[str, Any] = field(default_factory=dict)
tests: Dict[str, Any] = field(default_factory=dict)
unit_tests: Dict[str, Any] = field(default_factory=dict)
metrics: Dict[str, Any] = field(default_factory=dict)
semantic_models: Dict[str, Any] = field(default_factory=dict)
saved_queries: Dict[str, Any] = field(default_factory=dict)
@@ -259,6 +260,7 @@ class Project(dbtClassMixin, Replaceable):
"semantic_models": "semantic-models",
"saved_queries": "saved-queries",
"dbt_cloud": "dbt-cloud",
"unit_tests": "unit-tests",
}
@classmethod

View File

@@ -1,7 +1,12 @@
import threading
from dbt.contracts.graph.unparsed import FreshnessThreshold
from dbt.contracts.graph.nodes import CompiledNode, SourceDefinition, ResultNode
from dbt.contracts.graph.nodes import (
CompiledNode,
SourceDefinition,
ResultNode,
UnitTestDefinition,
)
from dbt.contracts.util import (
BaseArtifactMetadata,
ArtifactMixin,
@@ -153,7 +158,7 @@ class BaseResult(dbtClassMixin):
@dataclass
class NodeResult(BaseResult):
node: ResultNode
node: Union[ResultNode, UnitTestDefinition]
@dataclass

View File

@@ -2227,7 +2227,7 @@ class SQLCompiledPath(InfoLevel):
return "Z026"
def message(self) -> str:
return f" compiled Code at {self.path}"
return f" compiled code at {self.path}"
class CheckNodeTestFailure(InfoLevel):

View File

@@ -192,7 +192,7 @@ class DbtDatabaseError(DbtRuntimeError):
lines = []
if hasattr(self.node, "build_path") and self.node.build_path:
lines.append(f"compiled Code at {self.node.build_path}")
lines.append(f"compiled code at {self.node.build_path}")
return lines + DbtRuntimeError.process_stack(self)
@@ -1220,6 +1220,12 @@ class InvalidAccessTypeError(ParsingError):
super().__init__(msg=msg)
class InvalidUnitTestGivenInput(ParsingError):
def __init__(self, input: str) -> None:
msg = f"Unit test given inputs must be either a 'ref', 'source' or 'this' call. Got: '{input}'."
super().__init__(msg=msg)
class SameKeyNestedError(CompilationError):
def __init__(self) -> None:
msg = "Test cannot have the same key at the top-level and in config"

View File

@@ -31,6 +31,8 @@ def can_select_indirectly(node):
"""
if node.resource_type == NodeType.Test:
return True
elif node.resource_type == NodeType.Unit:
return True
else:
return False
@@ -171,9 +173,12 @@ class NodeSelector(MethodManager):
elif unique_id in self.manifest.semantic_models:
semantic_model = self.manifest.semantic_models[unique_id]
return semantic_model.config.enabled
elif unique_id in self.manifest.unit_tests:
return True
elif unique_id in self.manifest.saved_queries:
saved_query = self.manifest.saved_queries[unique_id]
return saved_query.config.enabled
node = self.manifest.nodes[unique_id]
if self.include_empty_nodes:
@@ -199,6 +204,8 @@ class NodeSelector(MethodManager):
node = self.manifest.metrics[unique_id]
elif unique_id in self.manifest.semantic_models:
node = self.manifest.semantic_models[unique_id]
elif unique_id in self.manifest.unit_tests:
node = self.manifest.unit_tests[unique_id]
elif unique_id in self.manifest.saved_queries:
node = self.manifest.saved_queries[unique_id]
else:
@@ -246,8 +253,11 @@ class NodeSelector(MethodManager):
)
for unique_id in self.graph.select_successors(selected):
if unique_id in self.manifest.nodes:
node = self.manifest.nodes[unique_id]
if unique_id in self.manifest.nodes or unique_id in self.manifest.unit_tests:
if unique_id in self.manifest.nodes:
node = self.manifest.nodes[unique_id]
elif unique_id in self.manifest.unit_tests:
node = self.manifest.unit_tests[unique_id] # type: ignore
if can_select_indirectly(node):
# should we add it in directly?
if indirect_selection == IndirectSelection.Eager or set(

View File

@@ -18,6 +18,7 @@ from dbt.contracts.graph.nodes import (
ResultNode,
ManifestNode,
ModelNode,
UnitTestDefinition,
SavedQuery,
SemanticModel,
)
@@ -101,7 +102,9 @@ def is_selected_node(fqn: List[str], node_selector: str, is_versioned: bool) ->
return True
SelectorTarget = Union[SourceDefinition, ManifestNode, Exposure, Metric]
SelectorTarget = Union[
SourceDefinition, ManifestNode, Exposure, Metric, SemanticModel, UnitTestDefinition
]
class SelectorMethod(metaclass=abc.ABCMeta):
@@ -148,6 +151,21 @@ class SelectorMethod(metaclass=abc.ABCMeta):
continue
yield unique_id, metric
def unit_tests(
self, included_nodes: Set[UniqueId]
) -> Iterator[Tuple[UniqueId, UnitTestDefinition]]:
for unique_id, unit_test in self.manifest.unit_tests.items():
unique_id = UniqueId(unique_id)
if unique_id not in included_nodes:
continue
yield unique_id, unit_test
def parsed_and_unit_nodes(self, included_nodes: Set[UniqueId]):
yield from chain(
self.parsed_nodes(included_nodes),
self.unit_tests(included_nodes),
)
def semantic_model_nodes(
self, included_nodes: Set[UniqueId]
) -> Iterator[Tuple[UniqueId, SemanticModel]]:
@@ -176,6 +194,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
self.source_nodes(included_nodes),
self.exposure_nodes(included_nodes),
self.metric_nodes(included_nodes),
self.unit_tests(included_nodes),
self.semantic_model_nodes(included_nodes),
)
@@ -192,6 +211,7 @@ class SelectorMethod(metaclass=abc.ABCMeta):
self.parsed_nodes(included_nodes),
self.exposure_nodes(included_nodes),
self.metric_nodes(included_nodes),
self.unit_tests(included_nodes),
self.semantic_model_nodes(included_nodes),
self.saved_query_nodes(included_nodes),
)
@@ -519,30 +539,37 @@ class TestNameSelectorMethod(SelectorMethod):
__test__ = False
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
for node, real_node in self.parsed_nodes(included_nodes):
if real_node.resource_type == NodeType.Test and hasattr(real_node, "test_metadata"):
if fnmatch(real_node.test_metadata.name, selector): # type: ignore[union-attr]
yield node
for unique_id, node in self.parsed_and_unit_nodes(included_nodes):
if node.resource_type == NodeType.Test and hasattr(node, "test_metadata"):
if fnmatch(node.test_metadata.name, selector): # type: ignore[union-attr]
yield unique_id
elif node.resource_type == NodeType.Unit:
if fnmatch(node.name, selector):
yield unique_id
class TestTypeSelectorMethod(SelectorMethod):
__test__ = False
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
search_type: Type
search_types: List[Any]
# continue supporting 'schema' + 'data' for backwards compatibility
if selector in ("generic", "schema"):
search_type = GenericTestNode
elif selector in ("singular", "data"):
search_type = SingularTestNode
search_types = [GenericTestNode]
elif selector in ("data"):
search_types = [GenericTestNode, SingularTestNode]
elif selector in ("singular"):
search_types = [SingularTestNode]
elif selector in ("unit"):
search_types = [UnitTestDefinition]
else:
raise DbtRuntimeError(
f'Invalid test type selector {selector}: expected "generic" or ' '"singular"'
f'Invalid test type selector {selector}: expected "generic", "singular", "unit", or "data"'
)
for node, real_node in self.parsed_nodes(included_nodes):
if isinstance(real_node, search_type):
yield node
for unique_id, node in self.parsed_and_unit_nodes(included_nodes):
if isinstance(node, tuple(search_types)):
yield unique_id
class StateSelectorMethod(SelectorMethod):
@@ -618,7 +645,9 @@ class StateSelectorMethod(SelectorMethod):
def check_modified_content(
self, old: Optional[SelectorTarget], new: SelectorTarget, adapter_type: str
) -> bool:
if isinstance(new, (SourceDefinition, Exposure, Metric, SemanticModel)):
if isinstance(
new, (SourceDefinition, Exposure, Metric, SemanticModel, UnitTestDefinition)
):
# these all overwrite `same_contents`
different_contents = not new.same_contents(old) # type: ignore
else:
@@ -709,6 +738,10 @@ class StateSelectorMethod(SelectorMethod):
previous_node = manifest.exposures[node]
elif node in manifest.metrics:
previous_node = manifest.metrics[node]
elif node in manifest.semantic_models:
previous_node = manifest.semantic_models[node]
elif node in manifest.unit_tests:
previous_node = manifest.unit_tests[node]
keyword_args = {}
if checker.__name__ in [

View File

@@ -99,6 +99,7 @@ class SelectionCriteria:
except ValueError as exc:
raise InvalidSelectorError(f"'{method_parts[0]}' is not a valid method name") from exc
# Following is for cases like config.severity and config.materialized
method_arguments: List[str] = method_parts[1:]
return method_name, method_arguments

View File

@@ -12,3 +12,31 @@
{{ "limit " ~ limit if limit != none }}
) dbt_internal_test
{%- endmacro %}
{% macro get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
{{ adapter.dispatch('get_unit_test_sql', 'dbt')(main_sql, expected_fixture_sql, expected_column_names) }}
{%- endmacro %}
{% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
-- Build actual result given inputs
with dbt_internal_unit_test_actual as (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as {{ adapter.quote("actual_or_expected") }}
from (
{{ main_sql }}
) _dbt_internal_unit_test_actual
),
-- Build expected result
dbt_internal_unit_test_expected as (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as {{ adapter.quote("actual_or_expected") }}
from (
{{ expected_fixture_sql }}
) _dbt_internal_unit_test_expected
)
-- Union actual and expected results
select * from dbt_internal_unit_test_actual
union all
select * from dbt_internal_unit_test_expected
{%- endmacro %}

View File

@@ -0,0 +1,28 @@
{%- materialization unit, default -%}
{% set relations = [] %}
{% set expected_rows = config.get('expected_rows') %}
{% set tested_expected_column_names = expected_rows[0].keys() if (expected_rows | length ) > 0 else get_columns_in_query(sql) %} %}
{%- set target_relation = this.incorporate(type='table') -%}
{%- set temp_relation = make_temp_relation(target_relation)-%}
{% set columns_in_relation = adapter.get_column_schema_from_query(get_empty_subquery_sql(sql)) %}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%}
{%- endfor -%}
{% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %}
{% call statement('main', fetch_result=True) -%}
{{ unit_test_sql }}
{%- endcall %}
{% do adapter.drop_relation(temp_relation) %}
{{ return({'relations': relations}) }}
{%- endmaterialization -%}

View File

@@ -0,0 +1,80 @@
{% macro get_fixture_sql(rows, column_name_to_data_types) %}
-- Fixture for {{ model.name }}
{% set default_row = {} %}
{%- if not column_name_to_data_types -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{% set test_sql = get_empty_subquery_sql("select * from " + (this| string)) %}
{% set columns_in_relation = adapter.get_column_schema_from_query(test_sql) %}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{#-- This needs to be a case-insensitive comparison --#}
{%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%}
{%- endfor -%}
{%- endif -%}
{%- if not column_name_to_data_types -%}
{{ exceptions.raise_compiler_error("Not able to get columns for unit test '" ~ model.name ~ "' from relation " ~ this) }}
{%- endif -%}
{%- for column_name, column_type in column_name_to_data_types.items() -%}
{%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%}
{%- endfor -%}
{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
select
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
{% endif %}
{%- endfor -%}
{%- if (rows | length) == 0 -%}
select
{%- for column_name, column_value in default_row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- endfor %}
limit 0
{%- endif -%}
{% endmacro %}
{% macro get_expected_sql(rows, column_name_to_data_types) %}
{%- if (rows | length) == 0 -%}
select * FROM dbt_internal_unit_test_actual
limit 0
{%- else -%}
{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
select
{%- for column_name, column_value in row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
{% endif %}
{%- endfor -%}
{%- endif -%}
{% endmacro %}
{%- macro format_row(row, column_name_to_data_types) -%}
{#-- wrap yaml strings in quotes, apply cast --#}
{%- for column_name, column_value in row.items() -%}
{% set row_update = {column_name: column_value} %}
{%- if column_value is string -%}
{%- set row_update = {column_name: safe_cast(dbt.string_literal(dbt.escape_single_quotes(column_value)), column_name_to_data_types[column_name]) } -%}
{%- elif column_value is none -%}
{%- set row_update = {column_name: safe_cast('null', column_name_to_data_types[column_name]) } -%}
{%- else -%}
{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%}
{%- endif -%}
{%- do row.update(row_update) -%}
{%- endfor -%}
{%- endmacro -%}

View File

@@ -35,6 +35,7 @@ class NodeType(StrEnum):
Group = "group"
SavedQuery = "saved_query"
SemanticModel = "semantic_model"
Unit = "unit_test"
@classmethod
def executable(cls) -> List["NodeType"]:

View File

@@ -1765,8 +1765,9 @@ def write_semantic_manifest(manifest: Manifest, target_path: str) -> None:
semantic_manifest.write_json_to_file(path)
def write_manifest(manifest: Manifest, target_path: str):
path = os.path.join(target_path, MANIFEST_FILE_NAME)
def write_manifest(manifest: Manifest, target_path: str, which: Optional[str] = None):
file_name = MANIFEST_FILE_NAME
path = os.path.join(target_path, file_name)
manifest.write(path)
write_semantic_manifest(manifest=manifest, target_path=target_path)

View File

@@ -608,7 +608,7 @@ class PartialParsing:
self.saved_manifest.files.pop(file_id)
# For each key in a schema file dictionary, process the changed, deleted, and added
# elemnts for the key lists
# elements for the key lists
def handle_schema_file_changes(self, schema_file, saved_yaml_dict, new_yaml_dict):
# loop through comparing previous dict_from_yaml with current dict_from_yaml
# Need to do the deleted/added/changed thing, just like the files lists
@@ -681,6 +681,7 @@ class PartialParsing:
handle_change("metrics", self.delete_schema_metric)
handle_change("groups", self.delete_schema_group)
handle_change("semantic_models", self.delete_schema_semantic_model)
handle_change("unit_tests", self.delete_schema_unit_test)
handle_change("saved_queries", self.delete_schema_saved_query)
def _handle_element_change(
@@ -938,6 +939,17 @@ class PartialParsing:
elif unique_id in self.saved_manifest.disabled:
self.delete_disabled(unique_id, schema_file.file_id)
def delete_schema_unit_test(self, schema_file, unit_test_dict):
unit_test_name = unit_test_dict["name"]
unit_tests = schema_file.unit_tests.copy()
for unique_id in unit_tests:
if unique_id in self.saved_manifest.unit_tests:
unit_test = self.saved_manifest.unit_tests[unique_id]
if unit_test.name == unit_test_name:
self.saved_manifest.unit_tests.pop(unique_id)
schema_file.unit_tests.remove(unique_id)
# No disabled unit tests yet
def get_schema_element(self, elem_list, elem_name):
for element in elem_list:
if "name" in element and element["name"] == elem_name:

View File

@@ -139,6 +139,11 @@ class SchemaParser(SimpleParser[YamlBlock, ModelNode]):
self.root_project, self.project.project_name, self.schema_yaml_vars
)
# This is unnecessary, but mypy was requiring it. Clean up parser code so
# we don't have to do this.
def parse_from_dict(self, dct):
pass
@classmethod
def get_compiled_path(cls, block: FileBlock) -> str:
# should this raise an error?
@@ -226,6 +231,12 @@ class SchemaParser(SimpleParser[YamlBlock, ModelNode]):
semantic_model_parser = SemanticModelParser(self, yaml_block)
semantic_model_parser.parse()
if "unit_tests" in dct:
from dbt.parser.unit_tests import UnitTestParser
unit_test_parser = UnitTestParser(self, yaml_block)
unit_test_parser.parse()
if "saved_queries" in dct:
from dbt.parser.schema_yaml_readers import SavedQueryParser
@@ -251,12 +262,13 @@ class ParseResult:
# abstract base class (ABCMeta)
# Four subclasses: MetricParser, ExposureParser, GroupParser, SourceParser, PatchParser
# Many subclasses: MetricParser, ExposureParser, GroupParser, SourceParser,
# PatchParser, SemanticModelParser, SavedQueryParser, UnitTestParser
class YamlReader(metaclass=ABCMeta):
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, key: str) -> None:
self.schema_parser = schema_parser
# key: models, seeds, snapshots, sources, macros,
# analyses, exposures
# analyses, exposures, unit_tests
self.key = key
self.yaml = yaml
self.schema_yaml_vars = SchemaYamlVars()
@@ -304,7 +316,7 @@ class YamlReader(metaclass=ABCMeta):
if coerce_dict_str(entry) is None:
raise YamlParseListError(path, self.key, data, "expected a dict with string keys")
if "name" not in entry:
if "name" not in entry and "model" not in entry:
raise ParsingError("Entry did not contain a name")
# Render the data (except for tests and descriptions).

View File

@@ -0,0 +1,351 @@
from csv import DictReader
from pathlib import Path
from typing import List, Set, Dict, Any
import os
from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore
from dbt import utils
from dbt.config import RuntimeConfig
from dbt.context.context_config import ContextConfig
from dbt.context.providers import generate_parse_exposure, get_rendered
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.model_config import UnitTestNodeConfig, ModelConfig
from dbt.contracts.graph.nodes import (
ModelNode,
UnitTestNode,
UnitTestDefinition,
DependsOn,
UnitTestConfig,
UnitTestSourceDefinition,
)
from dbt.contracts.graph.unparsed import UnparsedUnitTest
from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput
from dbt.graph import UniqueId
from dbt.node_types import NodeType
from dbt.parser.schemas import (
SchemaParser,
YamlBlock,
ValidationError,
JSONValidationError,
YamlParseDictError,
YamlReader,
ParseResult,
)
from dbt.utils import get_pseudo_test_path
class UnitTestManifestLoader:
def __init__(self, manifest, root_project, selected) -> None:
self.manifest: Manifest = manifest
self.root_project: RuntimeConfig = root_project
# selected comes from the initial selection against a "regular" manifest
self.selected: Set[UniqueId] = selected
self.unit_test_manifest = Manifest(macros=manifest.macros)
def load(self) -> Manifest:
for unique_id in self.selected:
if unique_id in self.manifest.unit_tests:
unit_test_case = self.manifest.unit_tests[unique_id]
self.parse_unit_test_case(unit_test_case)
return self.unit_test_manifest
def parse_unit_test_case(self, test_case: UnitTestDefinition):
# Create unit test node based on the node being tested
tested_node = self.manifest.ref_lookup.perform_lookup(
f"model.{test_case.package_name}.{test_case.model}", self.manifest
)
assert isinstance(tested_node, ModelNode)
# Create UnitTestNode based on model being tested. Since selection has
# already been done, we don't have to care about fields that are necessary
# for selection.
# Note: no depends_on, that's added later using input nodes
name = f"{test_case.model}__{test_case.name}"
unit_test_node = UnitTestNode(
name=name,
resource_type=NodeType.Unit,
package_name=test_case.package_name,
path=get_pseudo_test_path(name, test_case.original_file_path),
original_file_path=test_case.original_file_path,
unique_id=test_case.unique_id,
config=UnitTestNodeConfig(
materialized="unit",
expected_rows=test_case.expect.get_rows(
self.root_project.project_root, self.root_project.fixture_paths
),
),
raw_code=tested_node.raw_code,
database=tested_node.database,
schema=tested_node.schema,
alias=name,
fqn=test_case.unique_id.split("."),
checksum=FileHash.empty(),
tested_node_unique_id=tested_node.unique_id,
overrides=test_case.overrides,
)
ctx = generate_parse_exposure(
unit_test_node, # type: ignore
self.root_project,
self.manifest,
test_case.package_name,
)
get_rendered(unit_test_node.raw_code, ctx, unit_test_node, capture_macros=True)
# unit_test_node now has a populated refs/sources
self.unit_test_manifest.nodes[unit_test_node.unique_id] = unit_test_node
# Now create input_nodes for the test inputs
"""
given:
- input: ref('my_model_a')
rows: []
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
"""
# Add the model "input" nodes, consisting of all referenced models in the unit test.
# This creates an ephemeral model for every input in every test, so there may be multiple
# input models substituting for the same input ref'd model. Note that since these are
# always "ephemeral" they just wrap the tested_node SQL in additional CTEs. No actual table
# or view is created.
for given in test_case.given:
# extract the original_input_node from the ref in the "input" key of the given list
original_input_node = self._get_original_input_node(given.input, tested_node)
project_root = self.root_project.project_root
common_fields = {
"resource_type": NodeType.Model,
"package_name": test_case.package_name,
"original_file_path": original_input_node.original_file_path,
"config": ModelConfig(materialized="ephemeral"),
"database": original_input_node.database,
"alias": original_input_node.identifier,
"schema": original_input_node.schema,
"fqn": original_input_node.fqn,
"checksum": FileHash.empty(),
"raw_code": self._build_fixture_raw_code(
given.get_rows(project_root, self.root_project.fixture_paths), None
),
}
if original_input_node.resource_type in (
NodeType.Model,
NodeType.Seed,
NodeType.Snapshot,
):
input_name = f"{unit_test_node.name}__{original_input_node.name}"
input_node = ModelNode(
**common_fields,
unique_id=f"model.{test_case.package_name}.{input_name}",
name=input_name,
path=original_input_node.path,
)
elif original_input_node.resource_type == NodeType.Source:
# We are reusing the database/schema/identifier from the original source,
# but that shouldn't matter since this acts as an ephemeral model which just
# wraps a CTE around the unit test node.
input_name = f"{unit_test_node.name}__{original_input_node.search_name}__{original_input_node.name}"
input_node = UnitTestSourceDefinition(
**common_fields,
unique_id=f"model.{test_case.package_name}.{input_name}",
name=original_input_node.name, # must be the same name for source lookup to work
path=input_name + ".sql", # for writing out compiled_code
source_name=original_input_node.source_name, # needed for source lookup
)
# Sources need to go in the sources dictionary in order to create the right lookup
self.unit_test_manifest.sources[input_node.unique_id] = input_node # type: ignore
# Both ModelNode and UnitTestSourceDefinition need to go in nodes dictionary
self.unit_test_manifest.nodes[input_node.unique_id] = input_node
# Populate this_input_node_unique_id if input fixture represents node being tested
if original_input_node == tested_node:
unit_test_node.this_input_node_unique_id = input_node.unique_id
# Add unique ids of input_nodes to depends_on
unit_test_node.depends_on.nodes.append(input_node.unique_id)
def _build_fixture_raw_code(self, rows, column_name_to_data_types) -> str:
# We're not currently using column_name_to_data_types, but leaving here for
# possible future use.
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)
def _get_original_input_node(self, input: str, tested_node: ModelNode):
"""
Returns the original input node as defined in the project given an input reference
and the node being tested.
input: str representing how input node is referenced in tested model sql
* examples:
- "ref('my_model_a')"
- "source('my_source_schema', 'my_source_name')"
- "this"
tested_node: ModelNode of representing node being tested
"""
if input.strip() == "this":
original_input_node = tested_node
else:
try:
statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}")
except ExtractionError:
raise InvalidUnitTestGivenInput(input=input)
if statically_parsed["refs"]:
ref = list(statically_parsed["refs"])[0]
name = ref.get("name")
package = ref.get("package")
version = ref.get("version")
# TODO: disabled lookup, versioned lookup, public models
original_input_node = self.manifest.ref_lookup.find(
name, package, version, self.manifest
)
elif statically_parsed["sources"]:
source = list(statically_parsed["sources"])[0]
input_source_name, input_name = source
original_input_node = self.manifest.source_lookup.find(
f"{input_source_name}.{input_name}",
None,
self.manifest,
)
else:
raise InvalidUnitTestGivenInput(input=input)
return original_input_node
class UnitTestParser(YamlReader):
def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None:
super().__init__(schema_parser, yaml, "unit_tests")
self.schema_parser = schema_parser
self.yaml = yaml
def parse(self) -> ParseResult:
for data in self.get_key_dicts():
unit_test = self._get_unit_test(data)
tested_model_node = self._find_tested_model_node(unit_test)
unit_test_case_unique_id = (
f"{NodeType.Unit}.{self.project.project_name}.{unit_test.model}.{unit_test.name}"
)
unit_test_fqn = self._build_fqn(
self.project.project_name,
self.yaml.path.original_file_path,
unit_test.model,
unit_test.name,
)
unit_test_config = self._build_unit_test_config(unit_test_fqn, unit_test.config)
# Check that format and type of rows matches for each given input
for input in unit_test.given:
if input.rows is None and input.fixture is None:
input.rows = self._load_rows_from_seed(input.input)
input.validate_fixture("input", unit_test.name)
unit_test.expect.validate_fixture("expected", unit_test.name)
unit_test_definition = UnitTestDefinition(
name=unit_test.name,
model=unit_test.model,
resource_type=NodeType.Unit,
package_name=self.project.project_name,
path=self.yaml.path.relative_path,
original_file_path=self.yaml.path.original_file_path,
unique_id=unit_test_case_unique_id,
given=unit_test.given,
expect=unit_test.expect,
description=unit_test.description,
overrides=unit_test.overrides,
depends_on=DependsOn(nodes=[tested_model_node.unique_id]),
fqn=unit_test_fqn,
config=unit_test_config,
schema=tested_model_node.schema,
)
# for calculating state:modified
unit_test_definition.build_unit_test_checksum(
self.schema_parser.project.project_root, self.schema_parser.project.fixture_paths
)
self.manifest.add_unit_test(self.yaml.file, unit_test_definition)
return ParseResult()
def _get_unit_test(self, data: Dict[str, Any]) -> UnparsedUnitTest:
try:
UnparsedUnitTest.validate(data)
return UnparsedUnitTest.from_dict(data)
except (ValidationError, JSONValidationError) as exc:
raise YamlParseDictError(self.yaml.path, self.key, data, exc)
def _find_tested_model_node(self, unit_test: UnparsedUnitTest) -> ModelNode:
package_name = self.project.project_name
model_name_split = unit_test.model.split()
model_name = model_name_split[0]
model_version = model_name_split[1] if len(model_name_split) == 2 else None
tested_node = self.manifest.ref_lookup.find(
model_name, package_name, model_version, self.manifest
)
if not tested_node:
raise ParsingError(
f"Unable to find model '{package_name}.{unit_test.model}' for unit tests in {self.yaml.path.original_file_path}"
)
return tested_node
def _build_unit_test_config(
self, unit_test_fqn: List[str], config_dict: Dict[str, Any]
) -> UnitTestConfig:
config = ContextConfig(
self.schema_parser.root_project,
unit_test_fqn,
NodeType.Unit,
self.schema_parser.project.project_name,
)
unit_test_config_dict = config.build_config_dict(patch_config_dict=config_dict)
unit_test_config_dict = self.render_entry(unit_test_config_dict)
return UnitTestConfig.from_dict(unit_test_config_dict)
def _build_fqn(self, package_name, original_file_path, model_name, test_name):
# This code comes from "get_fqn" and "get_fqn_prefix" in the base parser.
# We need to get the directories underneath the model-path.
path = Path(original_file_path)
relative_path = str(path.relative_to(*path.parts[:1]))
no_ext = os.path.splitext(relative_path)[0]
fqn = [package_name]
fqn.extend(utils.split_path(no_ext)[:-1])
fqn.append(model_name)
fqn.append(test_name)
return fqn
def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]:
"""Read rows from seed file on disk if not specified in YAML config. If seed file doesn't exist, return empty list."""
ref = py_extract_from_source("{{ " + ref_str + " }}")["refs"][0]
rows: List[Dict[str, Any]] = []
seed_name = ref["name"]
package_name = ref.get("package", self.project.project_name)
seed_node = self.manifest.ref_lookup.find(seed_name, package_name, None, self.manifest)
if not seed_node or seed_node.resource_type != NodeType.Seed:
# Seed not found in custom package specified
if package_name != self.project.project_name:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in '{package_name}' package"
)
else:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in directories: {self.project.seed_paths}"
)
seed_path = Path(seed_node.root_path) / seed_node.original_file_path
with open(seed_path, "r") as f:
for row in DictReader(f):
rows.append(row)
return rows

View File

@@ -308,7 +308,7 @@ class BaseRunner(metaclass=ABCMeta):
with collect_timing_info("compile", ctx.timing.append):
# if we fail here, we still have a compiled node to return
# this has the benefit of showing a build path for the errant
# model
# model. This calls the 'compile' method in CompileTask
ctx.node = self.compile(manifest)
# for ephemeral nodes, we only want to compile, not run

View File

@@ -84,6 +84,7 @@ class BuildTask(RunTask):
NodeType.Snapshot: snapshot_model_runner,
NodeType.Seed: seed_runner,
NodeType.Test: test_runner,
NodeType.Unit: test_runner,
}
ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()})

View File

@@ -122,6 +122,7 @@ class GraphRunnableTask(ConfiguredTask):
fire_event(DefaultSelector(name=default_selector_name))
spec = self.config.get_selector(default_selector_name)
else:
# This is what's used with no default selector and no selection
# use --select and --exclude args
spec = parse_difference(self.selection_arg, self.exclusion_arg, indirect_selection)
return spec
@@ -136,6 +137,7 @@ class GraphRunnableTask(ConfiguredTask):
def get_graph_queue(self) -> GraphQueue:
selector = self.get_node_selector()
# Following uses self.selection_arg and self.exclusion_arg
spec = self.get_selection_spec()
return selector.get_graph_queue(spec)
@@ -155,9 +157,11 @@ class GraphRunnableTask(ConfiguredTask):
self._flattened_nodes.append(self.manifest.sources[uid])
elif uid in self.manifest.saved_queries:
self._flattened_nodes.append(self.manifest.saved_queries[uid])
elif uid in self.manifest.unit_tests:
self._flattened_nodes.append(self.manifest.unit_tests[uid])
else:
raise DbtInternalError(
f"Node selection returned {uid}, expected a node or a source"
f"Node selection returned {uid}, expected a node, a source, or a unit test"
)
self.num_nodes = len([n for n in self._flattened_nodes if not n.is_ephemeral_model])
@@ -206,6 +210,8 @@ class GraphRunnableTask(ConfiguredTask):
status: Dict[str, str] = {}
try:
result = runner.run_with_hooks(self.manifest)
except Exception as exc:
raise DbtInternalError(f"Unable to execute node: {exc}")
finally:
finishctx = TimestampNamed("finished_at")
with finishctx, DbtModelState(status):

View File

@@ -1,22 +1,24 @@
from distutils.util import strtobool
import agate
import daff
import re
from dataclasses import dataclass
from dbt.utils import _coerce_decimal
from dbt.events.format import pluralize
from dbt.dataclass_schema import dbtClassMixin
import threading
from typing import Dict, Any
from typing import Dict, Any, Optional, Union, List
from .compile import CompileRunner
from .run import RunTask
from dbt.contracts.graph.nodes import (
TestNode,
)
from dbt.contracts.graph.nodes import TestNode, UnitTestDefinition
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import TestStatus, PrimitiveDict, RunResult
from dbt.context.providers import generate_runtime_model_context
from dbt.clients.jinja import MacroGenerator
from dbt.clients.agate_helper import list_rows_from_table, json_rows_from_table
from dbt.events.functions import fire_event
from dbt.events.types import (
LogTestResult,
@@ -31,7 +33,16 @@ from dbt.graph import (
ResourceTypeSelector,
)
from dbt.node_types import NodeType
from dbt.parser.unit_tests import UnitTestManifestLoader
from dbt.flags import get_flags
from dbt.ui import green, red
@dataclass
class UnitTestDiff(dbtClassMixin):
actual: List[Dict[str, Any]]
expected: List[Dict[str, Any]]
rendered: str
@dataclass
@@ -59,10 +70,18 @@ class TestResultData(dbtClassMixin):
return bool(field)
@dataclass
class UnitTestResultData(dbtClassMixin):
should_error: bool
adapter_response: Dict[str, Any]
diff: Optional[UnitTestDiff] = None
class TestRunner(CompileRunner):
_ANSI_ESCAPE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
def describe_node(self):
node_name = self.node.name
return "test {}".format(node_name)
return f"{self.node.resource_type} {self.node.name}"
def print_result_line(self, result):
model = result.node
@@ -143,9 +162,87 @@ class TestRunner(CompileRunner):
TestResultData.validate(test_result_dct)
return TestResultData.from_dict(test_result_dct)
def execute(self, test: TestNode, manifest: Manifest):
result = self.execute_test(test, manifest)
def build_unit_test_manifest_from_test(
self, unit_test_def: UnitTestDefinition, manifest: Manifest
) -> Manifest:
# build a unit test manifest with only the test from this UnitTestDefinition
loader = UnitTestManifestLoader(manifest, self.config, {unit_test_def.unique_id})
return loader.load()
def execute_unit_test(
self, unit_test_def: UnitTestDefinition, manifest: Manifest
) -> UnitTestResultData:
unit_test_manifest = self.build_unit_test_manifest_from_test(unit_test_def, manifest)
# The unit test node and definition have the same unique_id
unit_test_node = unit_test_manifest.nodes[unit_test_def.unique_id]
# Compile the node
compiler = self.adapter.get_compiler()
unit_test_node = compiler.compile_node(unit_test_node, unit_test_manifest, {})
# generate_runtime_unit_test_context not strictly needed - this is to run the 'unit'
# materialization, not compile the node.compiled_code
context = generate_runtime_model_context(unit_test_node, self.config, unit_test_manifest)
materialization_macro = unit_test_manifest.find_materialization_macro_by_name(
self.config.project_name, unit_test_node.get_materialization(), self.adapter.type()
)
if materialization_macro is None:
raise MissingMaterializationError(
materialization=unit_test_node.get_materialization(),
adapter_type=self.adapter.type(),
)
if "config" not in context:
raise DbtInternalError(
"Invalid materialization context generated, missing config: {}".format(context)
)
# generate materialization macro
macro_func = MacroGenerator(materialization_macro, context)
# execute materialization macro
macro_func()
# load results from context
# could eventually be returned directly by materialization
result = context["load_result"]("main")
adapter_response = result["response"].to_dict(omit_none=True)
table = result["table"]
actual = self._get_unit_test_agate_table(table, "actual")
expected = self._get_unit_test_agate_table(table, "expected")
# generate diff, if exists
should_error, diff = False, None
daff_diff = self._get_daff_diff(expected, actual)
if daff_diff.hasDifference():
should_error = True
rendered = self._render_daff_diff(daff_diff)
rendered = f"\n\n{red('expected')} differs from {green('actual')}:\n\n{rendered}\n"
diff = UnitTestDiff(
actual=json_rows_from_table(actual),
expected=json_rows_from_table(expected),
rendered=rendered,
)
return UnitTestResultData(
diff=diff,
should_error=should_error,
adapter_response=adapter_response,
)
def execute(self, test: Union[TestNode, UnitTestDefinition], manifest: Manifest):
if isinstance(test, UnitTestDefinition):
unit_test_result = self.execute_unit_test(test, manifest)
return self.build_unit_test_run_result(test, unit_test_result)
else:
# Note: manifest here is a normal manifest
test_result = self.execute_test(test, manifest)
return self.build_test_run_result(test, test_result)
def build_test_run_result(self, test: TestNode, result: TestResultData) -> RunResult:
severity = test.config.severity.upper()
thread_id = threading.current_thread().name
num_errors = pluralize(result.failures, "result")
@@ -167,6 +264,31 @@ class TestRunner(CompileRunner):
else:
status = TestStatus.Pass
run_result = RunResult(
node=test,
status=status,
timing=[],
thread_id=thread_id,
execution_time=0,
message=message,
adapter_response=result.adapter_response,
failures=failures,
)
return run_result
def build_unit_test_run_result(
self, test: UnitTestDefinition, result: UnitTestResultData
) -> RunResult:
thread_id = threading.current_thread().name
status = TestStatus.Pass
message = None
failures = 0
if result.should_error:
status = TestStatus.Fail
message = result.diff.rendered if result.diff else None
failures = 1
return RunResult(
node=test,
status=status,
@@ -181,6 +303,41 @@ class TestRunner(CompileRunner):
def after_execute(self, result):
self.print_result_line(result)
def _get_unit_test_agate_table(self, result_table, actual_or_expected: str):
unit_test_table = result_table.where(
lambda row: row["actual_or_expected"] == actual_or_expected
)
columns = list(unit_test_table.columns.keys())
columns.remove("actual_or_expected")
return unit_test_table.select(columns)
def _get_daff_diff(
self, expected: agate.Table, actual: agate.Table, ordered: bool = False
) -> daff.TableDiff:
expected_daff_table = daff.PythonTableView(list_rows_from_table(expected))
actual_daff_table = daff.PythonTableView(list_rows_from_table(actual))
alignment = daff.Coopy.compareTables(expected_daff_table, actual_daff_table).align()
result = daff.PythonTableView([])
flags = daff.CompareFlags()
flags.ordered = ordered
diff = daff.TableDiff(alignment, flags)
diff.hilite(result)
return diff
def _render_daff_diff(self, daff_diff: daff.TableDiff) -> str:
result = daff.PythonTableView([])
daff_diff.hilite(result)
rendered = daff.TerminalDiffRender().render(result)
# strip colors if necessary
if not self.config.args.use_colors:
rendered = self._ANSI_ESCAPE.sub("", rendered)
return rendered
class TestSelector(ResourceTypeSelector):
def __init__(self, graph, manifest, previous_state) -> None:
@@ -188,7 +345,7 @@ class TestSelector(ResourceTypeSelector):
graph=graph,
manifest=manifest,
previous_state=previous_state,
resource_types=[NodeType.Test],
resource_types=[NodeType.Test, NodeType.Unit],
)

View File

@@ -81,6 +81,7 @@ setup(
"protobuf>=4.0.0",
"pytz>=2015.7",
"pyyaml>=6.0",
"daff>=1.3.46",
"typing-extensions>=4.4",
# ----
# Match snowflake-connector-python, to ensure compatibility in dbt-snowflake

View File

@@ -2,6 +2,19 @@ from dbt.adapters.base import Column
class PostgresColumn(Column):
TYPE_LABELS = {
"STRING": "TEXT",
"DATETIME": "TIMESTAMP",
"DATETIMETZ": "TIMESTAMPTZ",
"STRINGARRAY": "TEXT[]",
"INTEGERARRAY": "INT[]",
"DECIMALARRAY": "DECIMAL[]",
"BOOLEANARRAY": "BOOL[]",
"DATEARRAY": "DATE[]",
"DATETIMEARRAY": "TIMESTAMP[]",
"DATETIMETZARRAY": "TIMESTAMPTZ[]",
}
@property
def data_type(self):
# on postgres, do not convert 'text' or 'varchar' to 'varchar()'
@@ -9,4 +22,5 @@ class PostgresColumn(Column):
self.dtype.lower() == "character varying" and self.char_size is None
):
return self.dtype
return super().data_type

View File

@@ -3401,6 +3401,517 @@
"config"
]
},
"UnitTestNodeConfig": {
"type": "object",
"title": "UnitTestNodeConfig",
"properties": {
"_extra": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"enabled": {
"type": "boolean",
"default": true
},
"alias": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"schema": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"database": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"tags": {
"anyOf": [
{
"type": "array",
"items": {
"type": "string"
}
},
{
"type": "string"
}
]
},
"meta": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"group": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"materialized": {
"type": "string",
"default": "view"
},
"incremental_strategy": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"persist_docs": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"post-hook": {
"type": "array",
"items": {
"$ref": "#/$defs/Hook"
}
},
"pre-hook": {
"type": "array",
"items": {
"$ref": "#/$defs/Hook"
}
},
"quoting": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"column_types": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"full_refresh": {
"anyOf": [
{
"type": "boolean"
},
{
"type": "null"
}
],
"default": null
},
"unique_key": {
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "string"
}
},
{
"type": "null"
}
],
"default": null
},
"on_schema_change": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": "ignore"
},
"on_configuration_change": {
"enum": [
"apply",
"continue",
"fail"
]
},
"grants": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"packages": {
"type": "array",
"items": {
"type": "string"
}
},
"docs": {
"$ref": "#/$defs/Docs"
},
"contract": {
"$ref": "#/$defs/ContractConfig"
},
"expected_rows": {
"type": "array",
"items": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
}
},
"additionalProperties": true
},
"UnitTestOverrides": {
"type": "object",
"title": "UnitTestOverrides",
"properties": {
"macros": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"vars": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"env_vars": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
},
"additionalProperties": false
},
"UnitTestNode": {
"type": "object",
"title": "UnitTestNode",
"properties": {
"database": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
]
},
"schema": {
"type": "string"
},
"name": {
"type": "string"
},
"resource_type": {
"enum": [
"model",
"analysis",
"test",
"snapshot",
"operation",
"seed",
"rpc",
"sql_operation",
"doc",
"source",
"macro",
"exposure",
"metric",
"group",
"saved_query",
"semantic_model",
"unit_test"
]
},
"package_name": {
"type": "string"
},
"path": {
"type": "string"
},
"original_file_path": {
"type": "string"
},
"unique_id": {
"type": "string"
},
"fqn": {
"type": "array",
"items": {
"type": "string"
}
},
"alias": {
"type": "string"
},
"checksum": {
"$ref": "#/$defs/FileHash"
},
"config": {
"$ref": "#/$defs/UnitTestNodeConfig"
},
"_event_status": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"tags": {
"type": "array",
"items": {
"type": "string"
}
},
"description": {
"type": "string",
"default": ""
},
"columns": {
"type": "object",
"additionalProperties": {
"$ref": "#/$defs/ColumnInfo"
},
"propertyNames": {
"type": "string"
}
},
"meta": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"group": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"docs": {
"$ref": "#/$defs/Docs"
},
"patch_path": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"build_path": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"deferred": {
"type": "boolean",
"default": false
},
"unrendered_config": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"created_at": {
"type": "number"
},
"config_call_dict": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"relation_name": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"raw_code": {
"type": "string",
"default": ""
},
"language": {
"type": "string",
"default": "sql"
},
"refs": {
"type": "array",
"items": {
"$ref": "#/$defs/RefArgs"
}
},
"sources": {
"type": "array",
"items": {
"type": "array",
"items": {
"type": "string"
}
}
},
"metrics": {
"type": "array",
"items": {
"type": "array",
"items": {
"type": "string"
}
}
},
"depends_on": {
"$ref": "#/$defs/DependsOn"
},
"compiled_path": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"compiled": {
"type": "boolean",
"default": false
},
"compiled_code": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"extra_ctes_injected": {
"type": "boolean",
"default": false
},
"extra_ctes": {
"type": "array",
"items": {
"$ref": "#/$defs/InjectedCTE"
}
},
"_pre_injected_sql": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"contract": {
"$ref": "#/$defs/Contract"
},
"attached_node": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
},
"overrides": {
"anyOf": [
{
"$ref": "#/$defs/UnitTestOverrides"
},
{
"type": "null"
}
],
"default": null
}
},
"additionalProperties": false,
"required": [
"database",
"schema",
"name",
"resource_type",
"package_name",
"path",
"original_file_path",
"unique_id",
"fqn",
"alias",
"checksum"
]
},
"SeedConfig": {
"type": "object",
"title": "SeedConfig",
@@ -5251,7 +5762,8 @@
"metric",
"group",
"saved_query",
"semantic_model"
"semantic_model",
"unit_test"
]
},
"package_name": {
@@ -5822,7 +6334,8 @@
"metric",
"group",
"saved_query",
"semantic_model"
"semantic_model",
"unit_test"
]
},
"package_name": {
@@ -5975,6 +6488,200 @@
"node_relation"
]
},
"UnitTestInputFixture": {
"type": "object",
"title": "UnitTestInputFixture",
"properties": {
"input": {
"type": "string"
},
"rows": {
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
}
],
"default": ""
},
"format": {
"enum": [
"csv",
"dict"
],
"default": "dict"
}
},
"additionalProperties": false,
"required": [
"input"
]
},
"UnitTestOutputFixture": {
"type": "object",
"title": "UnitTestOutputFixture",
"properties": {
"rows": {
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
}
],
"default": ""
},
"format": {
"enum": [
"csv",
"dict"
],
"default": "dict"
}
},
"additionalProperties": false
},
"UnitTestConfig": {
"type": "object",
"title": "UnitTestConfig",
"properties": {
"_extra": {
"type": "object",
"propertyNames": {
"type": "string"
}
},
"tags": {
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "string"
}
}
]
},
"meta": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
},
"additionalProperties": true
},
"UnitTestDefinition": {
"type": "object",
"title": "UnitTestDefinition",
"properties": {
"name": {
"type": "string"
},
"resource_type": {
"enum": [
"model",
"analysis",
"test",
"snapshot",
"operation",
"seed",
"rpc",
"sql_operation",
"doc",
"source",
"macro",
"exposure",
"metric",
"group",
"saved_query",
"semantic_model",
"unit_test"
]
},
"package_name": {
"type": "string"
},
"path": {
"type": "string"
},
"original_file_path": {
"type": "string"
},
"unique_id": {
"type": "string"
},
"fqn": {
"type": "array",
"items": {
"type": "string"
}
},
"model": {
"type": "string"
},
"given": {
"type": "array",
"items": {
"$ref": "#/$defs/UnitTestInputFixture"
}
},
"expect": {
"$ref": "#/$defs/UnitTestOutputFixture"
},
"description": {
"type": "string",
"default": ""
},
"overrides": {
"anyOf": [
{
"$ref": "#/$defs/UnitTestOverrides"
},
{
"type": "null"
}
],
"default": null
},
"depends_on": {
"$ref": "#/$defs/DependsOn"
},
"config": {
"$ref": "#/$defs/UnitTestConfig"
}
},
"additionalProperties": false,
"required": [
"name",
"resource_type",
"package_name",
"path",
"original_file_path",
"unique_id",
"fqn",
"model",
"given",
"expect"
]
},
"WritableManifest": {
"type": "object",
"title": "WritableManifest",
@@ -6012,6 +6719,9 @@
{
"$ref": "#/$defs/SnapshotNode"
},
{
"$ref": "#/$defs/UnitTestNode"
},
{
"$ref": "#/$defs/SeedNode"
}
@@ -6121,6 +6831,9 @@
{
"$ref": "#/$defs/SnapshotNode"
},
{
"$ref": "#/$defs/UnitTestNode"
},
{
"$ref": "#/$defs/SeedNode"
},
@@ -6138,6 +6851,9 @@
},
{
"$ref": "#/$defs/SemanticModel"
},
{
"$ref": "#/$defs/UnitTestDefinition"
}
]
}
@@ -6230,6 +6946,16 @@
"propertyNames": {
"type": "string"
}
},
"unit_tests": {
"type": "object",
"description": "The unit tests defined in the project",
"additionalProperties": {
"$ref": "#/$defs/UnitTestDefinition"
},
"propertyNames": {
"type": "string"
}
}
},
"additionalProperties": false,
@@ -6248,7 +6974,8 @@
"child_map",
"group_map",
"saved_queries",
"semantic_models"
"semantic_models",
"unit_tests"
]
}
},

View File

@@ -69,10 +69,10 @@ class BaseConstraintsColumnsEqual:
["1", schema_int_type, int_type],
["'1'", string_type, string_type],
["true", "bool", "BOOL"],
["'2013-11-03 00:00:00-07'::timestamptz", "timestamptz", "DATETIMETZ"],
["'2013-11-03 00:00:00-07'::timestamp", "timestamp", "DATETIME"],
["ARRAY['a','b','c']", "text[]", "STRINGARRAY"],
["ARRAY[1,2,3]", "int[]", "INTEGERARRAY"],
["'2013-11-03 00:00:00-07'::timestamptz", "timestamptz", "TIMESTAMPTZ"],
["'2013-11-03 00:00:00-07'::timestamp", "timestamp", "TIMESTAMP"],
["ARRAY['a','b','c']", "text[]", "TEXT[]"],
["ARRAY[1,2,3]", "int[]", "INT[]"],
["'1'::numeric", "numeric", "DECIMAL"],
["""'{"bar": "baz", "balance": 7.77, "active": false}'::json""", "json", "JSON"],
]

View File

@@ -0,0 +1,93 @@
import pytest
from dbt.tests.util import write_file, run_dbt
my_model_sql = """
select
tested_column from {{ ref('my_upstream_model')}}
"""
my_upstream_model_sql = """
select
{sql_value} as tested_column,
{sql_value} as untested_column
"""
test_my_model_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {{ tested_column: {yaml_value} }}
expect:
rows:
- {{ tested_column: {yaml_value} }}
"""
class BaseUnitTestingTypes:
@pytest.fixture
def data_types(self):
# sql_value, yaml_value
return [
["1", "1"],
["1.0", "1.0"],
["'1'", "1"],
["'1'::numeric", "1"],
["'string'", "string"],
["true", "true"],
["DATE '2020-01-02'", "2020-01-02"],
["TIMESTAMP '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"],
["TIMESTAMPTZ '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"],
["ARRAY[1,2,3]", """'{1, 2, 3}'"""],
["ARRAY[1.0,2.0,3.0]", """'{1.0, 2.0, 3.0}'"""],
["ARRAY[1::numeric,2::numeric,3::numeric]", """'{1.0, 2.0, 3.0}'"""],
["ARRAY['a','b','c']", """'{"a", "b", "c"}'"""],
["ARRAY[true,true,false]", """'{true, true, false}'"""],
["ARRAY[DATE '2020-01-02']", """'{"2020-01-02"}'"""],
["ARRAY[TIMESTAMP '2013-11-03 00:00:00-0']", """'{"2013-11-03 00:00:00-0"}'"""],
["ARRAY[TIMESTAMPTZ '2013-11-03 00:00:00-0']", """'{"2013-11-03 00:00:00-0"}'"""],
[
"""'{"bar": "baz", "balance": 7.77, "active": false}'::json""",
"""'{"bar": "baz", "balance": 7.77, "active": false}'""",
],
]
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_upstream_model.sql": my_upstream_model_sql,
"schema.yml": test_my_model_yml,
}
def test_unit_test_data_type(self, project, data_types):
for (sql_value, yaml_value) in data_types:
# Write parametrized type value to sql files
write_file(
my_upstream_model_sql.format(sql_value=sql_value),
"models",
"my_upstream_model.sql",
)
# Write parametrized type value to unit test yaml definition
write_file(
test_my_model_yml.format(yaml_value=yaml_value),
"models",
"schema.yml",
)
results = run_dbt(["run", "--select", "my_upstream_model"])
assert len(results) == 1
try:
run_dbt(["test", "--select", "my_model"])
except Exception:
raise AssertionError(f"unit test failed when testing model with {sql_value}")
class TestUnitTestingTypes(BaseUnitTestingTypes):
pass

File diff suppressed because one or more lines are too long

View File

@@ -890,6 +890,7 @@ def expected_seeded_manifest(project, model_database=None, quote_model=False):
},
"disabled": {},
"semantic_models": {},
"unit_tests": {},
"saved_queries": {},
}
@@ -1450,6 +1451,7 @@ def expected_references_manifest(project):
}
},
"semantic_models": {},
"unit_tests": {},
"saved_queries": {},
}
@@ -1930,5 +1932,6 @@ def expected_versions_manifest(project):
"disabled": {},
"macros": {},
"semantic_models": {},
"unit_tests": {},
"saved_queries": {},
}

View File

@@ -469,6 +469,7 @@ def verify_manifest(project, expected_manifest, start_time, manifest_schema_path
"exposures",
"selectors",
"semantic_models",
"unit_tests",
"saved_queries",
}

View File

@@ -137,6 +137,19 @@ models:
- not_null
"""
unit_tests__yml = """
unit_tests:
- name: ut_model_3
model: model_3
given:
- input: ref('model_1')
rows:
- {iso3: ABW, name: Aruba}
expect:
rows:
- {iso3: ABW, name: Aruba}
"""
models_failing_tests__tests_yml = """
version: 2

View File

@@ -1,7 +1,7 @@
import pytest
from dbt.tests.util import run_dbt
from tests.functional.build.fixtures import (
from tests.functional.build_command.fixtures import (
seeds__country_csv,
snapshots__snap_0,
snapshots__snap_1,
@@ -24,6 +24,7 @@ from tests.functional.build.fixtures import (
models_interdependent__model_b_sql,
models_interdependent__model_b_null_sql,
models_interdependent__model_c_sql,
unit_tests__yml,
)
@@ -56,8 +57,9 @@ class TestPassingBuild(TestBuildBase):
"model_0.sql": models__model_0_sql,
"model_1.sql": models__model_1_sql,
"model_2.sql": models__model_2_sql,
"model_3.sql": models__model_3_sql,
"model_99.sql": models__model_99_sql,
"test.yml": models__test_yml,
"test.yml": models__test_yml + unit_tests__yml,
}
def test_build_happy_path(self, project):
@@ -73,14 +75,14 @@ class TestFailingBuild(TestBuildBase):
"model_2.sql": models__model_2_sql,
"model_3.sql": models__model_3_sql,
"model_99.sql": models__model_99_sql,
"test.yml": models__test_yml,
"test.yml": models__test_yml + unit_tests__yml,
}
def test_failing_test_skips_downstream(self, project):
results = run_dbt(["build"], expect_pass=False)
assert len(results) == 13
assert len(results) == 14
actual = [str(r.status) for r in results]
expected = ["error"] * 1 + ["skipped"] * 5 + ["pass"] * 2 + ["success"] * 5
expected = ["error"] * 1 + ["skipped"] * 6 + ["pass"] * 2 + ["success"] * 5
assert sorted(actual) == sorted(expected)
@@ -210,7 +212,9 @@ class TestDownstreamSelection:
def test_downstream_selection(self, project):
"""Ensure that selecting test+ does not select model_a's other children"""
results = run_dbt(["build", "--select", "model_a not_null_model_a_id+"], expect_pass=True)
# fails with "Got 1 result, configured to fail if != 0"
# model_a is defined as select null as id
results = run_dbt(["build", "--select", "model_a not_null_model_a_id+"], expect_pass=False)
assert len(results) == 2
@@ -226,5 +230,6 @@ class TestLimitedUpstreamSelection:
def test_limited_upstream_selection(self, project):
"""Ensure that selecting 1+model_c only selects up to model_b (+ tests of both)"""
results = run_dbt(["build", "--select", "1+model_c"], expect_pass=True)
# Fails with "relation "test17005969872609282880_test_build.model_a" does not exist"
results = run_dbt(["build", "--select", "1+model_c"], expect_pass=False)
assert len(results) == 4

View File

@@ -157,7 +157,6 @@ class TestGraphSelection(SelectionFixtures):
]
# ["list", "--project-dir", str(project.project_root), "--select", "models/test/subdir*"]
)
print(f"--- results: {results}")
assert len(results) == 1
def test_locally_qualified_name_model_with_dots(self, project):

View File

@@ -0,0 +1,600 @@
my_model_vars_sql = """
SELECT
a+b as c,
concat(string_a, string_b) as string_c,
not_testing, date_a,
{{ dbt.string_literal(type_numeric()) }} as macro_call,
{{ dbt.string_literal(var('my_test')) }} as var_call,
{{ dbt.string_literal(env_var('MY_TEST', 'default')) }} as env_var_call,
{{ dbt.string_literal(invocation_id) }} as invocation_id
FROM {{ ref('my_model_a')}} my_model_a
JOIN {{ ref('my_model_b' )}} my_model_b
ON my_model_a.id = my_model_b.id
"""
my_model_sql = """
SELECT
a+b as c,
concat(string_a, string_b) as string_c,
not_testing, date_a
FROM {{ ref('my_model_a')}} my_model_a
JOIN {{ ref('my_model_b' )}} my_model_b
ON my_model_a.id = my_model_b.id
"""
my_model_a_sql = """
SELECT
1 as a,
1 as id,
2 as not_testing,
'a' as string_a,
DATE '2020-01-02' as date_a
"""
my_model_b_sql = """
SELECT
2 as b,
1 as id,
2 as c,
'b' as string_b
"""
test_my_model_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_model_a')
rows:
- {id: 1, a: 1}
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
expect:
rows:
- {c: 2}
- name: test_my_model_empty
model: my_model
given:
- input: ref('my_model_a')
rows: []
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
expect:
rows: []
- name: test_my_model_overrides
model: my_model
given:
- input: ref('my_model_a')
rows:
- {id: 1, a: 1}
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
overrides:
macros:
type_numeric: override
invocation_id: 123
vars:
my_test: var_override
env_vars:
MY_TEST: env_var_override
expect:
rows:
- {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123}
- name: test_my_model_string_concat
model: my_model
given:
- input: ref('my_model_a')
rows:
- {id: 1, string_a: a}
- input: ref('my_model_b')
rows:
- {id: 1, string_b: b}
expect:
rows:
- {string_c: ab}
config:
tags: test_this
"""
test_my_model_simple_fixture_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_model_a')
rows:
- {id: 1, a: 1}
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
expect:
rows:
- {c: 2}
- name: test_depends_on_fixture
model: my_model
given:
- input: ref('my_model_a')
rows: []
- input: ref('my_model_b')
format: csv
fixture: test_my_model_fixture
expect:
rows: []
- name: test_my_model_overrides
model: my_model
given:
- input: ref('my_model_a')
rows:
- {id: 1, a: 1}
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
overrides:
macros:
type_numeric: override
invocation_id: 123
vars:
my_test: var_override
env_vars:
MY_TEST: env_var_override
expect:
rows:
- {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123}
- name: test_has_string_c_ab
model: my_model
given:
- input: ref('my_model_a')
rows:
- {id: 1, string_a: a}
- input: ref('my_model_b')
rows:
- {id: 1, string_b: b}
expect:
rows:
- {string_c: ab}
config:
tags: test_this
"""
datetime_test = """
- name: test_my_model_datetime
model: my_model
given:
- input: ref('my_model_a')
rows:
- {id: 1, date_a: "2020-01-01"}
- input: ref('my_model_b')
rows:
- {id: 1}
expect:
rows:
- {date_a: "2020-01-01"}
"""
event_sql = """
select DATE '2020-01-01' as event_time, 1 as event
union all
select DATE '2020-01-02' as event_time, 2 as event
union all
select DATE '2020-01-03' as event_time, 3 as event
"""
datetime_test_invalid_format_key = """
- name: test_my_model_datetime
model: my_model
given:
- input: ref('my_model_a')
format: xxxx
rows:
- {id: 1, date_a: "2020-01-01"}
- input: ref('my_model_b')
rows:
- {id: 1}
expect:
rows:
- {date_a: "2020-01-01"}
"""
datetime_test_invalid_csv_values = """
- name: test_my_model_datetime
model: my_model
given:
- input: ref('my_model_a')
format: csv
rows:
- {id: 1, date_a: "2020-01-01"}
- input: ref('my_model_b')
rows:
- {id: 1}
expect:
rows:
- {date_a: "2020-01-01"}
"""
datetime_test_invalid_csv_file_values = """
- name: test_my_model_datetime
model: my_model
given:
- input: ref('my_model_a')
format: csv
rows:
- {id: 1, date_a: "2020-01-01"}
- input: ref('my_model_b')
rows:
- {id: 1}
expect:
rows:
- {date_a: "2020-01-01"}
"""
event_sql = """
select DATE '2020-01-01' as event_time, 1 as event
union all
select DATE '2020-01-02' as event_time, 2 as event
union all
select DATE '2020-01-03' as event_time, 3 as event
"""
my_incremental_model_sql = """
{{
config(
materialized='incremental'
)
}}
select * from {{ ref('events') }}
{% if is_incremental() %}
where event_time > (select max(event_time) from {{ this }})
{% endif %}
"""
test_my_model_incremental_yml = """
unit_tests:
- name: incremental_false
model: my_incremental_model
overrides:
macros:
is_incremental: false
given:
- input: ref('events')
rows:
- {event_time: "2020-01-01", event: 1}
expect:
rows:
- {event_time: "2020-01-01", event: 1}
- name: incremental_true
model: my_incremental_model
overrides:
macros:
is_incremental: true
given:
- input: ref('events')
rows:
- {event_time: "2020-01-01", event: 1}
- {event_time: "2020-01-02", event: 2}
- {event_time: "2020-01-03", event: 3}
- input: this
rows:
- {event_time: "2020-01-01", event: 1}
expect:
rows:
- {event_time: "2020-01-02", event: 2}
- {event_time: "2020-01-03", event: 3}
"""
# -- inline csv tests
test_my_model_csv_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_model_a')
format: csv
rows: |
id,a
1,1
- input: ref('my_model_b')
format: csv
rows: |
id,b
1,2
2,2
expect:
format: csv
rows: |
c
2
- name: test_my_model_empty
model: my_model
given:
- input: ref('my_model_a')
rows: []
- input: ref('my_model_b')
format: csv
rows: |
id,b
1,2
2,2
expect:
rows: []
- name: test_my_model_overrides
model: my_model
given:
- input: ref('my_model_a')
format: csv
rows: |
id,a
1,1
- input: ref('my_model_b')
format: csv
rows: |
id,b
1,2
2,2
overrides:
macros:
type_numeric: override
invocation_id: 123
vars:
my_test: var_override
env_vars:
MY_TEST: env_var_override
expect:
rows:
- {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123}
- name: test_my_model_string_concat
model: my_model
given:
- input: ref('my_model_a')
format: csv
rows: |
id,string_a
1,a
- input: ref('my_model_b')
format: csv
rows: |
id,string_b
1,b
expect:
format: csv
rows: |
string_c
ab
config:
tags: test_this
"""
# -- csv file tests
test_my_model_file_csv_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_model_a')
format: csv
fixture: test_my_model_a_numeric_fixture
- input: ref('my_model_b')
format: csv
fixture: test_my_model_fixture
expect:
format: csv
fixture: test_my_model_basic_fixture
- name: test_my_model_empty
model: my_model
given:
- input: ref('my_model_a')
format: csv
fixture: test_my_model_a_empty_fixture
- input: ref('my_model_b')
format: csv
fixture: test_my_model_fixture
expect:
format: csv
fixture: test_my_model_a_empty_fixture
- name: test_my_model_overrides
model: my_model
given:
- input: ref('my_model_a')
format: csv
fixture: test_my_model_a_numeric_fixture
- input: ref('my_model_b')
format: csv
fixture: test_my_model_fixture
overrides:
macros:
type_numeric: override
invocation_id: 123
vars:
my_test: var_override
env_vars:
MY_TEST: env_var_override
expect:
rows:
- {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123}
- name: test_my_model_string_concat
model: my_model
given:
- input: ref('my_model_a')
format: csv
fixture: test_my_model_a_fixture
- input: ref('my_model_b')
format: csv
fixture: test_my_model_b_fixture
expect:
format: csv
fixture: test_my_model_concat_fixture
config:
tags: test_this
"""
test_my_model_fixture_csv = """id,b
1,2
2,2
"""
test_my_model_a_fixture_csv = """id,string_a
1,a
"""
test_my_model_a_empty_fixture_csv = """
"""
test_my_model_a_numeric_fixture_csv = """id,a
1,1
"""
test_my_model_b_fixture_csv = """id,string_b
1,b
"""
test_my_model_basic_fixture_csv = """c
2
"""
test_my_model_concat_fixture_csv = """string_c
ab
"""
# -- mixed inline and file csv
test_my_model_mixed_csv_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_model_a')
format: csv
rows: |
id,a
1,1
- input: ref('my_model_b')
format: csv
rows: |
id,b
1,2
2,2
expect:
format: csv
fixture: test_my_model_basic_fixture
- name: test_my_model_empty
model: my_model
given:
- input: ref('my_model_a')
format: csv
fixture: test_my_model_a_empty_fixture
- input: ref('my_model_b')
format: csv
rows: |
id,b
1,2
2,2
expect:
format: csv
fixture: test_my_model_a_empty_fixture
- name: test_my_model_overrides
model: my_model
given:
- input: ref('my_model_a')
format: csv
rows: |
id,a
1,1
- input: ref('my_model_b')
format: csv
fixture: test_my_model_fixture
overrides:
macros:
type_numeric: override
invocation_id: 123
vars:
my_test: var_override
env_vars:
MY_TEST: env_var_override
expect:
rows:
- {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123}
- name: test_my_model_string_concat
model: my_model
given:
- input: ref('my_model_a')
format: csv
fixture: test_my_model_a_fixture
- input: ref('my_model_b')
format: csv
fixture: test_my_model_b_fixture
expect:
format: csv
rows: |
string_c
ab
config:
tags: test_this
"""
# unit tests with errors
# -- fixture file doesn't exist
test_my_model_missing_csv_yml = """
unit_tests:
- name: test_missing_csv_file
model: my_model
given:
- input: ref('my_model_a')
format: csv
rows: |
id,a
1,1
- input: ref('my_model_b')
format: csv
rows: |
id,b
1,2
2,2
expect:
format: csv
fixture: fake_fixture
"""
test_my_model_duplicate_csv_yml = """
unit_tests:
- name: test_missing_csv_file
model: my_model
given:
- input: ref('my_model_a')
format: csv
rows: |
id,a
1,1
- input: ref('my_model_b')
format: csv
rows: |
id,b
1,2
2,2
expect:
format: csv
fixture: test_my_model_basic_fixture
"""

View File

@@ -0,0 +1,221 @@
import pytest
from dbt.exceptions import ParsingError, YamlParseDictError
from dbt.tests.util import run_dbt, write_file
from fixtures import (
my_model_sql,
my_model_a_sql,
my_model_b_sql,
test_my_model_csv_yml,
datetime_test,
datetime_test_invalid_format_key,
datetime_test_invalid_csv_values,
test_my_model_file_csv_yml,
test_my_model_fixture_csv,
test_my_model_a_fixture_csv,
test_my_model_b_fixture_csv,
test_my_model_basic_fixture_csv,
test_my_model_a_numeric_fixture_csv,
test_my_model_a_empty_fixture_csv,
test_my_model_concat_fixture_csv,
test_my_model_mixed_csv_yml,
test_my_model_missing_csv_yml,
test_my_model_duplicate_csv_yml,
)
class TestUnitTestsWithInlineCSV:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_model_a.sql": my_model_a_sql,
"my_model_b.sql": my_model_b_sql,
"test_my_model.yml": test_my_model_csv_yml + datetime_test,
}
def test_unit_test(self, project):
results = run_dbt(["run"])
assert len(results) == 3
# Select by model name
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
assert len(results) == 5
# Check error with invalid format key
write_file(
test_my_model_csv_yml + datetime_test_invalid_format_key,
project.project_root,
"models",
"test_my_model.yml",
)
with pytest.raises(YamlParseDictError):
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
# Check error with csv format defined but dict on rows
write_file(
test_my_model_csv_yml + datetime_test_invalid_csv_values,
project.project_root,
"models",
"test_my_model.yml",
)
with pytest.raises(ParsingError):
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
class TestUnitTestsWithFileCSV:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_model_a.sql": my_model_a_sql,
"my_model_b.sql": my_model_b_sql,
"test_my_model.yml": test_my_model_file_csv_yml + datetime_test,
}
@pytest.fixture(scope="class")
def tests(self):
return {
"fixtures": {
"test_my_model_fixture.csv": test_my_model_fixture_csv,
"test_my_model_a_fixture.csv": test_my_model_a_fixture_csv,
"test_my_model_b_fixture.csv": test_my_model_b_fixture_csv,
"test_my_model_basic_fixture.csv": test_my_model_basic_fixture_csv,
"test_my_model_a_numeric_fixture.csv": test_my_model_a_numeric_fixture_csv,
"test_my_model_a_empty_fixture.csv": test_my_model_a_empty_fixture_csv,
"test_my_model_concat_fixture.csv": test_my_model_concat_fixture_csv,
}
}
def test_unit_test(self, project):
results = run_dbt(["run"])
assert len(results) == 3
# Select by model name
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
assert len(results) == 5
# Check error with invalid format key
write_file(
test_my_model_file_csv_yml + datetime_test_invalid_format_key,
project.project_root,
"models",
"test_my_model.yml",
)
with pytest.raises(YamlParseDictError):
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
# Check error with csv format defined but dict on rows
write_file(
test_my_model_file_csv_yml + datetime_test_invalid_csv_values,
project.project_root,
"models",
"test_my_model.yml",
)
with pytest.raises(ParsingError):
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
class TestUnitTestsWithMixedCSV:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_model_a.sql": my_model_a_sql,
"my_model_b.sql": my_model_b_sql,
"test_my_model.yml": test_my_model_mixed_csv_yml + datetime_test,
}
@pytest.fixture(scope="class")
def tests(self):
return {
"fixtures": {
"test_my_model_fixture.csv": test_my_model_fixture_csv,
"test_my_model_a_fixture.csv": test_my_model_a_fixture_csv,
"test_my_model_b_fixture.csv": test_my_model_b_fixture_csv,
"test_my_model_basic_fixture.csv": test_my_model_basic_fixture_csv,
"test_my_model_a_numeric_fixture.csv": test_my_model_a_numeric_fixture_csv,
"test_my_model_a_empty_fixture.csv": test_my_model_a_empty_fixture_csv,
"test_my_model_concat_fixture.csv": test_my_model_concat_fixture_csv,
}
}
def test_unit_test(self, project):
results = run_dbt(["run"])
assert len(results) == 3
# Select by model name
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
assert len(results) == 5
# Check error with invalid format key
write_file(
test_my_model_mixed_csv_yml + datetime_test_invalid_format_key,
project.project_root,
"models",
"test_my_model.yml",
)
with pytest.raises(YamlParseDictError):
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
# Check error with csv format defined but dict on rows
write_file(
test_my_model_mixed_csv_yml + datetime_test_invalid_csv_values,
project.project_root,
"models",
"test_my_model.yml",
)
with pytest.raises(ParsingError):
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
class TestUnitTestsMissingCSVFile:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_model_a.sql": my_model_a_sql,
"my_model_b.sql": my_model_b_sql,
"test_my_model.yml": test_my_model_missing_csv_yml,
}
def test_missing(self, project):
results = run_dbt(["run"])
assert len(results) == 3
# Select by model name
expected_error = "Could not find fixture file fake_fixture for unit test"
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
assert expected_error in results[0].message
class TestUnitTestsDuplicateCSVFile:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_model_a.sql": my_model_a_sql,
"my_model_b.sql": my_model_b_sql,
"test_my_model.yml": test_my_model_duplicate_csv_yml,
}
@pytest.fixture(scope="class")
def tests(self):
return {
"fixtures": {
"one-folder": {
"test_my_model_basic_fixture.csv": test_my_model_basic_fixture_csv,
},
"another-folder": {
"test_my_model_basic_fixture.csv": test_my_model_basic_fixture_csv,
},
}
}
def test_duplicate(self, project):
results = run_dbt(["run"])
assert len(results) == 3
# Select by model name
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
expected_error = "Found multiple fixture files named test_my_model_basic_fixture"
assert expected_error in results[0].message

View File

@@ -0,0 +1,135 @@
import os
import pytest
import shutil
from copy import deepcopy
from dbt.tests.util import (
run_dbt,
write_file,
write_config_file,
)
from fixtures import (
my_model_vars_sql,
my_model_a_sql,
my_model_b_sql,
test_my_model_simple_fixture_yml,
test_my_model_fixture_csv,
test_my_model_b_fixture_csv as test_my_model_fixture_csv_modified,
)
class UnitTestState:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_vars_sql,
"my_model_a.sql": my_model_a_sql,
"my_model_b.sql": my_model_b_sql,
"test_my_model.yml": test_my_model_simple_fixture_yml,
}
@pytest.fixture(scope="class")
def tests(self):
return {
"fixtures": {
"test_my_model_fixture.csv": test_my_model_fixture_csv,
}
}
@pytest.fixture(scope="class")
def project_config_update(self):
return {"vars": {"my_test": "my_test_var"}}
def copy_state(self, project_root):
state_path = os.path.join(project_root, "state")
if not os.path.exists(state_path):
os.makedirs(state_path)
shutil.copyfile(
f"{project_root}/target/manifest.json", f"{project_root}/state/manifest.json"
)
shutil.copyfile(
f"{project_root}/target/run_results.json", f"{project_root}/state/run_results.json"
)
class TestUnitTestStateModified(UnitTestState):
def test_state_modified(self, project):
run_dbt(["run"])
run_dbt(["test"], expect_pass=False)
self.copy_state(project.project_root)
# no changes
results = run_dbt(["test", "--select", "state:modified", "--state", "state"])
assert len(results) == 0
# change underlying fixture file
write_file(
test_my_model_fixture_csv_modified,
project.project_root,
"tests",
"fixtures",
"test_my_model_fixture.csv",
)
# TODO: remove --no-partial-parse as part of https://github.com/dbt-labs/dbt-core/issues/9067
results = run_dbt(
["--no-partial-parse", "test", "--select", "state:modified", "--state", "state"],
expect_pass=True,
)
assert len(results) == 1
assert results[0].node.name.endswith("test_depends_on_fixture")
# reset changes
self.copy_state(project.project_root)
# change unit test definition of a single unit test
with_changes = test_my_model_simple_fixture_yml.replace("{string_c: ab}", "{string_c: bc}")
write_config_file(with_changes, project.project_root, "models", "test_my_model.yml")
results = run_dbt(
["test", "--select", "state:modified", "--state", "state"], expect_pass=False
)
assert len(results) == 1
assert results[0].node.name.endswith("test_has_string_c_ab")
# change underlying model logic
write_config_file(
test_my_model_simple_fixture_yml, project.project_root, "models", "test_my_model.yml"
)
write_file(
my_model_vars_sql.replace("a+b as c,", "a + b as c,"),
project.project_root,
"models",
"my_model.sql",
)
results = run_dbt(
["test", "--select", "state:modified", "--state", "state"], expect_pass=False
)
assert len(results) == 4
class TestUnitTestRetry(UnitTestState):
def test_unit_test_retry(self, project):
run_dbt(["run"])
run_dbt(["test"], expect_pass=False)
self.copy_state(project.project_root)
results = run_dbt(["retry"], expect_pass=False)
assert len(results) == 1
class TestUnitTestDeferState(UnitTestState):
@pytest.fixture(scope="class")
def other_schema(self, unique_schema):
return unique_schema + "_other"
@pytest.fixture(scope="class")
def profiles_config_update(self, dbt_profile_target, unique_schema, other_schema):
outputs = {"default": dbt_profile_target, "otherschema": deepcopy(dbt_profile_target)}
outputs["default"]["schema"] = unique_schema
outputs["otherschema"]["schema"] = other_schema
return {"test": {"outputs": outputs, "target": "default"}}
def test_unit_test_defer_state(self, project):
run_dbt(["run", "--target", "otherschema"])
self.copy_state(project.project_root)
results = run_dbt(["test", "--defer", "--state", "state"], expect_pass=False)
assert len(results) == 4
assert sorted([r.status for r in results]) == ["fail", "pass", "pass", "pass"]

View File

@@ -0,0 +1,232 @@
import pytest
from dbt.tests.util import (
run_dbt,
write_file,
get_manifest,
)
from dbt.exceptions import DuplicateResourceNameError, ParsingError
from fixtures import (
my_model_vars_sql,
my_model_a_sql,
my_model_b_sql,
test_my_model_yml,
datetime_test,
my_incremental_model_sql,
event_sql,
test_my_model_incremental_yml,
)
class TestUnitTests:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_vars_sql,
"my_model_a.sql": my_model_a_sql,
"my_model_b.sql": my_model_b_sql,
"test_my_model.yml": test_my_model_yml + datetime_test,
}
@pytest.fixture(scope="class")
def project_config_update(self):
return {"vars": {"my_test": "my_test_var"}}
def test_basic(self, project):
results = run_dbt(["run"])
assert len(results) == 3
# Select by model name
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
assert len(results) == 5
# Test select by test name
results = run_dbt(["test", "--select", "test_name:test_my_model_string_concat"])
assert len(results) == 1
# Select, method not specified
results = run_dbt(["test", "--select", "test_my_model_overrides"])
assert len(results) == 1
# Select using tag
results = run_dbt(["test", "--select", "tag:test_this"])
assert len(results) == 1
# Partial parsing... remove test
write_file(test_my_model_yml, project.project_root, "models", "test_my_model.yml")
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
assert len(results) == 4
# Partial parsing... put back removed test
write_file(
test_my_model_yml + datetime_test, project.project_root, "models", "test_my_model.yml"
)
results = run_dbt(["test", "--select", "my_model"], expect_pass=False)
assert len(results) == 5
manifest = get_manifest(project.project_root)
assert len(manifest.unit_tests) == 5
# Every unit test has a depends_on to the model it tests
for unit_test_definition in manifest.unit_tests.values():
assert unit_test_definition.depends_on.nodes[0] == "model.test.my_model"
# Check for duplicate unit test name
# this doesn't currently pass with partial parsing because of the root problem
# described in https://github.com/dbt-labs/dbt-core/issues/8982
write_file(
test_my_model_yml + datetime_test + datetime_test,
project.project_root,
"models",
"test_my_model.yml",
)
with pytest.raises(DuplicateResourceNameError):
run_dbt(["run", "--no-partial-parse", "--select", "my_model"])
class TestUnitTestIncrementalModel:
@pytest.fixture(scope="class")
def models(self):
return {
"my_incremental_model.sql": my_incremental_model_sql,
"events.sql": event_sql,
"test_my_incremental_model.yml": test_my_model_incremental_yml,
}
def test_basic(self, project):
results = run_dbt(["run"])
assert len(results) == 2
# Select by model name
results = run_dbt(["test", "--select", "my_incremental_model"], expect_pass=True)
assert len(results) == 2
my_new_model = """
select
my_favorite_seed.id,
a + b as c
from {{ ref('my_favorite_seed') }} as my_favorite_seed
inner join {{ ref('my_favorite_model') }} as my_favorite_model
on my_favorite_seed.id = my_favorite_model.id
"""
my_favorite_model = """
select
2 as id,
3 as b
"""
seed_my_favorite_seed = """id,a
1,5
2,4
3,3
4,2
5,1
"""
schema_yml_explicit_seed = """
unit_tests:
- name: t
model: my_new_model
given:
- input: ref('my_favorite_seed')
rows:
- {id: 1, a: 10}
- input: ref('my_favorite_model')
rows:
- {id: 1, b: 2}
expect:
rows:
- {id: 1, c: 12}
"""
schema_yml_implicit_seed = """
unit_tests:
- name: t
model: my_new_model
given:
- input: ref('my_favorite_seed')
- input: ref('my_favorite_model')
rows:
- {id: 1, b: 2}
expect:
rows:
- {id: 1, c: 7}
"""
schema_yml_nonexistent_seed = """
unit_tests:
- name: t
model: my_new_model
given:
- input: ref('my_second_favorite_seed')
- input: ref('my_favorite_model')
rows:
- {id: 1, b: 2}
expect:
rows:
- {id: 1, c: 7}
"""
class TestUnitTestExplicitSeed:
@pytest.fixture(scope="class")
def seeds(self):
return {"my_favorite_seed.csv": seed_my_favorite_seed}
@pytest.fixture(scope="class")
def models(self):
return {
"my_new_model.sql": my_new_model,
"my_favorite_model.sql": my_favorite_model,
"schema.yml": schema_yml_explicit_seed,
}
def test_explicit_seed(self, project):
run_dbt(["seed"])
run_dbt(["run"])
# Select by model name
results = run_dbt(["test", "--select", "my_new_model"], expect_pass=True)
assert len(results) == 1
class TestUnitTestImplicitSeed:
@pytest.fixture(scope="class")
def seeds(self):
return {"my_favorite_seed.csv": seed_my_favorite_seed}
@pytest.fixture(scope="class")
def models(self):
return {
"my_new_model.sql": my_new_model,
"my_favorite_model.sql": my_favorite_model,
"schema.yml": schema_yml_implicit_seed,
}
def test_implicit_seed(self, project):
run_dbt(["seed"])
run_dbt(["run"])
# Select by model name
results = run_dbt(["test", "--select", "my_new_model"], expect_pass=True)
assert len(results) == 1
class TestUnitTestNonexistentSeed:
@pytest.fixture(scope="class")
def seeds(self):
return {"my_favorite_seed.csv": seed_my_favorite_seed}
@pytest.fixture(scope="class")
def models(self):
return {
"my_new_model.sql": my_new_model,
"my_favorite_model.sql": my_favorite_model,
"schema.yml": schema_yml_nonexistent_seed,
}
def test_nonexistent_seed(self, project):
with pytest.raises(
ParsingError, match="Unable to find seed 'test.my_second_favorite_seed' for unit tests"
):
run_dbt(["test", "--select", "my_new_model"], expect_pass=False)

View File

@@ -0,0 +1,114 @@
import pytest
from dbt.tests.util import run_dbt, get_unique_ids_in_results
from dbt.tests.fixtures.project import write_project_files
local_dependency__dbt_project_yml = """
name: 'local_dep'
version: '1.0'
seeds:
quote_columns: False
"""
local_dependency__schema_yml = """
sources:
- name: seed_source
schema: "{{ var('schema_override', target.schema) }}"
tables:
- name: "seed"
columns:
- name: id
tests:
- unique
unit_tests:
- name: test_dep_model_id
model: dep_model
given:
- input: ref('seed')
rows:
- {id: 1, name: Joe}
expect:
rows:
- {name_id: Joe_1}
"""
local_dependency__dep_model_sql = """
select name || '_' || id as name_id from {{ ref('seed') }}
"""
local_dependency__seed_csv = """id,name
1,Mary
2,Sam
3,John
"""
my_model_sql = """
select * from {{ ref('dep_model') }}
"""
my_model_schema_yml = """
unit_tests:
- name: test_my_model_name_id
model: my_model
given:
- input: ref('dep_model')
rows:
- {name_id: Joe_1}
expect:
rows:
- {name_id: Joe_1}
"""
class TestUnitTestingInDependency:
@pytest.fixture(scope="class", autouse=True)
def setUp(self, project_root):
local_dependency_files = {
"dbt_project.yml": local_dependency__dbt_project_yml,
"models": {
"schema.yml": local_dependency__schema_yml,
"dep_model.sql": local_dependency__dep_model_sql,
},
"seeds": {"seed.csv": local_dependency__seed_csv},
}
write_project_files(project_root, "local_dependency", local_dependency_files)
@pytest.fixture(scope="class")
def packages(self):
return {"packages": [{"local": "local_dependency"}]}
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"schema.yml": my_model_schema_yml,
}
def test_unit_test_in_dependency(self, project):
run_dbt(["deps"])
run_dbt(["seed"])
results = run_dbt(["run"])
assert len(results) == 2
results = run_dbt(["test"])
assert len(results) == 3
unique_ids = get_unique_ids_in_results(results)
assert "unit_test.local_dep.dep_model.test_dep_model_id" in unique_ids
results = run_dbt(["test", "--select", "test_type:unit"])
# two unit tests, 1 in root package, one in local_dep package
assert len(results) == 2
results = run_dbt(["test", "--select", "local_dep"])
# 2 tests in local_dep package
assert len(results) == 2
results = run_dbt(["test", "--select", "test"])
# 1 test in root package
assert len(results) == 1

View File

@@ -0,0 +1,75 @@
import pytest
from dbt.tests.util import run_dbt
raw_customers_csv = """id,first_name,last_name,email
1,Michael,Perez,mperez0@chronoengine.com
2,Shawn,Mccoy,smccoy1@reddit.com
3,Kathleen,Payne,kpayne2@cargocollective.com
4,Jimmy,Cooper,jcooper3@cargocollective.com
5,Katherine,Rice,krice4@typepad.com
6,Sarah,Ryan,sryan5@gnu.org
7,Martin,Mcdonald,mmcdonald6@opera.com
8,Frank,Robinson,frobinson7@wunderground.com
9,Jennifer,Franklin,jfranklin8@mail.ru
10,Henry,Welch,hwelch9@list-manage.com
"""
schema_sources_yml = """
sources:
- name: seed_sources
schema: "{{ target.schema }}"
tables:
- name: raw_customers
columns:
- name: id
tests:
- not_null:
severity: "{{ 'error' if target.name == 'prod' else 'warn' }}"
- unique
- name: first_name
- name: last_name
- name: email
unit_tests:
- name: test_customers
model: customers
given:
- input: source('seed_sources', 'raw_customers')
rows:
- {id: 1, first_name: Emily}
expect:
rows:
- {id: 1, first_name: Emily}
"""
customers_sql = """
select * from {{ source('seed_sources', 'raw_customers') }}
"""
class TestUnitTestSourceInput:
@pytest.fixture(scope="class")
def seeds(self):
return {
"raw_customers.csv": raw_customers_csv,
}
@pytest.fixture(scope="class")
def models(self):
return {
"customers.sql": customers_sql,
"sources.yml": schema_sources_yml,
}
def test_source_input(self, project):
results = run_dbt(["seed"])
results = run_dbt(["run"])
len(results) == 1
results = run_dbt(["test", "--select", "test_type:unit"])
assert len(results) == 1
results = run_dbt(["build"])
assert len(results) == 5
result_unique_ids = [result.node.unique_id for result in results]
assert len(result_unique_ids) == 5
assert "unit_test.test.customers.test_customers" in result_unique_ids

View File

@@ -28,10 +28,16 @@ from dbt.contracts.graph.nodes import (
TestMetadata,
ColumnInfo,
AccessType,
UnitTestDefinition,
)
from dbt.contracts.graph.manifest import Manifest, ManifestMetadata
from dbt.contracts.graph.saved_queries import QueryParams
from dbt.contracts.graph.unparsed import ExposureType, Owner
from dbt.contracts.graph.unparsed import (
ExposureType,
Owner,
UnitTestInputFixture,
UnitTestOutputFixture,
)
from dbt.contracts.state import PreviousState
from dbt.node_types import NodeType
from dbt.graph.selector_methods import (
@@ -223,16 +229,16 @@ def make_macro(pkg, name, macro_sql, path=None, depends_on_macros=None):
def make_unique_test(pkg, test_model, column_name, path=None, refs=None, sources=None, tags=None):
return make_schema_test(pkg, "unique", test_model, {}, column_name=column_name)
return make_generic_test(pkg, "unique", test_model, {}, column_name=column_name)
def make_not_null_test(
pkg, test_model, column_name, path=None, refs=None, sources=None, tags=None
):
return make_schema_test(pkg, "not_null", test_model, {}, column_name=column_name)
return make_generic_test(pkg, "not_null", test_model, {}, column_name=column_name)
def make_schema_test(
def make_generic_test(
pkg,
test_name,
test_model,
@@ -323,7 +329,33 @@ def make_schema_test(
)
def make_data_test(
def make_unit_test(
pkg,
test_name,
test_model,
):
input_fixture = UnitTestInputFixture(
input="ref('table_model')",
rows=[{"id": 1, "string_a": "a"}],
)
output_fixture = UnitTestOutputFixture(
rows=[{"id": 1, "string_a": "a"}],
)
return UnitTestDefinition(
name=test_name,
model=test_model,
package_name=pkg,
resource_type=NodeType.Unit,
path="unit_tests.yml",
original_file_path="models/unit_tests.yml",
unique_id=f"unit.{pkg}.{test_model.name}__{test_name}",
given=[input_fixture],
expect=output_fixture,
fqn=[pkg, test_model.name, test_name],
)
def make_singular_test(
pkg, name, sql, refs=None, sources=None, tags=None, path=None, config_kwargs=None
):
@@ -746,7 +778,7 @@ def ext_source_id_unique(ext_source):
@pytest.fixture
def view_test_nothing(view_model):
return make_data_test(
return make_singular_test(
"pkg",
"view_test_nothing",
'select * from {{ ref("view_model") }} limit 0',
@@ -754,6 +786,15 @@ def view_test_nothing(view_model):
)
@pytest.fixture
def unit_test_table_model(table_model):
return make_unit_test(
"pkg",
"unit_test_table_model",
table_model,
)
# Support dots as namespace separators
@pytest.fixture
def namespaced_seed():
@@ -818,6 +859,7 @@ def manifest(
macro_default_test_unique,
macro_test_not_null,
macro_default_test_not_null,
unit_test_table_model,
):
nodes = [
seed,
@@ -849,10 +891,12 @@ def manifest(
macro_test_not_null,
macro_default_test_not_null,
]
unit_tests = [unit_test_table_model]
manifest = Manifest(
nodes={n.unique_id: n for n in nodes},
sources={s.unique_id: s for s in sources},
macros={m.unique_id: m for m in macros},
unit_tests={t.unique_id: t for t in unit_tests},
semantic_models={},
docs={},
files={},
@@ -873,7 +917,8 @@ def search_manifest_using_method(manifest, method, selection):
| set(manifest.exposures)
| set(manifest.metrics)
| set(manifest.semantic_models)
| set(manifest.saved_queries),
| set(manifest.saved_queries)
| set(manifest.unit_tests),
selection,
)
results = {manifest.expect(uid).search_name for uid in selected}
@@ -908,6 +953,7 @@ def test_select_fqn(manifest):
"mynamespace.union_model",
"mynamespace.ephemeral_model",
"mynamespace.seed",
"unit_test_table_model",
}
assert search_manifest_using_method(manifest, method, "ext") == {"ext_model"}
# versions
@@ -934,6 +980,7 @@ def test_select_fqn(manifest):
"mynamespace.union_model",
"mynamespace.ephemeral_model",
"union_model",
"unit_test_table_model",
}
# multiple wildcards
assert search_manifest_using_method(manifest, method, "*unions*") == {
@@ -947,6 +994,7 @@ def test_select_fqn(manifest):
"table_model",
"table_model_py",
"table_model_csv",
"unit_test_table_model",
}
# wildcard and ? (matches exactly one character)
assert search_manifest_using_method(manifest, method, "*ext_m?del") == {"ext_model"}
@@ -1143,6 +1191,7 @@ def test_select_package(manifest):
"mynamespace.seed",
"mynamespace.ephemeral_model",
"mynamespace.union_model",
"unit_test_table_model",
}
assert search_manifest_using_method(manifest, method, "ext") == {
"ext_model",
@@ -1255,7 +1304,16 @@ def test_select_test_type(manifest):
"unique_view_model_id",
"unique_ext_raw_ext_source_id",
}
assert search_manifest_using_method(manifest, method, "data") == {"view_test_nothing"}
assert search_manifest_using_method(manifest, method, "data") == {
"view_test_nothing",
"unique_table_model_id",
"not_null_table_model_id",
"unique_view_model_id",
"unique_ext_raw_ext_source_id",
}
assert search_manifest_using_method(manifest, method, "unit") == {
"unit_test_table_model",
}
def test_select_version(manifest):

View File

@@ -398,6 +398,7 @@ class ManifestTest(unittest.TestCase):
"docs": {},
"disabled": {},
"semantic_models": {},
"unit_tests": {},
"saved_queries": {},
},
)
@@ -582,6 +583,7 @@ class ManifestTest(unittest.TestCase):
},
"disabled": {},
"semantic_models": {},
"unit_tests": {},
"saved_queries": {},
},
)
@@ -921,6 +923,7 @@ class MixedManifestTest(unittest.TestCase):
"docs": {},
"disabled": {},
"semantic_models": {},
"unit_tests": {},
"saved_queries": {},
},
)

View File

@@ -17,6 +17,7 @@ node_type_pluralizations = {
NodeType.Metric: "metrics",
NodeType.Group: "groups",
NodeType.SemanticModel: "semantic_models",
NodeType.Unit: "unit_tests",
NodeType.SavedQuery: "saved_queries",
}

View File

@@ -176,13 +176,14 @@ class BaseParserTest(unittest.TestCase):
return FileBlock(file=source_file)
def assert_has_manifest_lengths(
self, manifest, macros=3, nodes=0, sources=0, docs=0, disabled=0
self, manifest, macros=3, nodes=0, sources=0, docs=0, disabled=0, unit_tests=0
):
self.assertEqual(len(manifest.macros), macros)
self.assertEqual(len(manifest.nodes), nodes)
self.assertEqual(len(manifest.sources), sources)
self.assertEqual(len(manifest.docs), docs)
self.assertEqual(len(manifest.disabled), disabled)
self.assertEqual(len(manifest.unit_tests), unit_tests)
def assertEqualNodes(node_one, node_two):
@@ -371,8 +372,8 @@ class SchemaParserTest(BaseParserTest):
manifest=self.manifest,
)
def file_block_for(self, data, filename):
return super().file_block_for(data, filename, "models")
def file_block_for(self, data, filename, searched="models"):
return super().file_block_for(data, filename, searched)
def yaml_block_for(self, test_yml: str, filename: str):
file_block = self.file_block_for(data=test_yml, filename=filename)

View File

@@ -0,0 +1,183 @@
from dbt.contracts.graph.nodes import UnitTestDefinition, UnitTestConfig, DependsOn, NodeType
from dbt.exceptions import ParsingError
from dbt.parser import SchemaParser
from dbt.parser.unit_tests import UnitTestParser
from .utils import MockNode
from .test_parser import SchemaParserTest, assertEqualNodes
from unittest import mock
from dbt.contracts.graph.unparsed import UnitTestOutputFixture
UNIT_TEST_MODEL_NOT_FOUND_SOURCE = """
unit_tests:
- name: test_my_model_doesnt_exist
model: my_model_doesnt_exist
description: "unit test description"
given: []
expect:
rows:
- {a: 1}
"""
UNIT_TEST_SOURCE = """
unit_tests:
- name: test_my_model
model: my_model
description: "unit test description"
given: []
expect:
rows:
- {a: 1}
"""
UNIT_TEST_VERSIONED_MODEL_SOURCE = """
unit_tests:
- name: test_my_model_versioned
model: my_model_versioned.v1
description: "unit test description"
given: []
expect:
rows:
- {a: 1}
"""
UNIT_TEST_CONFIG_SOURCE = """
unit_tests:
- name: test_my_model
model: my_model
config:
tags: "schema_tag"
meta:
meta_key: meta_value
meta_jinja_key: '{{ 1 + 1 }}'
description: "unit test description"
given: []
expect:
rows:
- {a: 1}
"""
UNIT_TEST_MULTIPLE_SOURCE = """
unit_tests:
- name: test_my_model
model: my_model
description: "unit test description"
given: []
expect:
rows:
- {a: 1}
- name: test_my_model2
model: my_model
description: "unit test description"
given: []
expect:
rows:
- {a: 1}
"""
class UnitTestParserTest(SchemaParserTest):
def setUp(self):
super().setUp()
my_model_node = MockNode(
package="snowplow",
name="my_model",
config=mock.MagicMock(enabled=True),
schema="test_schema",
refs=[],
sources=[],
patch_path=None,
)
self.manifest.nodes = {my_model_node.unique_id: my_model_node}
self.parser = SchemaParser(
project=self.snowplow_project_config,
manifest=self.manifest,
root_project=self.root_project_config,
)
def file_block_for(self, data, filename):
return super().file_block_for(data, filename, "unit_tests")
def test_basic_model_not_found(self):
block = self.yaml_block_for(UNIT_TEST_MODEL_NOT_FOUND_SOURCE, "test_my_model.yml")
with self.assertRaises(ParsingError):
UnitTestParser(self.parser, block).parse()
def test_basic(self):
block = self.yaml_block_for(UNIT_TEST_SOURCE, "test_my_model.yml")
UnitTestParser(self.parser, block).parse()
self.assert_has_manifest_lengths(self.parser.manifest, nodes=1, unit_tests=1)
unit_test = list(self.parser.manifest.unit_tests.values())[0]
expected = UnitTestDefinition(
name="test_my_model",
model="my_model",
resource_type=NodeType.Unit,
package_name="snowplow",
path=block.path.relative_path,
original_file_path=block.path.original_file_path,
unique_id="unit_test.snowplow.my_model.test_my_model",
given=[],
expect=UnitTestOutputFixture(rows=[{"a": 1}]),
description="unit test description",
overrides=None,
depends_on=DependsOn(nodes=["model.snowplow.my_model"]),
fqn=["snowplow", "my_model", "test_my_model"],
config=UnitTestConfig(),
schema="test_schema",
)
expected.build_unit_test_checksum("anything", "anything")
assertEqualNodes(unit_test, expected)
def test_unit_test_config(self):
block = self.yaml_block_for(UNIT_TEST_CONFIG_SOURCE, "test_my_model.yml")
self.root_project_config.unit_tests = {
"snowplow": {"my_model": {"+tags": ["project_tag"]}}
}
UnitTestParser(self.parser, block).parse()
self.assert_has_manifest_lengths(self.parser.manifest, nodes=1, unit_tests=1)
unit_test = self.parser.manifest.unit_tests["unit_test.snowplow.my_model.test_my_model"]
self.assertEqual(sorted(unit_test.config.tags), sorted(["schema_tag", "project_tag"]))
self.assertEqual(unit_test.config.meta, {"meta_key": "meta_value", "meta_jinja_key": "2"})
def test_unit_test_versioned_model(self):
block = self.yaml_block_for(UNIT_TEST_VERSIONED_MODEL_SOURCE, "test_my_model.yml")
my_model_versioned_node = MockNode(
package="snowplow",
name="my_model_versioned",
config=mock.MagicMock(enabled=True),
refs=[],
sources=[],
patch_path=None,
version=1,
)
self.manifest.nodes[my_model_versioned_node.unique_id] = my_model_versioned_node
UnitTestParser(self.parser, block).parse()
self.assert_has_manifest_lengths(self.parser.manifest, nodes=2, unit_tests=1)
unit_test = self.parser.manifest.unit_tests[
"unit_test.snowplow.my_model_versioned.v1.test_my_model_versioned"
]
self.assertEqual(len(unit_test.depends_on.nodes), 1)
self.assertEqual(unit_test.depends_on.nodes[0], "model.snowplow.my_model_versioned.v1")
def test_multiple_unit_tests(self):
block = self.yaml_block_for(UNIT_TEST_MULTIPLE_SOURCE, "test_my_model.yml")
UnitTestParser(self.parser, block).parse()
self.assert_has_manifest_lengths(self.parser.manifest, nodes=1, unit_tests=2)
for unit_test in self.parser.manifest.unit_tests.values():
self.assertEqual(len(unit_test.depends_on.nodes), 1)
self.assertEqual(unit_test.depends_on.nodes[0], "model.snowplow.my_model")

View File

@@ -336,7 +336,7 @@ def MockNode(package, name, resource_type=None, **kwargs):
version = kwargs.get("version")
search_name = name if version is None else f"{name}.v{version}"
unique_id = f"{str(resource_type)}.{package}.{name}"
unique_id = f"{str(resource_type)}.{package}.{search_name}"
node = mock.MagicMock(
__class__=cls,
resource_type=resource_type,

File diff suppressed because it is too large Load Diff